From 1014bf426b8b1de4cacca5848302b215867e7ad6 Mon Sep 17 00:00:00 2001 From: Jan Verbeek Date: Thu, 20 Mar 2025 20:33:28 +0100 Subject: [PATCH 1/8] shuf: Move NonrepeatingIterator to own module --- src/uu/shuf/src/nonrepeating_iterator.rs | 182 +++++++++++++++++++++++ src/uu/shuf/src/shuf.rs | 175 +--------------------- 2 files changed, 185 insertions(+), 172 deletions(-) create mode 100644 src/uu/shuf/src/nonrepeating_iterator.rs diff --git a/src/uu/shuf/src/nonrepeating_iterator.rs b/src/uu/shuf/src/nonrepeating_iterator.rs new file mode 100644 index 00000000000..dfefd117863 --- /dev/null +++ b/src/uu/shuf/src/nonrepeating_iterator.rs @@ -0,0 +1,182 @@ +// spell-checker:ignore nonrepeating + +use std::{collections::HashSet, ops::RangeInclusive}; + +use rand::{Rng, seq::SliceRandom}; + +use crate::WrappedRng; + +enum NumberSet { + AlreadyListed(HashSet), + Remaining(Vec), +} + +pub(crate) struct NonrepeatingIterator<'a> { + range: RangeInclusive, + rng: &'a mut WrappedRng, + remaining_count: usize, + buf: NumberSet, +} + +impl<'a> NonrepeatingIterator<'a> { + pub(crate) fn new( + range: RangeInclusive, + rng: &'a mut WrappedRng, + amount: usize, + ) -> Self { + let capped_amount = if range.start() > range.end() { + 0 + } else if range == (0..=usize::MAX) { + amount + } else { + amount.min(range.end() - range.start() + 1) + }; + NonrepeatingIterator { + range, + rng, + remaining_count: capped_amount, + buf: NumberSet::AlreadyListed(HashSet::default()), + } + } + + fn produce(&mut self) -> usize { + debug_assert!(self.range.start() <= self.range.end()); + match &mut self.buf { + NumberSet::AlreadyListed(already_listed) => { + let chosen = loop { + let guess = self.rng.random_range(self.range.clone()); + let newly_inserted = already_listed.insert(guess); + if newly_inserted { + break guess; + } + }; + // Once a significant fraction of the interval has already been enumerated, + // the number of attempts to find a number that hasn't been chosen yet increases. + // Therefore, we need to switch at some point from "set of already returned values" to "list of remaining values". + let range_size = (self.range.end() - self.range.start()).saturating_add(1); + if number_set_should_list_remaining(already_listed.len(), range_size) { + let mut remaining = self + .range + .clone() + .filter(|n| !already_listed.contains(n)) + .collect::>(); + assert!(remaining.len() >= self.remaining_count); + remaining.partial_shuffle(&mut self.rng, self.remaining_count); + remaining.truncate(self.remaining_count); + self.buf = NumberSet::Remaining(remaining); + } + chosen + } + NumberSet::Remaining(remaining_numbers) => { + debug_assert!(!remaining_numbers.is_empty()); + // We only enter produce() when there is at least one actual element remaining, so popping must always return an element. + remaining_numbers.pop().unwrap() + } + } + } +} + +impl Iterator for NonrepeatingIterator<'_> { + type Item = usize; + + fn next(&mut self) -> Option { + if self.range.is_empty() || self.remaining_count == 0 { + return None; + } + self.remaining_count -= 1; + Some(self.produce()) + } +} + +// This could be a method, but it is much easier to test as a stand-alone function. +fn number_set_should_list_remaining(listed_count: usize, range_size: usize) -> bool { + // Arbitrarily determine the switchover point to be around 25%. This is because: + // - HashSet has a large space overhead for the hash table load factor. + // - This means that somewhere between 25-40%, the memory required for a "positive" HashSet and a "negative" Vec should be the same. + // - HashSet has a small but non-negligible overhead for each lookup, so we have a slight preference for Vec anyway. + // - At 25%, on average 1.33 attempts are needed to find a number that hasn't been taken yet. + // - Finally, "24%" is computationally the simplest: + listed_count >= range_size / 4 +} + +#[cfg(test)] +// Since the computed value is a bool, it is more readable to write the expected value out: +#[allow(clippy::bool_assert_comparison)] +mod test_number_set_decision { + use super::number_set_should_list_remaining; + + #[test] + fn test_stay_positive_large_remaining_first() { + assert_eq!(false, number_set_should_list_remaining(0, usize::MAX)); + } + + #[test] + fn test_stay_positive_large_remaining_second() { + assert_eq!(false, number_set_should_list_remaining(1, usize::MAX)); + } + + #[test] + fn test_stay_positive_large_remaining_tenth() { + assert_eq!(false, number_set_should_list_remaining(9, usize::MAX)); + } + + #[test] + fn test_stay_positive_smallish_range_first() { + assert_eq!(false, number_set_should_list_remaining(0, 12345)); + } + + #[test] + fn test_stay_positive_smallish_range_second() { + assert_eq!(false, number_set_should_list_remaining(1, 12345)); + } + + #[test] + fn test_stay_positive_smallish_range_tenth() { + assert_eq!(false, number_set_should_list_remaining(9, 12345)); + } + + #[test] + fn test_stay_positive_small_range_not_too_early() { + assert_eq!(false, number_set_should_list_remaining(1, 10)); + } + + // Don't want to test close to the border, in case we decide to change the threshold. + // However, at 50% coverage, we absolutely should switch: + #[test] + fn test_switch_half() { + assert_eq!(true, number_set_should_list_remaining(1234, 2468)); + } + + // Ensure that the decision is monotonous: + #[test] + fn test_switch_late1() { + assert_eq!(true, number_set_should_list_remaining(12340, 12345)); + } + + #[test] + fn test_switch_late2() { + assert_eq!(true, number_set_should_list_remaining(12344, 12345)); + } + + // Ensure that we are overflow-free: + #[test] + fn test_no_crash_exceed_max_size1() { + assert_eq!(false, number_set_should_list_remaining(12345, usize::MAX)); + } + + #[test] + fn test_no_crash_exceed_max_size2() { + assert_eq!( + true, + number_set_should_list_remaining(usize::MAX - 1, usize::MAX) + ); + } + + #[test] + fn test_no_crash_exceed_max_size3() { + assert_eq!( + true, + number_set_should_list_remaining(usize::MAX, usize::MAX) + ); + } +} diff --git a/src/uu/shuf/src/shuf.rs b/src/uu/shuf/src/shuf.rs index 4fd5ca85a0f..47e9cd1320b 100644 --- a/src/uu/shuf/src/shuf.rs +++ b/src/uu/shuf/src/shuf.rs @@ -10,7 +10,6 @@ use clap::{Arg, ArgAction, Command}; use rand::prelude::SliceRandom; use rand::seq::IndexedRandom; use rand::{Rng, RngCore}; -use std::collections::HashSet; use std::ffi::{OsStr, OsString}; use std::fs::File; use std::io::{BufWriter, Error, Read, Write, stdin, stdout}; @@ -22,8 +21,11 @@ use uucore::error::{FromIo, UResult, USimpleError, UUsageError}; use uucore::format_usage; use uucore::translate; +mod nonrepeating_iterator; mod rand_read_adapter; +use nonrepeating_iterator::NonrepeatingIterator; + enum Mode { Default(PathBuf), Echo(Vec), @@ -315,95 +317,6 @@ impl Shufable for RangeInclusive { } } -enum NumberSet { - AlreadyListed(HashSet), - Remaining(Vec), -} - -struct NonrepeatingIterator<'a> { - range: RangeInclusive, - rng: &'a mut WrappedRng, - remaining_count: usize, - buf: NumberSet, -} - -impl<'a> NonrepeatingIterator<'a> { - fn new(range: RangeInclusive, rng: &'a mut WrappedRng, amount: usize) -> Self { - let capped_amount = if range.start() > range.end() { - 0 - } else if range == (0..=usize::MAX) { - amount - } else { - amount.min(range.end() - range.start() + 1) - }; - NonrepeatingIterator { - range, - rng, - remaining_count: capped_amount, - buf: NumberSet::AlreadyListed(HashSet::default()), - } - } - - fn produce(&mut self) -> usize { - debug_assert!(self.range.start() <= self.range.end()); - match &mut self.buf { - NumberSet::AlreadyListed(already_listed) => { - let chosen = loop { - let guess = self.rng.random_range(self.range.clone()); - let newly_inserted = already_listed.insert(guess); - if newly_inserted { - break guess; - } - }; - // Once a significant fraction of the interval has already been enumerated, - // the number of attempts to find a number that hasn't been chosen yet increases. - // Therefore, we need to switch at some point from "set of already returned values" to "list of remaining values". - let range_size = (self.range.end() - self.range.start()).saturating_add(1); - if number_set_should_list_remaining(already_listed.len(), range_size) { - let mut remaining = self - .range - .clone() - .filter(|n| !already_listed.contains(n)) - .collect::>(); - assert!(remaining.len() >= self.remaining_count); - remaining.partial_shuffle(&mut self.rng, self.remaining_count); - remaining.truncate(self.remaining_count); - self.buf = NumberSet::Remaining(remaining); - } - chosen - } - NumberSet::Remaining(remaining_numbers) => { - debug_assert!(!remaining_numbers.is_empty()); - // We only enter produce() when there is at least one actual element remaining, so popping must always return an element. - remaining_numbers.pop().unwrap() - } - } - } -} - -impl Iterator for NonrepeatingIterator<'_> { - type Item = usize; - - fn next(&mut self) -> Option { - if self.range.is_empty() || self.remaining_count == 0 { - return None; - } - self.remaining_count -= 1; - Some(self.produce()) - } -} - -// This could be a method, but it is much easier to test as a stand-alone function. -fn number_set_should_list_remaining(listed_count: usize, range_size: usize) -> bool { - // Arbitrarily determine the switchover point to be around 25%. This is because: - // - HashSet has a large space overhead for the hash table load factor. - // - This means that somewhere between 25-40%, the memory required for a "positive" HashSet and a "negative" Vec should be the same. - // - HashSet has a small but non-negligible overhead for each lookup, so we have a slight preference for Vec anyway. - // - At 25%, on average 1.33 attempts are needed to find a number that hasn't been taken yet. - // - Finally, "24%" is computationally the simplest: - listed_count >= range_size / 4 -} - trait Writable { fn write_all_to(&self, output: &mut impl OsWrite) -> Result<(), Error>; } @@ -543,85 +456,3 @@ mod test_split_seps { assert_eq!(split_seps(b"a\nb\nc", b'\n'), &[b"a", b"b", b"c"]); } } - -#[cfg(test)] -// Since the computed value is a bool, it is more readable to write the expected value out: -#[allow(clippy::bool_assert_comparison)] -mod test_number_set_decision { - use super::number_set_should_list_remaining; - - #[test] - fn test_stay_positive_large_remaining_first() { - assert_eq!(false, number_set_should_list_remaining(0, usize::MAX)); - } - - #[test] - fn test_stay_positive_large_remaining_second() { - assert_eq!(false, number_set_should_list_remaining(1, usize::MAX)); - } - - #[test] - fn test_stay_positive_large_remaining_tenth() { - assert_eq!(false, number_set_should_list_remaining(9, usize::MAX)); - } - - #[test] - fn test_stay_positive_smallish_range_first() { - assert_eq!(false, number_set_should_list_remaining(0, 12345)); - } - - #[test] - fn test_stay_positive_smallish_range_second() { - assert_eq!(false, number_set_should_list_remaining(1, 12345)); - } - - #[test] - fn test_stay_positive_smallish_range_tenth() { - assert_eq!(false, number_set_should_list_remaining(9, 12345)); - } - - #[test] - fn test_stay_positive_small_range_not_too_early() { - assert_eq!(false, number_set_should_list_remaining(1, 10)); - } - - // Don't want to test close to the border, in case we decide to change the threshold. - // However, at 50% coverage, we absolutely should switch: - #[test] - fn test_switch_half() { - assert_eq!(true, number_set_should_list_remaining(1234, 2468)); - } - - // Ensure that the decision is monotonous: - #[test] - fn test_switch_late1() { - assert_eq!(true, number_set_should_list_remaining(12340, 12345)); - } - - #[test] - fn test_switch_late2() { - assert_eq!(true, number_set_should_list_remaining(12344, 12345)); - } - - // Ensure that we are overflow-free: - #[test] - fn test_no_crash_exceed_max_size1() { - assert_eq!(false, number_set_should_list_remaining(12345, usize::MAX)); - } - - #[test] - fn test_no_crash_exceed_max_size2() { - assert_eq!( - true, - number_set_should_list_remaining(usize::MAX - 1, usize::MAX) - ); - } - - #[test] - fn test_no_crash_exceed_max_size3() { - assert_eq!( - true, - number_set_should_list_remaining(usize::MAX, usize::MAX) - ); - } -} From e6be570d9341028d0cbf4ad3030c9ddad34f658d Mon Sep 17 00:00:00 2001 From: Jan Verbeek Date: Sun, 23 Mar 2025 17:57:55 +0100 Subject: [PATCH 2/8] shuf: correctness: Flush output after writing This is important since the output is buffered and errors may end up ignored otherwise. `shuf -e a b c > /dev/full` now errors while it didn't before. --- src/uu/shuf/src/shuf.rs | 1 + tests/by-util/test_shuf.rs | 12 ++++++++++++ 2 files changed, 13 insertions(+) diff --git a/src/uu/shuf/src/shuf.rs b/src/uu/shuf/src/shuf.rs index 47e9cd1320b..98e035b6346 100644 --- a/src/uu/shuf/src/shuf.rs +++ b/src/uu/shuf/src/shuf.rs @@ -386,6 +386,7 @@ fn shuf_exec( output.write_all(&[opts.sep]).map_err_context(ctx)?; } } + output.flush().map_err_context(ctx)?; Ok(()) } diff --git a/tests/by-util/test_shuf.rs b/tests/by-util/test_shuf.rs index 4d3f841ace9..83f02f04918 100644 --- a/tests/by-util/test_shuf.rs +++ b/tests/by-util/test_shuf.rs @@ -847,3 +847,15 @@ fn test_range_repeat_empty_minus_one() { .no_stdout() .stderr_contains("invalid value '5-3' for '--input-range ': start exceeds end\n"); } + +// This test fails if we forget to flush the `BufWriter`. +#[test] +#[cfg(target_os = "linux")] +fn write_errors_are_reported() { + new_ucmd!() + .arg("-i1-10") + .arg("-o/dev/full") + .fails() + .no_stdout() + .stderr_is("shuf: write failed: No space left on device\n"); +} From ffb041e3f3c67bce1576d9bad8b67f5efa40ae3a Mon Sep 17 00:00:00 2001 From: Jan Verbeek Date: Sun, 23 Mar 2025 18:08:42 +0100 Subject: [PATCH 3/8] shuf: perf: Bump output buffer to 64KB --- src/uu/shuf/src/shuf.rs | 23 ++++++++++++++--------- 1 file changed, 14 insertions(+), 9 deletions(-) diff --git a/src/uu/shuf/src/shuf.rs b/src/uu/shuf/src/shuf.rs index 98e035b6346..c0765e55339 100644 --- a/src/uu/shuf/src/shuf.rs +++ b/src/uu/shuf/src/shuf.rs @@ -32,6 +32,8 @@ enum Mode { InputRange(RangeInclusive), } +const BUF_SIZE: usize = 64 * 1024; + struct Options { head_count: usize, output: Option, @@ -100,15 +102,18 @@ pub fn uumain(args: impl uucore::Args) -> UResult<()> { }, }; - let mut output = BufWriter::new(match options.output { - None => Box::new(stdout()) as Box, - Some(ref s) => { - let file = File::create(s).map_err_context( - || translate!("shuf-error-failed-to-open-for-writing", "file" => s.quote()), - )?; - Box::new(file) as Box - } - }); + let mut output = BufWriter::with_capacity( + BUF_SIZE, + match options.output { + None => Box::new(stdout()) as Box, + Some(ref s) => { + let file = File::create(s).map_err_context( + || translate!("shuf-error-failed-to-open-for-writing", "file" => s.quote()), + )?; + Box::new(file) as Box + } + }, + ); if options.head_count == 0 { // In this case we do want to touch the output file but we can quit immediately. From a6a501fe537cc42a9e357e40a0052b00de069959 Mon Sep 17 00:00:00 2001 From: Jan Verbeek Date: Mon, 24 Mar 2025 19:01:34 +0100 Subject: [PATCH 4/8] shuf: correctness: Do not use panics to report --random-source read errors --- src/uu/shuf/locales/en-US.ftl | 1 + src/uu/shuf/src/rand_read_adapter.rs | 47 ++++++++++++---------------- src/uu/shuf/src/shuf.rs | 24 +++++++++++++- 3 files changed, 44 insertions(+), 28 deletions(-) diff --git a/src/uu/shuf/locales/en-US.ftl b/src/uu/shuf/locales/en-US.ftl index 24876e6a37f..913ddaa1d01 100644 --- a/src/uu/shuf/locales/en-US.ftl +++ b/src/uu/shuf/locales/en-US.ftl @@ -19,6 +19,7 @@ shuf-error-unexpected-argument = unexpected argument { $arg } found shuf-error-failed-to-open-for-writing = failed to open { $file } for writing shuf-error-failed-to-open-random-source = failed to open random source { $file } shuf-error-read-error = read error +shuf-error-read-random-bytes = reading random bytes failed shuf-error-no-lines-to-repeat = no lines to repeat shuf-error-start-exceeds-end = start exceeds end shuf-error-missing-dash = missing '-' diff --git a/src/uu/shuf/src/rand_read_adapter.rs b/src/uu/shuf/src/rand_read_adapter.rs index 3f504c03d2b..84c7e8bf218 100644 --- a/src/uu/shuf/src/rand_read_adapter.rs +++ b/src/uu/shuf/src/rand_read_adapter.rs @@ -13,8 +13,9 @@ //! A wrapper around any Read to treat it as an RNG. -use std::fmt; -use std::io::Read; +use std::cell::Cell; +use std::io::{Error, Read}; +use std::rc::Rc; use rand_core::{RngCore, impls}; @@ -30,27 +31,33 @@ use rand_core::{RngCore, impls}; /// /// `ReadRng` uses [`std::io::Read::read_exact`], which retries on interrupts. /// All other errors from the underlying reader, including when it does not -/// have enough data, will only be reported through `try_fill_bytes`. -/// The other [`RngCore`] methods will panic in case of an error. +/// have enough data, will be reported via the public error field (which can +/// be cloned in advance, as it uses [`Rc`]). This field must be checked for +/// errors after every operation. /// /// [`OsRng`]: rand::rngs::OsRng -#[derive(Debug)] pub struct ReadRng { reader: R, + pub error: ErrorCell, } +pub type ErrorCell = Rc>>; + impl ReadRng { /// Create a new `ReadRng` from a `Read`. pub fn new(r: R) -> Self { - Self { reader: r } + Self { + reader: r, + error: Rc::default(), + } } - fn try_fill_bytes(&mut self, dest: &mut [u8]) -> Result<(), ReadError> { + fn try_fill_bytes(&mut self, dest: &mut [u8]) -> Result<(), Error> { if dest.is_empty() { return Ok(()); } // Use `std::io::read_exact`, which retries on `ErrorKind::Interrupted`. - self.reader.read_exact(dest).map_err(ReadError) + self.reader.read_exact(dest) } } @@ -64,25 +71,11 @@ impl RngCore for ReadRng { } fn fill_bytes(&mut self, dest: &mut [u8]) { - self.try_fill_bytes(dest).unwrap_or_else(|err| { - panic!("reading random bytes from Read implementation failed; error: {err}"); - }); - } -} - -/// `ReadRng` error type -#[derive(Debug)] -pub struct ReadError(std::io::Error); - -impl fmt::Display for ReadError { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!(f, "ReadError: {}", self.0) - } -} - -impl std::error::Error for ReadError { - fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { - Some(&self.0) + if let Err(err) = self.try_fill_bytes(dest) { + // Failed to deliver random data, so the caller must check the error + // cell before using the result. + self.error.set(Some(err)); + } } } diff --git a/src/uu/shuf/src/shuf.rs b/src/uu/shuf/src/shuf.rs index c0765e55339..1cc44428353 100644 --- a/src/uu/shuf/src/shuf.rs +++ b/src/uu/shuf/src/shuf.rs @@ -370,7 +370,7 @@ fn shuf_exec( output: &mut BufWriter>, ) -> UResult<()> { let ctx = || translate!("shuf-error-write-failed"); - + let error_cell = rng.get_error_cell(); if opts.repeat { if input.is_empty() { return Err(USimpleError::new( @@ -380,12 +380,15 @@ fn shuf_exec( } for _ in 0..opts.head_count { let r = input.choose(rng); + WrappedRng::check_error(error_cell.as_ref())?; r.write_all_to(output).map_err_context(ctx)?; output.write_all(&[opts.sep]).map_err_context(ctx)?; } } else { let shuffled = input.partial_shuffle(rng, opts.head_count); + WrappedRng::check_error(error_cell.as_ref())?; + for r in shuffled { r.write_all_to(output).map_err_context(ctx)?; output.write_all(&[opts.sep]).map_err_context(ctx)?; @@ -415,6 +418,25 @@ enum WrappedRng { RngDefault(rand::rngs::ThreadRng), } +impl WrappedRng { + fn get_error_cell(&self) -> Option { + if let Self::RngFile(adapter) = self { + Some(adapter.error.clone()) + } else { + None + } + } + + fn check_error(error_cell: Option<&rand_read_adapter::ErrorCell>) -> UResult<()> { + if let Some(cell) = error_cell { + if let Some(err) = cell.take() { + return Err(err.map_err_context(|| translate!("shuf-error-read-random-bytes"))); + } + } + Ok(()) + } +} + impl RngCore for WrappedRng { fn next_u32(&mut self) -> u32 { match self { From 8e4ada0796f2a87eec5e0af0720324848f56e183 Mon Sep 17 00:00:00 2001 From: Jan Verbeek Date: Wed, 26 Mar 2025 10:34:54 +0100 Subject: [PATCH 5/8] shuf: perf: Use itoa for integer formatting MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This gives a 1.8× speedup over a stdlib formatted write for `shuf -r -n1000000 -i1-1024`. The original version of this commit replaced a formatted write, but before it got merged main received optimized manual formatting from another PR. The speedup of itoa over the manual write is around 1.1×, much less dramatic. --- .../workspace.wordlist.txt | 1 + Cargo.lock | 1 + Cargo.toml | 1 + src/uu/shuf/Cargo.toml | 1 + src/uu/shuf/src/shuf.rs | 24 ++++--------------- 5 files changed, 8 insertions(+), 20 deletions(-) diff --git a/.vscode/cspell.dictionaries/workspace.wordlist.txt b/.vscode/cspell.dictionaries/workspace.wordlist.txt index 28c468d4f9c..30d2bd3e04b 100644 --- a/.vscode/cspell.dictionaries/workspace.wordlist.txt +++ b/.vscode/cspell.dictionaries/workspace.wordlist.txt @@ -38,6 +38,7 @@ getrandom globset indicatif itertools +itoa iuse langid lscolors diff --git a/Cargo.lock b/Cargo.lock index e0da285d3f2..d0126ab3a4d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3822,6 +3822,7 @@ dependencies = [ "clap", "codspeed-divan-compat", "fluent", + "itoa", "rand 0.9.2", "rand_core 0.9.5", "tempfile", diff --git a/Cargo.toml b/Cargo.toml index 7d3ee246268..121b155381f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -336,6 +336,7 @@ icu_locale = "2.0.0" icu_provider = "2.0.0" indicatif = "0.18.0" itertools = "0.14.0" +itoa = "1.0.15" jiff = "0.2.18" libc = "0.2.172" linux-raw-sys = "0.12" diff --git a/src/uu/shuf/Cargo.toml b/src/uu/shuf/Cargo.toml index b67b1d80811..ee4e217d0f4 100644 --- a/src/uu/shuf/Cargo.toml +++ b/src/uu/shuf/Cargo.toml @@ -19,6 +19,7 @@ path = "src/shuf.rs" [dependencies] clap = { workspace = true } +itoa = { workspace = true } rand = { workspace = true } rand_core = { workspace = true } uucore = { workspace = true } diff --git a/src/uu/shuf/src/shuf.rs b/src/uu/shuf/src/shuf.rs index 1cc44428353..cc04a3068d6 100644 --- a/src/uu/shuf/src/shuf.rs +++ b/src/uu/shuf/src/shuf.rs @@ -340,26 +340,10 @@ impl Writable for &OsStr { impl Writable for usize { fn write_all_to(&self, output: &mut impl OsWrite) -> Result<(), Error> { - let mut n = *self; - - // Handle the zero case explicitly - if n == 0 { - return output.write_all(b"0"); - } - - // Maximum number of digits for u64 is 20 (18446744073709551615) - let mut buf = [0u8; 20]; - let mut i = 20; - - // Write digits from right to left - while n > 0 { - i -= 1; - buf[i] = b'0' + (n % 10) as u8; - n /= 10; - } - - // Write the relevant part of the buffer to output - output.write_all(&buf[i..]) + // The itoa crate is surprisingly much more efficient than a formatted write. + // It speeds up `shuf -r -n1000000 -i1-1024` by 1.8×. + let mut buf = itoa::Buffer::new(); + output.write_all(buf.format(*self).as_bytes()) } } From 6560a6985e7410ad1a62a0edf22d13538cc15cc9 Mon Sep 17 00:00:00 2001 From: Jan Verbeek Date: Tue, 25 Mar 2025 19:19:26 +0100 Subject: [PATCH 6/8] shuf: correctness: Make --random-source compatible with GNU shuf When the --random-source option is used uutils shuf now gives identical output to GNU shuf in many (but not all) cases. This is helpful to users who use it to get deterministic output, e.g. by combining it with `openssl` as suggested in the GNU info pages. I reverse engineered the algorithm from GNU shuf's output. There may be bugs. Other modes of shuffling still use `rand`'s `ThreadRng`, though they now sample a uniform distribution directly without going through the slice helper trait. Additionally, switch from `usize` to `u64` for `--input-range` and `--head-count`. This way the same range of numbers can be generated on 32-bit platforms as on 64-bit platforms. --- .../cspell.dictionaries/jargon.wordlist.txt | 2 + src/uu/shuf/locales/en-US.ftl | 1 + src/uu/shuf/src/compat_random_source.rs | 107 ++++++++++++ src/uu/shuf/src/nonrepeating_iterator.rs | 57 +++---- src/uu/shuf/src/rand_read_adapter.rs | 135 --------------- src/uu/shuf/src/shuf.rs | 138 +++++++-------- tests/by-util/test_shuf.rs | 159 ++++++++++++++++++ 7 files changed, 360 insertions(+), 239 deletions(-) create mode 100644 src/uu/shuf/src/compat_random_source.rs delete mode 100644 src/uu/shuf/src/rand_read_adapter.rs diff --git a/.vscode/cspell.dictionaries/jargon.wordlist.txt b/.vscode/cspell.dictionaries/jargon.wordlist.txt index d0957090f4f..972598361ae 100644 --- a/.vscode/cspell.dictionaries/jargon.wordlist.txt +++ b/.vscode/cspell.dictionaries/jargon.wordlist.txt @@ -93,6 +93,7 @@ mergeable microbenchmark microbenchmarks microbenchmarking +monomorphized multibyte multicall nmerge @@ -107,6 +108,7 @@ nolinks nonblock nonportable nonprinting +nonrepeating nonseekable notrunc nowrite diff --git a/src/uu/shuf/locales/en-US.ftl b/src/uu/shuf/locales/en-US.ftl index 913ddaa1d01..477684fb241 100644 --- a/src/uu/shuf/locales/en-US.ftl +++ b/src/uu/shuf/locales/en-US.ftl @@ -20,6 +20,7 @@ shuf-error-failed-to-open-for-writing = failed to open { $file } for writing shuf-error-failed-to-open-random-source = failed to open random source { $file } shuf-error-read-error = read error shuf-error-read-random-bytes = reading random bytes failed +shuf-error-end-of-random-bytes = end of random source shuf-error-no-lines-to-repeat = no lines to repeat shuf-error-start-exceeds-end = start exceeds end shuf-error-missing-dash = missing '-' diff --git a/src/uu/shuf/src/compat_random_source.rs b/src/uu/shuf/src/compat_random_source.rs new file mode 100644 index 00000000000..9d2d1e3b2fb --- /dev/null +++ b/src/uu/shuf/src/compat_random_source.rs @@ -0,0 +1,107 @@ +use std::io::BufRead; + +use uucore::error::{FromIo, UResult, USimpleError}; +use uucore::translate; + +/// A uniform integer generator that tries to exactly match GNU shuf's --random-source. +/// +/// It's not particularly efficient and possibly not quite uniform. It should *only* be +/// used for compatibility with GNU: other modes shouldn't touch this code. +/// +/// All the logic here was black box reverse engineered. It might not match up in all edge +/// cases but it gives identical results on many different large and small inputs. +/// +/// It seems that GNU uses fairly textbook rejection sampling to generate integers, reading +/// one byte at a time until it has enough entropy, and recycling leftover entropy after +/// accepting or rejecting a value. +/// +/// To do your own experiments, start with commands like these: +/// +/// printf '\x01\x02\x03\x04' | shuf -i0-255 -r --random-source=/dev/stdin +/// +/// Then vary the integer range and the input and the input length. It can be useful to +/// see when exactly shuf crashes with an "end of file" error. +/// +/// To spot small inconsistencies it's useful to run: +/// +/// diff -y <(my_shuf ...) <(shuf -i0-{MAX} -r --random-source={INPUT}) | head -n 50 +pub struct RandomSourceAdapter { + reader: R, + state: u64, + entropy: u64, +} + +impl RandomSourceAdapter { + pub fn new(reader: R) -> Self { + Self { + reader, + state: 0, + entropy: 0, + } + } +} + +impl RandomSourceAdapter { + pub fn get_value(&mut self, at_most: u64) -> UResult { + while self.entropy < at_most { + let buf = self + .reader + .fill_buf() + .map_err_context(|| translate!("shuf-error-read-random-bytes"))?; + let Some(&byte) = buf.first() else { + return Err(USimpleError::new( + 1, + translate!("shuf-error-end-of-random-bytes"), + )); + }; + self.reader.consume(1); + // Is overflow OK here? Won't it cause bias? (Seems to work out...) + self.state = self.state.wrapping_mul(256).wrapping_add(byte as u64); + self.entropy = self.entropy.wrapping_mul(256).wrapping_add(255); + } + + if at_most == u64::MAX { + // at_most + 1 would overflow but this case is easy. + let val = self.state; + self.entropy = 0; + self.state = 0; + return Ok(val); + } + + let num_possibilities = at_most + 1; + + // If the generated number falls within this margin at the upper end of the + // range then we retry to avoid modulo bias. + let margin = ((self.entropy as u128 + 1) % num_possibilities as u128) as u64; + let safe_zone = self.entropy - margin; + + if self.state <= safe_zone { + let val = self.state % num_possibilities; + // Reuse the rest of the state. + self.state /= num_possibilities; + // We need this subtraction, otherwise we consume new input slightly more + // slowly than GNU. Not sure if it checks out mathematically. + self.entropy -= at_most; + self.entropy /= num_possibilities; + Ok(val) + } else { + self.state %= num_possibilities; + self.entropy %= num_possibilities; + // I sure hope the compiler optimizes this tail call. + self.get_value(at_most) + } + } + + pub fn shuffle<'a, T>(&mut self, vals: &'a mut [T], amount: usize) -> UResult<&'a mut [T]> { + // Fisher-Yates shuffle. + // TODO: GNU does something different if amount <= vals.len() and the input is stdin. + // The order changes completely and depends on --head-count. + // No clue what they might do differently and why. + let amount = amount.min(vals.len()); + for idx in 0..amount { + let other_idx = self.get_value((vals.len() - idx - 1) as u64)? as usize + idx; + vals.swap(idx, other_idx); + } + Ok(&mut vals[..amount]) + } +} diff --git a/src/uu/shuf/src/nonrepeating_iterator.rs b/src/uu/shuf/src/nonrepeating_iterator.rs index dfefd117863..41a301a0bf4 100644 --- a/src/uu/shuf/src/nonrepeating_iterator.rs +++ b/src/uu/shuf/src/nonrepeating_iterator.rs @@ -1,32 +1,30 @@ // spell-checker:ignore nonrepeating +// TODO: this iterator is not compatible with GNU when --random-source is used + use std::{collections::HashSet, ops::RangeInclusive}; -use rand::{Rng, seq::SliceRandom}; +use uucore::error::UResult; use crate::WrappedRng; enum NumberSet { - AlreadyListed(HashSet), - Remaining(Vec), + AlreadyListed(HashSet), + Remaining(Vec), } pub(crate) struct NonrepeatingIterator<'a> { - range: RangeInclusive, + range: RangeInclusive, rng: &'a mut WrappedRng, - remaining_count: usize, + remaining_count: u64, buf: NumberSet, } impl<'a> NonrepeatingIterator<'a> { - pub(crate) fn new( - range: RangeInclusive, - rng: &'a mut WrappedRng, - amount: usize, - ) -> Self { + pub(crate) fn new(range: RangeInclusive, rng: &'a mut WrappedRng, amount: u64) -> Self { let capped_amount = if range.start() > range.end() { 0 - } else if range == (0..=usize::MAX) { + } else if range == (0..=u64::MAX) { amount } else { amount.min(range.end() - range.start() + 1) @@ -39,12 +37,12 @@ impl<'a> NonrepeatingIterator<'a> { } } - fn produce(&mut self) -> usize { + fn produce(&mut self) -> UResult { debug_assert!(self.range.start() <= self.range.end()); match &mut self.buf { NumberSet::AlreadyListed(already_listed) => { let chosen = loop { - let guess = self.rng.random_range(self.range.clone()); + let guess = self.rng.choose_from_range(self.range.clone())?; let newly_inserted = already_listed.insert(guess); if newly_inserted { break guess; @@ -54,32 +52,32 @@ impl<'a> NonrepeatingIterator<'a> { // the number of attempts to find a number that hasn't been chosen yet increases. // Therefore, we need to switch at some point from "set of already returned values" to "list of remaining values". let range_size = (self.range.end() - self.range.start()).saturating_add(1); - if number_set_should_list_remaining(already_listed.len(), range_size) { + if number_set_should_list_remaining(already_listed.len() as u64, range_size) { let mut remaining = self .range .clone() .filter(|n| !already_listed.contains(n)) .collect::>(); - assert!(remaining.len() >= self.remaining_count); - remaining.partial_shuffle(&mut self.rng, self.remaining_count); - remaining.truncate(self.remaining_count); + assert!(remaining.len() as u64 >= self.remaining_count); + remaining.truncate(self.remaining_count as usize); + self.rng.shuffle(&mut remaining, usize::MAX)?; self.buf = NumberSet::Remaining(remaining); } - chosen + Ok(chosen) } NumberSet::Remaining(remaining_numbers) => { debug_assert!(!remaining_numbers.is_empty()); // We only enter produce() when there is at least one actual element remaining, so popping must always return an element. - remaining_numbers.pop().unwrap() + Ok(remaining_numbers.pop().unwrap()) } } } } impl Iterator for NonrepeatingIterator<'_> { - type Item = usize; + type Item = UResult; - fn next(&mut self) -> Option { + fn next(&mut self) -> Option> { if self.range.is_empty() || self.remaining_count == 0 { return None; } @@ -89,7 +87,7 @@ impl Iterator for NonrepeatingIterator<'_> { } // This could be a method, but it is much easier to test as a stand-alone function. -fn number_set_should_list_remaining(listed_count: usize, range_size: usize) -> bool { +fn number_set_should_list_remaining(listed_count: u64, range_size: u64) -> bool { // Arbitrarily determine the switchover point to be around 25%. This is because: // - HashSet has a large space overhead for the hash table load factor. // - This means that somewhere between 25-40%, the memory required for a "positive" HashSet and a "negative" Vec should be the same. @@ -107,17 +105,17 @@ mod test_number_set_decision { #[test] fn test_stay_positive_large_remaining_first() { - assert_eq!(false, number_set_should_list_remaining(0, usize::MAX)); + assert_eq!(false, number_set_should_list_remaining(0, u64::MAX)); } #[test] fn test_stay_positive_large_remaining_second() { - assert_eq!(false, number_set_should_list_remaining(1, usize::MAX)); + assert_eq!(false, number_set_should_list_remaining(1, u64::MAX)); } #[test] fn test_stay_positive_large_remaining_tenth() { - assert_eq!(false, number_set_should_list_remaining(9, usize::MAX)); + assert_eq!(false, number_set_should_list_remaining(9, u64::MAX)); } #[test] @@ -161,22 +159,19 @@ mod test_number_set_decision { // Ensure that we are overflow-free: #[test] fn test_no_crash_exceed_max_size1() { - assert_eq!(false, number_set_should_list_remaining(12345, usize::MAX)); + assert_eq!(false, number_set_should_list_remaining(12345, u64::MAX)); } #[test] fn test_no_crash_exceed_max_size2() { assert_eq!( true, - number_set_should_list_remaining(usize::MAX - 1, usize::MAX) + number_set_should_list_remaining(u64::MAX - 1, u64::MAX) ); } #[test] fn test_no_crash_exceed_max_size3() { - assert_eq!( - true, - number_set_should_list_remaining(usize::MAX, usize::MAX) - ); + assert_eq!(true, number_set_should_list_remaining(u64::MAX, u64::MAX)); } } diff --git a/src/uu/shuf/src/rand_read_adapter.rs b/src/uu/shuf/src/rand_read_adapter.rs deleted file mode 100644 index 84c7e8bf218..00000000000 --- a/src/uu/shuf/src/rand_read_adapter.rs +++ /dev/null @@ -1,135 +0,0 @@ -// This file is part of the uutils coreutils package. -// -// For the full copyright and license information, please view the LICENSE -// file that was distributed with this source code. -// Copyright 2018 Developers of the Rand project. -// Copyright 2013 The Rust Project Developers. -// -// Licensed under the Apache License, Version 2.0 or the MIT license -// , at your -// option. This file may not be copied, modified, or distributed -// except according to those terms. - -//! A wrapper around any Read to treat it as an RNG. - -use std::cell::Cell; -use std::io::{Error, Read}; -use std::rc::Rc; - -use rand_core::{RngCore, impls}; - -/// An RNG that reads random bytes straight from any type supporting -/// [`std::io::Read`], for example files. -/// -/// This will work best with an infinite reader, but that is not required. -/// -/// This can be used with `/dev/urandom` on Unix but it is recommended to use -/// [`OsRng`] instead. -/// -/// # Panics -/// -/// `ReadRng` uses [`std::io::Read::read_exact`], which retries on interrupts. -/// All other errors from the underlying reader, including when it does not -/// have enough data, will be reported via the public error field (which can -/// be cloned in advance, as it uses [`Rc`]). This field must be checked for -/// errors after every operation. -/// -/// [`OsRng`]: rand::rngs::OsRng -pub struct ReadRng { - reader: R, - pub error: ErrorCell, -} - -pub type ErrorCell = Rc>>; - -impl ReadRng { - /// Create a new `ReadRng` from a `Read`. - pub fn new(r: R) -> Self { - Self { - reader: r, - error: Rc::default(), - } - } - - fn try_fill_bytes(&mut self, dest: &mut [u8]) -> Result<(), Error> { - if dest.is_empty() { - return Ok(()); - } - // Use `std::io::read_exact`, which retries on `ErrorKind::Interrupted`. - self.reader.read_exact(dest) - } -} - -impl RngCore for ReadRng { - fn next_u32(&mut self) -> u32 { - impls::next_u32_via_fill(self) - } - - fn next_u64(&mut self) -> u64 { - impls::next_u64_via_fill(self) - } - - fn fill_bytes(&mut self, dest: &mut [u8]) { - if let Err(err) = self.try_fill_bytes(dest) { - // Failed to deliver random data, so the caller must check the error - // cell before using the result. - self.error.set(Some(err)); - } - } -} - -#[cfg(test)] -mod test { - use std::println; - - use super::ReadRng; - use rand::RngCore; - - #[test] - fn test_reader_rng_u64() { - // transmute from the target to avoid endianness concerns. - #[rustfmt::skip] - let v = [0u8, 0, 0, 0, 0, 0, 0, 1, - 0, 4, 0, 0, 3, 0, 0, 2, - 5, 0, 0, 0, 0, 0, 0, 0]; - let mut rng = ReadRng::new(&v[..]); - - assert_eq!(rng.next_u64(), 1 << 56); - assert_eq!(rng.next_u64(), (2 << 56) + (3 << 32) + (4 << 8)); - assert_eq!(rng.next_u64(), 5); - } - - #[test] - fn test_reader_rng_u32() { - let v = [0u8, 0, 0, 1, 0, 0, 2, 0, 3, 0, 0, 0]; - let mut rng = ReadRng::new(&v[..]); - - assert_eq!(rng.next_u32(), 1 << 24); - assert_eq!(rng.next_u32(), 2 << 16); - assert_eq!(rng.next_u32(), 3); - } - - #[test] - fn test_reader_rng_fill_bytes() { - let v = [1u8, 2, 3, 4, 5, 6, 7, 8]; - let mut w = [0u8; 8]; - - let mut rng = ReadRng::new(&v[..]); - rng.fill_bytes(&mut w); - - assert_eq!(v, w); - } - - #[test] - fn test_reader_rng_insufficient_bytes() { - let v = [1u8, 2, 3, 4, 5, 6, 7, 8]; - let mut w = [0u8; 9]; - - let mut rng = ReadRng::new(&v[..]); - - let result = rng.try_fill_bytes(&mut w); - assert!(result.is_err()); - println!("Error: {}", result.unwrap_err()); - } -} diff --git a/src/uu/shuf/src/shuf.rs b/src/uu/shuf/src/shuf.rs index cc04a3068d6..9a3e3afd468 100644 --- a/src/uu/shuf/src/shuf.rs +++ b/src/uu/shuf/src/shuf.rs @@ -7,12 +7,11 @@ use clap::builder::ValueParser; use clap::{Arg, ArgAction, Command}; -use rand::prelude::SliceRandom; -use rand::seq::IndexedRandom; -use rand::{Rng, RngCore}; +use rand::Rng; +use rand::seq::{IndexedRandom, SliceRandom}; use std::ffi::{OsStr, OsString}; use std::fs::File; -use std::io::{BufWriter, Error, Read, Write, stdin, stdout}; +use std::io::{BufReader, BufWriter, Error, Read, Write, stdin, stdout}; use std::ops::RangeInclusive; use std::path::{Path, PathBuf}; use std::str::FromStr; @@ -21,21 +20,21 @@ use uucore::error::{FromIo, UResult, USimpleError, UUsageError}; use uucore::format_usage; use uucore::translate; +mod compat_random_source; mod nonrepeating_iterator; -mod rand_read_adapter; use nonrepeating_iterator::NonrepeatingIterator; enum Mode { Default(PathBuf), Echo(Vec), - InputRange(RangeInclusive), + InputRange(RangeInclusive), } const BUF_SIZE: usize = 64 * 1024; struct Options { - head_count: usize, + head_count: u64, output: Option, random_source: Option, repeat: bool, @@ -87,11 +86,11 @@ pub fn uumain(args: impl uucore::Args) -> UResult<()> { // Busybox takes the final value which is more typical: later // options override earlier options. head_count: matches - .get_many::(options::HEAD_COUNT) + .get_many::(options::HEAD_COUNT) .unwrap_or_default() .copied() .min() - .unwrap_or(usize::MAX), + .unwrap_or(u64::MAX), output: matches.get_one(options::OUTPUT).cloned(), random_source: matches.get_one(options::RANDOM_SOURCE).cloned(), repeat: matches.get_flag(options::REPEAT), @@ -125,7 +124,8 @@ pub fn uumain(args: impl uucore::Args) -> UResult<()> { let file = File::open(r).map_err_context( || translate!("shuf-error-failed-to-open-random-source", "file" => r.quote()), )?; - WrappedRng::RngFile(rand_read_adapter::ReadRng::new(file)) + let file = BufReader::new(file); + WrappedRng::RngFile(compat_random_source::RandomSourceAdapter::new(file)) } None => WrappedRng::RngDefault(rand::rng()), }; @@ -180,7 +180,7 @@ pub fn uu_app() -> Command { .value_name("COUNT") .action(ArgAction::Append) .help(translate!("shuf-help-head-count")) - .value_parser(usize::from_str), + .value_parser(u64::from_str), ) .arg( Arg::new(options::OUTPUT) @@ -250,12 +250,15 @@ fn split_seps(data: &[u8], sep: u8) -> Vec<&[u8]> { trait Shufable { type Item: Writable; fn is_empty(&self) -> bool; - fn choose(&self, rng: &mut WrappedRng) -> Self::Item; + fn choose(&self, rng: &mut WrappedRng) -> UResult; + // In some modes we shuffle ahead of time and in some as we generate + // so we unfortunately need to double-wrap UResult. + // But it's monomorphized so the optimizer will hopefully Take Care Of It™. fn partial_shuffle<'b>( &'b mut self, rng: &'b mut WrappedRng, - amount: usize, - ) -> impl Iterator; + amount: u64, + ) -> UResult>>; } impl<'a> Shufable for Vec<&'a [u8]> { @@ -265,20 +268,22 @@ impl<'a> Shufable for Vec<&'a [u8]> { (**self).is_empty() } - fn choose(&self, rng: &mut WrappedRng) -> Self::Item { - // Note: "copied()" only copies the reference, not the entire [u8]. - // Returns None if the slice is empty. We checked this before, so - // this is safe. - (**self).choose(rng).unwrap() + fn choose(&self, rng: &mut WrappedRng) -> UResult { + rng.choose(self) } fn partial_shuffle<'b>( &'b mut self, rng: &'b mut WrappedRng, - amount: usize, - ) -> impl Iterator { - // Note: "copied()" only copies the reference, not the entire [u8]. - (**self).partial_shuffle(rng, amount).0.iter().copied() + amount: u64, + ) -> UResult>> { + // On 32-bit platforms it's possible that amount > usize::MAX. + // We saturate as usize::MAX since all of our shuffling modes require storing + // elements in memory so more than usize::MAX elements won't fit anyway. + // (With --repeat an output larger than usize::MAX is possible. But --repeat + // uses `choose()`.) + let amount = usize::try_from(amount).unwrap_or(usize::MAX); + Ok(rng.shuffle(self, amount)?.iter().copied().map(Ok)) } } @@ -289,36 +294,37 @@ impl<'a> Shufable for Vec<&'a OsStr> { (**self).is_empty() } - fn choose(&self, rng: &mut WrappedRng) -> Self::Item { - (**self).choose(rng).unwrap() + fn choose(&self, rng: &mut WrappedRng) -> UResult { + rng.choose(self) } fn partial_shuffle<'b>( &'b mut self, rng: &'b mut WrappedRng, - amount: usize, - ) -> impl Iterator { - (**self).partial_shuffle(rng, amount).0.iter().copied() + amount: u64, + ) -> UResult>> { + let amount = usize::try_from(amount).unwrap_or(usize::MAX); + Ok(rng.shuffle(self, amount)?.iter().copied().map(Ok)) } } -impl Shufable for RangeInclusive { - type Item = usize; +impl Shufable for RangeInclusive { + type Item = u64; fn is_empty(&self) -> bool { self.is_empty() } - fn choose(&self, rng: &mut WrappedRng) -> usize { - rng.random_range(self.clone()) + fn choose(&self, rng: &mut WrappedRng) -> UResult { + rng.choose_from_range(self.clone()) } fn partial_shuffle<'b>( &'b mut self, rng: &'b mut WrappedRng, - amount: usize, - ) -> impl Iterator { - NonrepeatingIterator::new(self.clone(), rng, amount) + amount: u64, + ) -> UResult>> { + Ok(NonrepeatingIterator::new(self.clone(), rng, amount)) } } @@ -338,7 +344,7 @@ impl Writable for &OsStr { } } -impl Writable for usize { +impl Writable for u64 { fn write_all_to(&self, output: &mut impl OsWrite) -> Result<(), Error> { // The itoa crate is surprisingly much more efficient than a formatted write. // It speeds up `shuf -r -n1000000 -i1-1024` by 1.8×. @@ -354,7 +360,6 @@ fn shuf_exec( output: &mut BufWriter>, ) -> UResult<()> { let ctx = || translate!("shuf-error-write-failed"); - let error_cell = rng.get_error_cell(); if opts.repeat { if input.is_empty() { return Err(USimpleError::new( @@ -363,17 +368,16 @@ fn shuf_exec( )); } for _ in 0..opts.head_count { - let r = input.choose(rng); - WrappedRng::check_error(error_cell.as_ref())?; + let r = input.choose(rng)?; r.write_all_to(output).map_err_context(ctx)?; output.write_all(&[opts.sep]).map_err_context(ctx)?; } } else { - let shuffled = input.partial_shuffle(rng, opts.head_count); - WrappedRng::check_error(error_cell.as_ref())?; + let shuffled = input.partial_shuffle(rng, opts.head_count)?; for r in shuffled { + let r = r?; r.write_all_to(output).map_err_context(ctx)?; output.write_all(&[opts.sep]).map_err_context(ctx)?; } @@ -383,10 +387,10 @@ fn shuf_exec( Ok(()) } -fn parse_range(input_range: &str) -> Result, String> { +fn parse_range(input_range: &str) -> Result, String> { if let Some((from, to)) = input_range.split_once('-') { - let begin = from.parse::().map_err(|e| e.to_string())?; - let end = to.parse::().map_err(|e| e.to_string())?; + let begin = from.parse::().map_err(|e| e.to_string())?; + let end = to.parse::().map_err(|e| e.to_string())?; if begin <= end || begin == end + 1 { Ok(begin..=end) } else { @@ -398,48 +402,36 @@ fn parse_range(input_range: &str) -> Result, String> { } enum WrappedRng { - RngFile(rand_read_adapter::ReadRng), RngDefault(rand::rngs::ThreadRng), + RngFile(compat_random_source::RandomSourceAdapter>), } impl WrappedRng { - fn get_error_cell(&self) -> Option { - if let Self::RngFile(adapter) = self { - Some(adapter.error.clone()) - } else { - None - } - } - - fn check_error(error_cell: Option<&rand_read_adapter::ErrorCell>) -> UResult<()> { - if let Some(cell) = error_cell { - if let Some(err) = cell.take() { - return Err(err.map_err_context(|| translate!("shuf-error-read-random-bytes"))); - } - } - Ok(()) - } -} - -impl RngCore for WrappedRng { - fn next_u32(&mut self) -> u32 { + fn choose(&mut self, vals: &[T]) -> UResult { match self { - Self::RngFile(r) => r.next_u32(), - Self::RngDefault(r) => r.next_u32(), + Self::RngDefault(rng) => Ok(*vals.choose(rng).unwrap()), + Self::RngFile(adapter) => { + assert!(!vals.is_empty()); + let idx = adapter.get_value(vals.len() as u64 - 1)? as usize; + Ok(vals[idx]) + } } } - fn next_u64(&mut self) -> u64 { + fn shuffle<'a, T>(&mut self, vals: &'a mut [T], amount: usize) -> UResult<&'a mut [T]> { match self { - Self::RngFile(r) => r.next_u64(), - Self::RngDefault(r) => r.next_u64(), + Self::RngDefault(rng) => Ok(vals.partial_shuffle(rng, amount).0), + Self::RngFile(adapter) => adapter.shuffle(vals, amount), } } - fn fill_bytes(&mut self, dest: &mut [u8]) { + fn choose_from_range(&mut self, range: RangeInclusive) -> UResult { match self { - Self::RngFile(r) => r.fill_bytes(dest), - Self::RngDefault(r) => r.fill_bytes(dest), + Self::RngDefault(rng) => Ok(rng.random_range(range)), + Self::RngFile(adapter) => { + let offset = adapter.get_value(*range.end() - *range.start())?; + Ok(*range.start() + offset) + } } } } diff --git a/tests/by-util/test_shuf.rs b/tests/by-util/test_shuf.rs index 83f02f04918..1b5d2a99c0e 100644 --- a/tests/by-util/test_shuf.rs +++ b/tests/by-util/test_shuf.rs @@ -859,3 +859,162 @@ fn write_errors_are_reported() { .no_stdout() .stderr_is("shuf: write failed: No space left on device\n"); } + +// On 32-bit platforms, if we cast carelessly, this will give no output. +#[test] +fn test_head_count_does_not_overflow_file() { + let (at, mut ucmd) = at_and_ucmd!(); + + at.append("input.txt", "hello\n"); + + ucmd.arg(format!("-n{}", u64::from(u32::MAX) + 1)) + .arg("input.txt") + .succeeds() + .stdout_is("hello\n") + .no_stderr(); +} + +#[test] +fn test_head_count_does_not_overflow_args() { + new_ucmd!() + .arg(format!("-n{}", u64::from(u32::MAX) + 1)) + .arg("-e") + .arg("goodbye") + .succeeds() + .stdout_is("goodbye\n") + .no_stderr(); +} + +#[test] +fn test_head_count_does_not_overflow_range() { + new_ucmd!() + .arg(format!("-n{}", u64::from(u32::MAX) + 1)) + .arg("-i1-1") + .succeeds() + .stdout_is("1\n") + .no_stderr(); +} + +// Test reproducibility and compatibility of --random-source. +// These hard-coded results match those of GNU shuf. They should not be changed. + +#[test] +fn test_gnu_compat_range_repeat() { + let (at, mut ucmd) = at_and_ucmd!(); + at.append_bytes( + "random_bytes.bin", + b"\xfb\x83\x8f\x21\x9b\x3c\x2d\xc5\x73\xa5\x58\x6c\x54\x2f\x59\xf8", + ); + + ucmd.arg("--random-source=random_bytes.bin") + .arg("-r") + .arg("-i1-99") + .fails_with_code(1) + .stderr_is("shuf: end of random source\n") + .stdout_is("38\n30\n10\n26\n23\n61\n46\n99\n75\n43\n10\n89\n10\n44\n24\n59\n22\n51\n"); +} + +#[test] +fn test_gnu_compat_args_no_repeat() { + let (at, mut ucmd) = at_and_ucmd!(); + at.append_bytes( + "random_bytes.bin", + b"\xd1\xfd\xb9\x9a\xf5\x81\x71\x42\xf9\x7a\x59\x79\xd4\x9c\x8c\x7d", + ); + + ucmd.arg("--random-source=random_bytes.bin") + .arg("-e") + .args(&["1", "2", "3", "4", "5", "6", "7"][..]) + .succeeds() + .no_stderr() + .stdout_is("7\n1\n2\n5\n3\n4\n6\n"); +} + +#[test] +fn test_gnu_compat_from_stdin() { + let (at, mut ucmd) = at_and_ucmd!(); + at.append_bytes( + "random_bytes.bin", + b"\xd1\xfd\xb9\x9a\xf5\x81\x71\x42\xf9\x7a\x59\x79\xd4\x9c\x8c\x7d", + ); + + at.append("input.txt", "1\n2\n3\n4\n5\n6\n7\n"); + + ucmd.arg("--random-source=random_bytes.bin") + .set_stdin(at.open("input.txt")) + .succeeds() + .no_stderr() + .stdout_is("7\n1\n2\n5\n3\n4\n6\n"); +} + +#[test] +fn test_gnu_compat_from_file() { + let (at, mut ucmd) = at_and_ucmd!(); + at.append_bytes( + "random_bytes.bin", + b"\xd1\xfd\xb9\x9a\xf5\x81\x71\x42\xf9\x7a\x59\x79\xd4\x9c\x8c\x7d", + ); + + at.append("input.txt", "1\n2\n3\n4\n5\n6\n7\n"); + + ucmd.arg("--random-source=random_bytes.bin") + .arg("input.txt") + .succeeds() + .no_stderr() + .stdout_is("7\n1\n2\n5\n3\n4\n6\n"); +} + +#[test] +fn test_gnu_compat_limited_from_file() { + let (at, mut ucmd) = at_and_ucmd!(); + at.append_bytes( + "random_bytes.bin", + b"\xd1\xfd\xb9\x9a\xf5\x81\x71\x42\xf9\x7a\x59\x79\xd4\x9c\x8c\x7d", + ); + + at.append("input.txt", "1\n2\n3\n4\n5\n6\n7\n"); + + ucmd.arg("--random-source=random_bytes.bin") + .arg("-n5") + .arg("input.txt") + .succeeds() + .no_stderr() + .stdout_is("7\n1\n2\n5\n3\n"); +} + +// This specific case causes GNU to give different results than other modes. +#[ignore = "disabled until fixed"] +#[test] +fn test_gnu_compat_limited_from_stdin() { + let (at, mut ucmd) = at_and_ucmd!(); + at.append_bytes( + "random_bytes.bin", + b"\xd1\xfd\xb9\x9a\xf5\x81\x71\x42\xf9\x7a\x59\x79\xd4\x9c\x8c\x7d", + ); + + at.append("input.txt", "1\n2\n3\n4\n5\n6\n7\n"); + + ucmd.arg("--random-source=random_bytes.bin") + .arg("-n7") + .set_stdin(at.open("input.txt")) + .succeeds() + .no_stderr() + .stdout_is("6\n5\n1\n3\n2\n7\n4\n"); +} + +// We haven't reverse-engineered GNU's nonrepeating integer sampling yet. +#[ignore = "disabled until fixed"] +#[test] +fn test_gnu_compat_range_no_repeat() { + let (at, mut ucmd) = at_and_ucmd!(); + at.append_bytes( + "random_bytes.bin", + b"\xd1\xfd\xb9\x9a\xf5\x81\x71\x42\xf9\x7a\x59\x79\xd4\x9c\x8c\x7d", + ); + + ucmd.arg("--random-source=random_bytes.bin") + .arg("-i1-10") + .succeeds() + .no_stderr() + .stdout_is("10\n2\n8\n7\n3\n9\n6\n5\n1\n4\n"); +} From 6f97141cf44a7b9a4a2739c1e4698eb98bcffd8a Mon Sep 17 00:00:00 2001 From: Jan Verbeek Date: Wed, 26 Mar 2025 15:51:55 +0100 Subject: [PATCH 7/8] shuf: feature: Add --random-seed option MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This adds a new option to get reproducible output from a seed. This was already possible with --random-source, but doing that properly was tricky and had poor performance. Adding this option implies a commitment to keep using the exact same algorithms in the future. For that reason we only use third-party libraries for well-known algorithms and implement our own distributions on top of that. ----- As a teenager on King's Day I once used `shuf` for divination. People paid €0.50 to enter a cramped tent and sat down next to me behind an old netbook. I would ask their name and their sun sign and pipe this information into `shuf --random-source=/dev/stdin`, which selected pseudo-random dictionary words and `tee`d them into `espeak`. If someone's name was too short `shuf` crashed with an end of file error. --random-seed would have worked better. --- .../cspell.dictionaries/jargon.wordlist.txt | 1 + .../cspell.dictionaries/people.wordlist.txt | 3 + Cargo.lock | 2 + Cargo.toml | 1 + src/uu/shuf/Cargo.toml | 2 + src/uu/shuf/locales/en-US.ftl | 1 + src/uu/shuf/src/compat_random_source.rs | 24 +++- src/uu/shuf/src/random_seed.rs | 118 ++++++++++++++++++ src/uu/shuf/src/shuf.rs | 77 ++++++++---- tests/by-util/test_shuf.rs | 49 ++++++++ 10 files changed, 250 insertions(+), 28 deletions(-) create mode 100644 src/uu/shuf/src/random_seed.rs diff --git a/.vscode/cspell.dictionaries/jargon.wordlist.txt b/.vscode/cspell.dictionaries/jargon.wordlist.txt index 972598361ae..bd0e7ae4149 100644 --- a/.vscode/cspell.dictionaries/jargon.wordlist.txt +++ b/.vscode/cspell.dictionaries/jargon.wordlist.txt @@ -55,6 +55,7 @@ fileio filesystem filesystems flamegraph +footgun freeram fsxattr fullblock diff --git a/.vscode/cspell.dictionaries/people.wordlist.txt b/.vscode/cspell.dictionaries/people.wordlist.txt index 8fe38d88538..446c00df4b6 100644 --- a/.vscode/cspell.dictionaries/people.wordlist.txt +++ b/.vscode/cspell.dictionaries/people.wordlist.txt @@ -37,6 +37,9 @@ Boden Garman Chirag B Jadwani Chirag Jadwani +Daniel Lemire + Daniel + Lemire Derek Chiang Derek Chiang diff --git a/Cargo.lock b/Cargo.lock index d0126ab3a4d..2e973866b40 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3824,7 +3824,9 @@ dependencies = [ "fluent", "itoa", "rand 0.9.2", + "rand_chacha 0.9.0", "rand_core 0.9.5", + "sha3", "tempfile", "uucore", ] diff --git a/Cargo.toml b/Cargo.toml index 121b155381f..845dfe67ce5 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -358,6 +358,7 @@ phf_codegen = "0.13.1" platform-info = "2.0.3" procfs = "0.18" rand = { version = "0.9.0", features = ["small_rng"] } +rand_chacha = { version = "0.9.0" } rand_core = "0.9.0" rayon = "1.10" regex = "1.10.4" diff --git a/src/uu/shuf/Cargo.toml b/src/uu/shuf/Cargo.toml index ee4e217d0f4..494de156bb8 100644 --- a/src/uu/shuf/Cargo.toml +++ b/src/uu/shuf/Cargo.toml @@ -21,7 +21,9 @@ path = "src/shuf.rs" clap = { workspace = true } itoa = { workspace = true } rand = { workspace = true } +rand_chacha = { workspace = true } rand_core = { workspace = true } +sha3 = { workspace = true } uucore = { workspace = true } fluent = { workspace = true } diff --git a/src/uu/shuf/locales/en-US.ftl b/src/uu/shuf/locales/en-US.ftl index 477684fb241..de322117983 100644 --- a/src/uu/shuf/locales/en-US.ftl +++ b/src/uu/shuf/locales/en-US.ftl @@ -10,6 +10,7 @@ shuf-help-echo = treat each ARG as an input line shuf-help-input-range = treat each number LO through HI as an input line shuf-help-head-count = output at most COUNT lines shuf-help-output = write result to FILE instead of standard output +shuf-help-random-seed = seed with STRING for reproducible output shuf-help-random-source = get random bytes from FILE shuf-help-repeat = output lines can be repeated shuf-help-zero-terminated = line delimiter is NUL, not newline diff --git a/src/uu/shuf/src/compat_random_source.rs b/src/uu/shuf/src/compat_random_source.rs index 9d2d1e3b2fb..73a7191be73 100644 --- a/src/uu/shuf/src/compat_random_source.rs +++ b/src/uu/shuf/src/compat_random_source.rs @@ -1,4 +1,9 @@ -use std::io::BufRead; +// This file is part of the uutils coreutils package. +// +// For the full copyright and license information, please view the LICENSE +// file that was distributed with this source code. + +use std::{io::BufRead, ops::RangeInclusive}; use uucore::error::{FromIo, UResult, USimpleError}; use uucore::translate; @@ -42,7 +47,7 @@ impl RandomSourceAdapter { } impl RandomSourceAdapter { - pub fn get_value(&mut self, at_most: u64) -> UResult { + fn generate_at_most(&mut self, at_most: u64) -> UResult { while self.entropy < at_most { let buf = self .reader @@ -88,10 +93,21 @@ impl RandomSourceAdapter { self.state %= num_possibilities; self.entropy %= num_possibilities; // I sure hope the compiler optimizes this tail call. - self.get_value(at_most) + self.generate_at_most(at_most) } } + pub fn choose_from_range(&mut self, range: RangeInclusive) -> UResult { + let offset = self.generate_at_most(*range.end() - *range.start())?; + Ok(*range.start() + offset) + } + + pub fn choose_from_slice(&mut self, vals: &[T]) -> UResult { + assert!(!vals.is_empty()); + let idx = self.generate_at_most(vals.len() as u64 - 1)? as usize; + Ok(vals[idx]) + } + pub fn shuffle<'a, T>(&mut self, vals: &'a mut [T], amount: usize) -> UResult<&'a mut [T]> { // Fisher-Yates shuffle. // TODO: GNU does something different if amount <= vals.len() and the input is stdin. @@ -99,7 +115,7 @@ impl RandomSourceAdapter { // No clue what they might do differently and why. let amount = amount.min(vals.len()); for idx in 0..amount { - let other_idx = self.get_value((vals.len() - idx - 1) as u64)? as usize + idx; + let other_idx = self.generate_at_most((vals.len() - idx - 1) as u64)? as usize + idx; vals.swap(idx, other_idx); } Ok(&mut vals[..amount]) diff --git a/src/uu/shuf/src/random_seed.rs b/src/uu/shuf/src/random_seed.rs new file mode 100644 index 00000000000..f66ad62f834 --- /dev/null +++ b/src/uu/shuf/src/random_seed.rs @@ -0,0 +1,118 @@ +// This file is part of the uutils coreutils package. +// +// For the full copyright and license information, please view the LICENSE +// file that was distributed with this source code. + +use std::ops::RangeInclusive; + +use rand::{RngCore as _, SeedableRng as _}; +use rand_chacha::ChaCha12Rng; +use sha3::{Digest as _, Sha3_256}; + +/// Reproducible seeded random number generation. +/// +/// The behavior should stay the same between releases, so don't change it without +/// a very good reason. +/// +/// # How it works +/// +/// - Take a Unicode string as the seed. +/// +/// - Encode this seed as UTF-8. +/// +/// - Take the SHA3-256 hash of the encoded seed. +/// +/// - Use that hash as the input for a [`rand_chacha`] ChaCha12 RNG. +/// (We don't touch the nonce, so that's probably zero.) +/// +/// - Take 64-bit samples from the RNG. +/// +/// - Use Lemire's method to generate uniformly distributed integers and: +/// +/// - With --repeat, use these to pick elements from ranges. +/// +/// - Without --repeat, use these to do left-to-right modern Fisher-Yates. +/// +/// - Or for --input-range without --repeat, do whatever NonrepeatingIterator does. +/// (We may want to change that. Watch this space.) +/// +/// # Why it works like this +/// +/// - Unicode string: Greatest common denominator between platforms. Windows doesn't +/// let you pass raw bytes as a CLI argument and that would be bad practice anyway. +/// A decimal or hex number would work but this is much more flexible without being +/// unmanageable. +/// +/// (Footgun: if the user passes a filename we won't read from the file but the +/// command will run anyway.) +/// +/// - UTF-8: That's what Rust likes and it's the least unreasonable Unicode encoding. +/// +/// - SHA3-256: We want to make good use of the entire user input and SHA-3 is +/// state of the art. ChaCha12 takes a 256-bit seed. +/// +/// - ChaCha12: [`rand`]'s default rng as of writing. Seems state of the art. +/// +/// - 64-bit samples: We could often get away with 32-bit samples but let's keep things +/// simple and only use one width. (There doesn't seem to be much of a performance hit.) +/// +/// - Lemire, Fisher-Yates: These are very easy to implement and maintain ourselves. +/// `rand` provides fancier implementations but only promises reproducibility within +/// patch releases: +/// +/// Strictly speaking even `ChaCha12` is subject to breakage. But since it's a very +/// specific algorithm I assume it's safe in practice. +pub struct SeededRng(Box); + +impl SeededRng { + pub fn new(seed: &str) -> Self { + let mut hasher = Sha3_256::new(); + hasher.update(seed.as_bytes()); + let seed = hasher.finalize(); + let seed = seed.as_slice().try_into().unwrap(); + Self(Box::new(rand_chacha::ChaCha12Rng::from_seed(seed))) + } + + #[allow(clippy::many_single_char_names)] // use original lemire names for easy comparison + fn generate_at_most(&mut self, at_most: u64) -> u64 { + if at_most == u64::MAX { + return self.0.next_u64(); + } + + // https://lemire.me/blog/2019/06/06/nearly-divisionless-random-integer-generation-on-various-systems/ + let s: u64 = at_most + 1; + let mut x: u64 = self.0.next_u64(); + let mut m: u128 = u128::from(x) * u128::from(s); + let mut l: u64 = m as u64; + if l < s { + let t: u64 = s.wrapping_neg() % s; + while l < t { + x = self.0.next_u64(); + m = u128::from(x) * u128::from(s); + l = m as u64; + } + } + (m >> 64) as u64 + } + + pub fn choose_from_range(&mut self, range: RangeInclusive) -> u64 { + let offset = self.generate_at_most(*range.end() - *range.start()); + *range.start() + offset + } + + pub fn choose_from_slice(&mut self, vals: &[T]) -> T { + assert!(!vals.is_empty()); + let idx = self.generate_at_most(vals.len() as u64 - 1) as usize; + vals[idx] + } + + pub fn shuffle<'a, T>(&mut self, vals: &'a mut [T], amount: usize) -> &'a mut [T] { + // Fisher-Yates shuffle. + let amount = amount.min(vals.len()); + for idx in 0..amount { + let other_idx = self.generate_at_most((vals.len() - idx - 1) as u64) as usize + idx; + vals.swap(idx, other_idx); + } + &mut vals[..amount] + } +} diff --git a/src/uu/shuf/src/shuf.rs b/src/uu/shuf/src/shuf.rs index 9a3e3afd468..e2cb2e95850 100644 --- a/src/uu/shuf/src/shuf.rs +++ b/src/uu/shuf/src/shuf.rs @@ -5,16 +5,20 @@ // spell-checker:ignore (ToDO) cmdline evec nonrepeating seps shufable rvec fdata -use clap::builder::ValueParser; -use clap::{Arg, ArgAction, Command}; -use rand::Rng; -use rand::seq::{IndexedRandom, SliceRandom}; use std::ffi::{OsStr, OsString}; use std::fs::File; use std::io::{BufReader, BufWriter, Error, Read, Write, stdin, stdout}; use std::ops::RangeInclusive; use std::path::{Path, PathBuf}; use std::str::FromStr; + +use clap::{Arg, ArgAction, Command, builder::ValueParser}; +use rand::rngs::ThreadRng; +use rand::{ + Rng, + seq::{IndexedRandom, SliceRandom}, +}; + use uucore::display::{OsWrite, Quotable}; use uucore::error::{FromIo, UResult, USimpleError, UUsageError}; use uucore::format_usage; @@ -22,8 +26,11 @@ use uucore::translate; mod compat_random_source; mod nonrepeating_iterator; +mod random_seed; +use compat_random_source::RandomSourceAdapter; use nonrepeating_iterator::NonrepeatingIterator; +use random_seed::SeededRng; enum Mode { Default(PathBuf), @@ -36,17 +43,24 @@ const BUF_SIZE: usize = 64 * 1024; struct Options { head_count: u64, output: Option, - random_source: Option, + random_source: RandomSource, repeat: bool, sep: u8, } +enum RandomSource { + None, + Seed(String), + File(PathBuf), +} + mod options { pub static ECHO: &str = "echo"; pub static INPUT_RANGE: &str = "input-range"; pub static HEAD_COUNT: &str = "head-count"; pub static OUTPUT: &str = "output"; pub static RANDOM_SOURCE: &str = "random-source"; + pub static RANDOM_SEED: &str = "random-seed"; pub static REPEAT: &str = "repeat"; pub static ZERO_TERMINATED: &str = "zero-terminated"; pub static FILE_OR_ARGS: &str = "file-or-args"; @@ -80,6 +94,14 @@ pub fn uumain(args: impl uucore::Args) -> UResult<()> { Mode::Default(file.into()) }; + let random_source = if let Some(filename) = matches.get_one(options::RANDOM_SOURCE).cloned() { + RandomSource::File(filename) + } else if let Some(seed) = matches.get_one(options::RANDOM_SEED).cloned() { + RandomSource::Seed(seed) + } else { + RandomSource::None + }; + let options = Options { // GNU shuf takes the lowest value passed, so we imitate that. // It's probably a bug or an implementation artifact though. @@ -92,7 +114,7 @@ pub fn uumain(args: impl uucore::Args) -> UResult<()> { .min() .unwrap_or(u64::MAX), output: matches.get_one(options::OUTPUT).cloned(), - random_source: matches.get_one(options::RANDOM_SOURCE).cloned(), + random_source, repeat: matches.get_flag(options::REPEAT), sep: if matches.get_flag(options::ZERO_TERMINATED) { b'\0' @@ -120,14 +142,15 @@ pub fn uumain(args: impl uucore::Args) -> UResult<()> { } let mut rng = match options.random_source { - Some(ref r) => { + RandomSource::None => WrappedRng::Default(rand::rng()), + RandomSource::Seed(ref seed) => WrappedRng::Seed(SeededRng::new(seed)), + RandomSource::File(ref r) => { let file = File::open(r).map_err_context( || translate!("shuf-error-failed-to-open-random-source", "file" => r.quote()), )?; let file = BufReader::new(file); - WrappedRng::RngFile(compat_random_source::RandomSourceAdapter::new(file)) + WrappedRng::File(compat_random_source::RandomSourceAdapter::new(file)) } - None => WrappedRng::RngDefault(rand::rng()), }; match mode { @@ -191,6 +214,15 @@ pub fn uu_app() -> Command { .value_parser(ValueParser::path_buf()) .value_hint(clap::ValueHint::FilePath), ) + .arg( + Arg::new(options::RANDOM_SEED) + .long(options::RANDOM_SEED) + .value_name("STRING") + .help(translate!("shuf-help-random-seed")) + .value_parser(ValueParser::string()) + .value_hint(clap::ValueHint::Other) + .conflicts_with(options::RANDOM_SOURCE), + ) .arg( Arg::new(options::RANDOM_SOURCE) .long(options::RANDOM_SOURCE) @@ -402,36 +434,33 @@ fn parse_range(input_range: &str) -> Result, String> { } enum WrappedRng { - RngDefault(rand::rngs::ThreadRng), - RngFile(compat_random_source::RandomSourceAdapter>), + Default(ThreadRng), + Seed(SeededRng), + File(RandomSourceAdapter>), } impl WrappedRng { fn choose(&mut self, vals: &[T]) -> UResult { match self { - Self::RngDefault(rng) => Ok(*vals.choose(rng).unwrap()), - Self::RngFile(adapter) => { - assert!(!vals.is_empty()); - let idx = adapter.get_value(vals.len() as u64 - 1)? as usize; - Ok(vals[idx]) - } + Self::Default(rng) => Ok(*vals.choose(rng).unwrap()), + Self::Seed(rng) => Ok(rng.choose_from_slice(vals)), + Self::File(rng) => rng.choose_from_slice(vals), } } fn shuffle<'a, T>(&mut self, vals: &'a mut [T], amount: usize) -> UResult<&'a mut [T]> { match self { - Self::RngDefault(rng) => Ok(vals.partial_shuffle(rng, amount).0), - Self::RngFile(adapter) => adapter.shuffle(vals, amount), + Self::Default(rng) => Ok(vals.partial_shuffle(rng, amount).0), + Self::Seed(rng) => Ok(rng.shuffle(vals, amount)), + Self::File(rng) => rng.shuffle(vals, amount), } } fn choose_from_range(&mut self, range: RangeInclusive) -> UResult { match self { - Self::RngDefault(rng) => Ok(rng.random_range(range)), - Self::RngFile(adapter) => { - let offset = adapter.get_value(*range.end() - *range.start())?; - Ok(*range.start() + offset) - } + Self::Default(rng) => Ok(rng.random_range(range)), + Self::Seed(rng) => Ok(rng.choose_from_range(range)), + Self::File(rng) => rng.choose_from_range(range), } } } diff --git a/tests/by-util/test_shuf.rs b/tests/by-util/test_shuf.rs index 1b5d2a99c0e..b419419ecde 100644 --- a/tests/by-util/test_shuf.rs +++ b/tests/by-util/test_shuf.rs @@ -1018,3 +1018,52 @@ fn test_gnu_compat_range_no_repeat() { .no_stderr() .stdout_is("10\n2\n8\n7\n3\n9\n6\n5\n1\n4\n"); } + +// Test reproducibility of --random-seed. +// These results are arbitrary but they should not change unless we choose to break compatibility. + +#[test] +fn test_seed_args_repeat() { + new_ucmd!() + .arg("--random-seed=🌱") + .arg("-e") + .arg("-r") + .arg("-n10") + .args(&["foo", "bar", "baz", "qux"]) + .succeeds() + .no_stderr() + .stdout_is("qux\nbar\nbaz\nfoo\nbaz\nqux\nqux\nfoo\nqux\nqux\n"); +} + +#[test] +fn test_seed_args_no_repeat() { + new_ucmd!() + .arg("--random-seed=🌱") + .arg("-e") + .args(&["foo", "bar", "baz", "qux"]) + .succeeds() + .no_stderr() + .stdout_is("qux\nbaz\nfoo\nbar\n"); +} + +#[test] +fn test_seed_range_repeat() { + new_ucmd!() + .arg("--random-seed=🦀") + .arg("-r") + .arg("-i1-99") + .arg("-n10") + .succeeds() + .no_stderr() + .stdout_is("60\n44\n38\n41\n63\n43\n31\n71\n46\n90\n"); +} + +#[test] +fn test_seed_range_no_repeat() { + new_ucmd!() + .arg("--random-seed=12345") + .arg("-i1-10") + .succeeds() + .no_stderr() + .stdout_is("8\n9\n5\n10\n1\n2\n4\n7\n3\n6\n"); +} From a1f4d08600bef472b5c12266368edc319de192de Mon Sep 17 00:00:00 2001 From: Jan Verbeek Date: Wed, 26 Mar 2025 17:49:49 +0100 Subject: [PATCH 8/8] shuf: correctness: Use Fisher-Yates for nonrepeating integers We used to use a clever homegrown way to sample integers. But GNU shuf with --random-source observably uses Fisher-Yates, and the output of the old version depended on a heuristic (making it dangerous for --random-seed). So now we do Fisher-Yates here, just like we do for other inputs. In deterministic modes the output for --input-range is identical that for piping `seq` into `shuf`. We imitate the old algorithm's method for keeping the resource use in check. The performance of the new version is very close to that of the old version: I haven't found any cases where it's much faster or much slower. --- src/uu/shuf/src/nonrepeating_iterator.rs | 230 ++++++++--------------- src/uu/shuf/src/random_seed.rs | 3 - src/uu/shuf/src/shuf.rs | 3 +- tests/by-util/test_shuf.rs | 51 ++++- 4 files changed, 132 insertions(+), 155 deletions(-) diff --git a/src/uu/shuf/src/nonrepeating_iterator.rs b/src/uu/shuf/src/nonrepeating_iterator.rs index 41a301a0bf4..d05844ba9d1 100644 --- a/src/uu/shuf/src/nonrepeating_iterator.rs +++ b/src/uu/shuf/src/nonrepeating_iterator.rs @@ -1,74 +1,85 @@ -// spell-checker:ignore nonrepeating - -// TODO: this iterator is not compatible with GNU when --random-source is used - -use std::{collections::HashSet, ops::RangeInclusive}; +use std::collections::HashMap; +use std::ops::RangeInclusive; use uucore::error::UResult; use crate::WrappedRng; -enum NumberSet { - AlreadyListed(HashSet), - Remaining(Vec), -} - +/// An iterator that samples from an integer range without repetition. +/// +/// This is based on Fisher-Yates, and it's required for backward compatibility +/// that it behaves exactly like Fisher-Yates if --random-source or --random-seed +/// is used. But we have a few tricks: +/// +/// - In the beginning we use a hash table instead of an array. This way we lazily +/// keep track of swaps without allocating the entire range upfront. +/// +/// - When the hash table starts to get big relative to the remaining items +/// we switch over to an array. +/// +/// - We store the array backwards so that we can shrink it as we go and free excess +/// memory every now and then. +/// +/// Both the hash table and the array give the same output. +/// +/// There's room for optimization: +/// +/// - Switching over from the hash table to the array is costly. If we happen to know +/// (through --head-count) that only few draws remain then it would be better not +/// to switch. +/// +/// - If the entire range gets used then we might as well allocate an array to start +/// with. But if the user e.g. pipes through `head` rather than using --head-count +/// we can't know whether that's the case, so there's a tradeoff. +/// +/// GNU decides the other way: --head-count is noticeably faster than | head. pub(crate) struct NonrepeatingIterator<'a> { - range: RangeInclusive, rng: &'a mut WrappedRng, - remaining_count: u64, - buf: NumberSet, + values: Values, +} + +enum Values { + Full(Vec), + Sparse(RangeInclusive, HashMap), } impl<'a> NonrepeatingIterator<'a> { - pub(crate) fn new(range: RangeInclusive, rng: &'a mut WrappedRng, amount: u64) -> Self { - let capped_amount = if range.start() > range.end() { - 0 - } else if range == (0..=u64::MAX) { - amount - } else { - amount.min(range.end() - range.start() + 1) - }; - NonrepeatingIterator { - range, - rng, - remaining_count: capped_amount, - buf: NumberSet::AlreadyListed(HashSet::default()), - } + pub(crate) fn new(range: RangeInclusive, rng: &'a mut WrappedRng) -> Self { + let values = Values::Sparse(range, HashMap::default()); + NonrepeatingIterator { rng, values } } fn produce(&mut self) -> UResult { - debug_assert!(self.range.start() <= self.range.end()); - match &mut self.buf { - NumberSet::AlreadyListed(already_listed) => { - let chosen = loop { - let guess = self.rng.choose_from_range(self.range.clone())?; - let newly_inserted = already_listed.insert(guess); - if newly_inserted { - break guess; - } - }; - // Once a significant fraction of the interval has already been enumerated, - // the number of attempts to find a number that hasn't been chosen yet increases. - // Therefore, we need to switch at some point from "set of already returned values" to "list of remaining values". - let range_size = (self.range.end() - self.range.start()).saturating_add(1); - if number_set_should_list_remaining(already_listed.len() as u64, range_size) { - let mut remaining = self - .range - .clone() - .filter(|n| !already_listed.contains(n)) - .collect::>(); - assert!(remaining.len() as u64 >= self.remaining_count); - remaining.truncate(self.remaining_count as usize); - self.rng.shuffle(&mut remaining, usize::MAX)?; - self.buf = NumberSet::Remaining(remaining); + match &mut self.values { + Values::Full(items) => { + let this_idx = items.len() - 1; + + let other_idx = self.rng.choose_from_range(0..=items.len() as u64 - 1)? as usize; + // Flip the index to pretend we're going left-to-right + let other_idx = items.len() - other_idx - 1; + + items.swap(this_idx, other_idx); + + let val = items.pop().unwrap(); + if items.len().is_power_of_two() && items.len() >= 512 { + items.shrink_to_fit(); } - Ok(chosen) + Ok(val) } - NumberSet::Remaining(remaining_numbers) => { - debug_assert!(!remaining_numbers.is_empty()); - // We only enter produce() when there is at least one actual element remaining, so popping must always return an element. - Ok(remaining_numbers.pop().unwrap()) + Values::Sparse(range, items) => { + let this_idx = *range.start(); + let this_val = items.remove(&this_idx).unwrap_or(this_idx); + + let other_idx = self.rng.choose_from_range(range.clone())?; + + let val = if this_idx == other_idx { + this_val + } else { + items.insert(other_idx, this_val).unwrap_or(other_idx) + }; + *range = *range.start() + 1..=*range.end(); + + Ok(val) } } } @@ -77,101 +88,24 @@ impl<'a> NonrepeatingIterator<'a> { impl Iterator for NonrepeatingIterator<'_> { type Item = UResult; - fn next(&mut self) -> Option> { - if self.range.is_empty() || self.remaining_count == 0 { - return None; + fn next(&mut self) -> Option { + match &self.values { + Values::Full(items) if items.is_empty() => return None, + Values::Full(_) => (), + Values::Sparse(range, _) if range.is_empty() => return None, + Values::Sparse(range, items) => { + let range_len = range.size_hint().0 as u64; + if items.len() as u64 >= range_len / 8 { + self.values = Values::Full(hashmap_to_vec(range.clone(), items)); + } + } } - self.remaining_count -= 1; + Some(self.produce()) } } -// This could be a method, but it is much easier to test as a stand-alone function. -fn number_set_should_list_remaining(listed_count: u64, range_size: u64) -> bool { - // Arbitrarily determine the switchover point to be around 25%. This is because: - // - HashSet has a large space overhead for the hash table load factor. - // - This means that somewhere between 25-40%, the memory required for a "positive" HashSet and a "negative" Vec should be the same. - // - HashSet has a small but non-negligible overhead for each lookup, so we have a slight preference for Vec anyway. - // - At 25%, on average 1.33 attempts are needed to find a number that hasn't been taken yet. - // - Finally, "24%" is computationally the simplest: - listed_count >= range_size / 4 -} - -#[cfg(test)] -// Since the computed value is a bool, it is more readable to write the expected value out: -#[allow(clippy::bool_assert_comparison)] -mod test_number_set_decision { - use super::number_set_should_list_remaining; - - #[test] - fn test_stay_positive_large_remaining_first() { - assert_eq!(false, number_set_should_list_remaining(0, u64::MAX)); - } - - #[test] - fn test_stay_positive_large_remaining_second() { - assert_eq!(false, number_set_should_list_remaining(1, u64::MAX)); - } - - #[test] - fn test_stay_positive_large_remaining_tenth() { - assert_eq!(false, number_set_should_list_remaining(9, u64::MAX)); - } - - #[test] - fn test_stay_positive_smallish_range_first() { - assert_eq!(false, number_set_should_list_remaining(0, 12345)); - } - - #[test] - fn test_stay_positive_smallish_range_second() { - assert_eq!(false, number_set_should_list_remaining(1, 12345)); - } - - #[test] - fn test_stay_positive_smallish_range_tenth() { - assert_eq!(false, number_set_should_list_remaining(9, 12345)); - } - - #[test] - fn test_stay_positive_small_range_not_too_early() { - assert_eq!(false, number_set_should_list_remaining(1, 10)); - } - - // Don't want to test close to the border, in case we decide to change the threshold. - // However, at 50% coverage, we absolutely should switch: - #[test] - fn test_switch_half() { - assert_eq!(true, number_set_should_list_remaining(1234, 2468)); - } - - // Ensure that the decision is monotonous: - #[test] - fn test_switch_late1() { - assert_eq!(true, number_set_should_list_remaining(12340, 12345)); - } - - #[test] - fn test_switch_late2() { - assert_eq!(true, number_set_should_list_remaining(12344, 12345)); - } - - // Ensure that we are overflow-free: - #[test] - fn test_no_crash_exceed_max_size1() { - assert_eq!(false, number_set_should_list_remaining(12345, u64::MAX)); - } - - #[test] - fn test_no_crash_exceed_max_size2() { - assert_eq!( - true, - number_set_should_list_remaining(u64::MAX - 1, u64::MAX) - ); - } - - #[test] - fn test_no_crash_exceed_max_size3() { - assert_eq!(true, number_set_should_list_remaining(u64::MAX, u64::MAX)); - } +fn hashmap_to_vec(range: RangeInclusive, map: &HashMap) -> Vec { + let lookup = |idx| *map.get(&idx).unwrap_or(&idx); + range.rev().map(lookup).collect() } diff --git a/src/uu/shuf/src/random_seed.rs b/src/uu/shuf/src/random_seed.rs index f66ad62f834..dbc6c728c19 100644 --- a/src/uu/shuf/src/random_seed.rs +++ b/src/uu/shuf/src/random_seed.rs @@ -33,9 +33,6 @@ use sha3::{Digest as _, Sha3_256}; /// /// - Without --repeat, use these to do left-to-right modern Fisher-Yates. /// -/// - Or for --input-range without --repeat, do whatever NonrepeatingIterator does. -/// (We may want to change that. Watch this space.) -/// /// # Why it works like this /// /// - Unicode string: Greatest common denominator between platforms. Windows doesn't diff --git a/src/uu/shuf/src/shuf.rs b/src/uu/shuf/src/shuf.rs index e2cb2e95850..73290a0fc27 100644 --- a/src/uu/shuf/src/shuf.rs +++ b/src/uu/shuf/src/shuf.rs @@ -356,7 +356,8 @@ impl Shufable for RangeInclusive { rng: &'b mut WrappedRng, amount: u64, ) -> UResult>> { - Ok(NonrepeatingIterator::new(self.clone(), rng, amount)) + let amount = usize::try_from(amount).unwrap_or(usize::MAX); + Ok(NonrepeatingIterator::new(self.clone(), rng).take(amount)) } } diff --git a/tests/by-util/test_shuf.rs b/tests/by-util/test_shuf.rs index b419419ecde..948b3ed0756 100644 --- a/tests/by-util/test_shuf.rs +++ b/tests/by-util/test_shuf.rs @@ -4,6 +4,8 @@ // file that was distributed with this source code. // spell-checker:ignore (ToDO) unwritable +use std::fmt::Write; + use uutests::at_and_ucmd; use uutests::new_ucmd; @@ -1002,8 +1004,6 @@ fn test_gnu_compat_limited_from_stdin() { .stdout_is("6\n5\n1\n3\n2\n7\n4\n"); } -// We haven't reverse-engineered GNU's nonrepeating integer sampling yet. -#[ignore = "disabled until fixed"] #[test] fn test_gnu_compat_range_no_repeat() { let (at, mut ucmd) = at_and_ucmd!(); @@ -1060,10 +1060,55 @@ fn test_seed_range_repeat() { #[test] fn test_seed_range_no_repeat() { + let expected = "8\n9\n1\n5\n2\n6\n4\n3\n10\n7\n"; + new_ucmd!() .arg("--random-seed=12345") .arg("-i1-10") .succeeds() .no_stderr() - .stdout_is("8\n9\n5\n10\n1\n2\n4\n7\n3\n6\n"); + .stdout_is(expected); + + // Piping from e.g. seq gives identical results. + new_ucmd!() + .arg("--random-seed=12345") + .pipe_in("1\n2\n3\n4\n5\n6\n7\n8\n9\n10\n") + .succeeds() + .no_stderr() + .stdout_is(expected); +} + +// Test a longer input to exercise some more code paths in the sparse representation. +#[test] +fn test_seed_long_range_no_repeat() { + let expected = "\ + 1\n3\n35\n37\n36\n45\n72\n17\n18\n40\n67\n74\n81\n77\n14\n90\n\ + 7\n12\n80\n54\n23\n61\n29\n41\n15\n56\n6\n32\n82\n76\n11\n2\n100\n\ + 50\n60\n97\n73\n79\n91\n89\n85\n86\n66\n70\n22\n55\n8\n83\n39\n27\n"; + + new_ucmd!() + .arg("--random-seed=67890") + .arg("-i1-100") + .arg("-n50") + .succeeds() + .no_stderr() + .stdout_is(expected); + + let mut test_input = String::new(); + for n in 1..=100 { + writeln!(&mut test_input, "{n}").unwrap(); + } + + new_ucmd!() + .arg("--random-seed=67890") + .pipe_in(test_input.as_bytes()) + .arg("-n50") + .succeeds() + .no_stderr() + .stdout_is(expected); +} + +#[test] +fn test_empty_range_no_repeat() { + new_ucmd!().arg("-i4-3").succeeds().no_stderr().no_stdout(); }