Skip to content
Merged
39 changes: 25 additions & 14 deletions crates/libtest2-harness/src/context.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
pub(crate) use crate::*;

#[derive(Debug)]
pub struct TestContext {
mode: RunMode,
run_ignored: bool,
pub(crate) start: std::time::Instant,
pub(crate) mode: RunMode,
pub(crate) run_ignored: bool,
pub(crate) notifier: notify::ArcNotifier,
pub(crate) test_name: String,
}

impl TestContext {
Expand All @@ -26,21 +28,30 @@ impl TestContext {
pub fn current_mode(&self) -> RunMode {
self.mode
}
}

impl TestContext {
pub(crate) fn new() -> Self {
Self {
mode: Default::default(),
run_ignored: false,
}
pub fn notify(&self, event: notify::Event) -> std::io::Result<()> {
self.notifier().notify(event)
}

pub fn elapased_s(&self) -> notify::Elapsed {
notify::Elapsed(self.start.elapsed())
}

pub fn test_name(&self) -> &str {
&self.test_name
}

pub(crate) fn set_mode(&mut self, mode: RunMode) {
self.mode = mode;
pub(crate) fn notifier(&self) -> &notify::ArcNotifier {
&self.notifier
}

pub(crate) fn set_run_ignored(&mut self, yes: bool) {
self.run_ignored = yes;
pub(crate) fn clone(&self) -> Self {
Self {
start: self.start,
mode: self.mode,
run_ignored: self.run_ignored,
notifier: self.notifier.clone(),
test_name: self.test_name.clone(),
}
}
}
160 changes: 51 additions & 109 deletions crates/libtest2-harness/src/harness.rs
Original file line number Diff line number Diff line change
Expand Up @@ -83,14 +83,14 @@ impl Harness<StateArgs> {
pub struct StateParsed {
start: std::time::Instant,
opts: libtest_lexarg::TestOpts,
notifier: Box<dyn notify::Notifier>,
notifier: notify::ArcNotifier,
}
impl HarnessState for StateParsed {}
impl sealed::_HarnessState_is_Sealed for StateParsed {}

impl Harness<StateParsed> {
pub fn discover(
mut self,
self,
cases: impl IntoIterator<Item = impl Case + 'static>,
) -> std::io::Result<Harness<StateDiscovered>> {
self.state.notifier.notify(
Expand Down Expand Up @@ -144,22 +144,22 @@ impl Harness<StateParsed> {
pub struct StateDiscovered {
start: std::time::Instant,
opts: libtest_lexarg::TestOpts,
notifier: Box<dyn notify::Notifier>,
notifier: notify::ArcNotifier,
cases: Vec<Box<dyn Case>>,
}
impl HarnessState for StateDiscovered {}
impl sealed::_HarnessState_is_Sealed for StateDiscovered {}

impl Harness<StateDiscovered> {
pub fn run(mut self) -> std::io::Result<bool> {
pub fn run(self) -> std::io::Result<bool> {
if self.state.opts.list {
Ok(true)
} else {
run(
&self.state.start,
&self.state.opts,
self.state.cases,
self.state.notifier.as_mut(),
self.state.notifier,
)
}
}
Expand Down Expand Up @@ -252,16 +252,16 @@ fn parse<'p>(parser: &mut cli::Parser<'p>) -> Result<libtest_lexarg::TestOpts, c
Ok(opts)
}

fn notifier(opts: &libtest_lexarg::TestOpts) -> Box<dyn notify::Notifier> {
fn notifier(opts: &libtest_lexarg::TestOpts) -> notify::ArcNotifier {
#[cfg(feature = "color")]
let stdout = anstream::stdout();
#[cfg(not(feature = "color"))]
let stdout = std::io::stdout();
match opts.format {
OutputFormat::Json => Box::new(notify::JsonNotifier::new(stdout)),
_ if opts.list => Box::new(notify::TerseListNotifier::new(stdout)),
OutputFormat::Pretty => Box::new(notify::PrettyRunNotifier::new(stdout)),
OutputFormat::Terse => Box::new(notify::TerseRunNotifier::new(stdout)),
OutputFormat::Json => notify::ArcNotifier::new(notify::JsonNotifier::new(stdout)),
_ if opts.list => notify::ArcNotifier::new(notify::TerseListNotifier::new(stdout)),
OutputFormat::Pretty => notify::ArcNotifier::new(notify::PrettyRunNotifier::new(stdout)),
OutputFormat::Terse => notify::ArcNotifier::new(notify::TerseRunNotifier::new(stdout)),
}
}

Expand Down Expand Up @@ -292,7 +292,7 @@ fn run(
start: &std::time::Instant,
opts: &libtest_lexarg::TestOpts,
cases: Vec<Box<dyn Case>>,
notifier: &mut dyn notify::Notifier,
notifier: notify::ArcNotifier,
) -> std::io::Result<bool> {
notifier.notify(
notify::event::RunStart {
Expand All @@ -316,7 +316,6 @@ fn run(

let threads = opts.test_threads.map(|t| t.get()).unwrap_or(1);

let mut context = TestContext::new();
let run_ignored = match opts.run_ignored {
libtest_lexarg::RunIgnored::Yes | libtest_lexarg::RunIgnored::Only => true,
libtest_lexarg::RunIgnored::No => false,
Expand All @@ -331,9 +330,13 @@ fn run(
(false, true) => RunMode::Bench,
(false, false) => unreachable!("libtest-lexarg` should always ensure at least one is set"),
};
context.set_mode(mode);
context.set_run_ignored(run_ignored);
let context = std::sync::Arc::new(context);
let context = TestContext {
start: *start,
mode,
run_ignored,
notifier,
test_name: String::new(),
};

let mut success = true;

Expand All @@ -345,88 +348,49 @@ fn run(
.partition::<Vec<_>, _>(|c| c.exclusive(&context))
};
if !concurrent_cases.is_empty() {
notifier.threaded(true);
struct RunningTest {
join_handle: std::thread::JoinHandle<()>,
}

impl RunningTest {
fn join(
self,
start: &std::time::Instant,
event: &notify::event::CaseComplete,
notifier: &mut dyn notify::Notifier,
) -> std::io::Result<()> {
if self.join_handle.join().is_err() {
let kind = notify::MessageKind::Error;
let message = Some("panicked after reporting success".to_owned());
notifier.notify(
notify::event::CaseMessage {
name: event.name.clone(),
kind,
message,
elapsed_s: Some(notify::Elapsed(start.elapsed())),
}
.into(),
)?;
}
Ok(())
}
}
context.notifier().threaded(true);

// Use a deterministic hasher
type TestMap = std::collections::HashMap<
String,
RunningTest,
std::thread::JoinHandle<std::io::Result<bool>>,
std::hash::BuildHasherDefault<std::collections::hash_map::DefaultHasher>,
>;

let sync_success = std::sync::Arc::new(std::sync::atomic::AtomicBool::new(success));
let mut running_tests: TestMap = Default::default();
let mut pending = 0;
let (tx, rx) = std::sync::mpsc::channel::<notify::Event>();
let mut running: TestMap = Default::default();
let (tx, rx) = std::sync::mpsc::channel::<String>();
let mut remaining = std::collections::VecDeque::from(concurrent_cases);
while pending > 0 || !remaining.is_empty() {
while pending < threads && !remaining.is_empty() {
while !running.is_empty() || !remaining.is_empty() {
while running.len() < threads && !remaining.is_empty() {
let case = remaining.pop_front().unwrap();
let case = std::sync::Arc::new(case);
let name = case.name().to_owned();

let cfg = std::thread::Builder::new().name(name.clone());
let start = *start;
let tx = tx.clone();
let case = std::sync::Arc::new(case);
let case_fallback = case.clone();
let context = context.clone();
let context_fallback = context.clone();
let sync_success = sync_success.clone();
let sync_success_fallback = sync_success.clone();
let thread_tx = tx.clone();
let thread_case = case.clone();
let mut thread_context = context.clone();
thread_context.test_name = name.clone();
let thread_sync_success = sync_success.clone();
let join_handle = cfg.spawn(move || {
let mut notifier = SenderNotifier { tx: tx.clone() };
let case_success =
run_case(&start, case.as_ref().as_ref(), &context, &mut notifier)
.expect("`SenderNotifier` is infallible");
if !case_success {
sync_success.store(case_success, std::sync::atomic::Ordering::Relaxed);
let status = run_case(thread_case.as_ref().as_ref(), &thread_context);
if !matches!(status, Ok(true)) {
thread_sync_success.store(false, std::sync::atomic::Ordering::Relaxed);
}
let _ = thread_tx.send(thread_case.name().to_owned());
status
});
match join_handle {
Ok(join_handle) => {
running_tests.insert(name.clone(), RunningTest { join_handle });
pending += 1;
running.insert(name.clone(), join_handle);
}
Err(e) if e.kind() == std::io::ErrorKind::WouldBlock => {
// `ErrorKind::WouldBlock` means hitting the thread limit on some
// platforms, so run the test synchronously here instead.
let case_success = run_case(
&start,
case_fallback.as_ref().as_ref(),
&context_fallback,
notifier,
)
.expect("`SenderNotifier` is infallible");
let case_success = run_case(case.as_ref().as_ref(), &context)?;
if !case_success {
sync_success_fallback
.store(case_success, std::sync::atomic::Ordering::Relaxed);
sync_success.store(case_success, std::sync::atomic::Ordering::Relaxed);
}
}
Err(e) => {
Expand All @@ -435,13 +399,9 @@ fn run(
}
}

let event = rx.recv().unwrap();
if let notify::Event::CaseComplete(event) = &event {
let running_test = running_tests.remove(&event.name).unwrap();
running_test.join(start, event, notifier)?;
pending -= 1;
}
notifier.notify(event)?;
let test_name = rx.recv().unwrap();
let running_test = running.remove(&test_name).unwrap();
let _ = running_test.join();
success &= sync_success.load(std::sync::atomic::Ordering::SeqCst);
if !success && opts.fail_fast {
break;
Expand All @@ -450,16 +410,16 @@ fn run(
}

if !exclusive_cases.is_empty() {
notifier.threaded(false);
context.notifier().threaded(false);
for case in exclusive_cases {
success &= run_case(start, case.as_ref(), &context, notifier)?;
success &= run_case(case.as_ref(), &context)?;
if !success && opts.fail_fast {
break;
}
}
}

notifier.notify(
context.notifier().notify(
notify::event::RunComplete {
elapsed_s: Some(notify::Elapsed(start.elapsed())),
}
Expand All @@ -469,16 +429,11 @@ fn run(
Ok(success)
}

fn run_case(
start: &std::time::Instant,
case: &dyn Case,
context: &TestContext,
notifier: &mut dyn notify::Notifier,
) -> std::io::Result<bool> {
notifier.notify(
fn run_case(case: &dyn Case, context: &TestContext) -> std::io::Result<bool> {
context.notifier().notify(
notify::event::CaseStart {
name: case.name().to_owned(),
elapsed_s: Some(notify::Elapsed(start.elapsed())),
elapsed_s: Some(context.elapased_s()),
}
.into(),
)?;
Expand Down Expand Up @@ -507,21 +462,21 @@ fn run_case(
let kind = err.status();
case_status = Some(kind);
let message = err.cause().map(|c| c.to_string());
notifier.notify(
context.notifier().notify(
notify::event::CaseMessage {
name: case.name().to_owned(),
kind,
message,
elapsed_s: Some(notify::Elapsed(start.elapsed())),
elapsed_s: Some(context.elapased_s()),
}
.into(),
)?;
}

notifier.notify(
context.notifier().notify(
notify::event::CaseComplete {
name: case.name().to_owned(),
elapsed_s: Some(notify::Elapsed(start.elapsed())),
elapsed_s: Some(context.elapased_s()),
}
.into(),
)?;
Expand All @@ -537,16 +492,3 @@ fn __rust_begin_short_backtrace<T, F: FnOnce() -> T>(f: F) -> T {
// prevent this frame from being tail-call optimised away
std::hint::black_box(result)
}

#[derive(Clone, Debug)]
struct SenderNotifier {
tx: std::sync::mpsc::Sender<notify::Event>,
}

impl notify::Notifier for SenderNotifier {
fn notify(&mut self, event: notify::Event) -> std::io::Result<()> {
// If the sender doesn't care, neither do we
let _ = self.tx.send(event);
Ok(())
}
}
Loading
Loading