Skip to content

Commit da2544b

Browse files
committed
Auto merge of #149495 - scottmcm:assume-filter-count, r=Mark-Simulacrum
Assume the returned value in `.filter(…).count()` Similar to how this helps in `slice::Iter::position`, LLVM sometimes loses track of how high this can get, so for `TrustedLen` iterators tell it what the upper bound is.
2 parents fbab541 + 6bd9d76 commit da2544b

File tree

3 files changed

+107
-2
lines changed

3 files changed

+107
-2
lines changed

library/core/src/iter/adapters/filter.rs

Lines changed: 39 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ use core::ops::ControlFlow;
44

55
use crate::fmt;
66
use crate::iter::adapters::SourceIter;
7-
use crate::iter::{FusedIterator, InPlaceIterable, TrustedFused};
7+
use crate::iter::{FusedIterator, InPlaceIterable, TrustedFused, TrustedLen};
88
use crate::num::NonZero;
99
use crate::ops::Try;
1010

@@ -138,7 +138,13 @@ where
138138
move |x| predicate(&x) as usize
139139
}
140140

141-
self.iter.map(to_usize(self.predicate)).sum()
141+
let before = self.iter.size_hint().1.unwrap_or(usize::MAX);
142+
let total = self.iter.map(to_usize(self.predicate)).sum();
143+
// SAFETY: `total` and `before` came from the same iterator of type `I`
144+
unsafe {
145+
<I as SpecAssumeCount>::assume_count_le_upper_bound(total, before);
146+
}
147+
total
142148
}
143149

144150
#[inline]
@@ -214,3 +220,34 @@ unsafe impl<I: InPlaceIterable, P> InPlaceIterable for Filter<I, P> {
214220
const EXPAND_BY: Option<NonZero<usize>> = I::EXPAND_BY;
215221
const MERGE_BY: Option<NonZero<usize>> = I::MERGE_BY;
216222
}
223+
224+
trait SpecAssumeCount {
225+
/// # Safety
226+
///
227+
/// `count` must be an number of items actually read from the iterator.
228+
///
229+
/// `upper` must either:
230+
/// - have come from `size_hint().1` on the iterator, or
231+
/// - be `usize::MAX` which will vacuously do nothing.
232+
unsafe fn assume_count_le_upper_bound(count: usize, upper: usize);
233+
}
234+
235+
impl<I: Iterator> SpecAssumeCount for I {
236+
#[inline]
237+
#[rustc_inherit_overflow_checks]
238+
default unsafe fn assume_count_le_upper_bound(count: usize, upper: usize) {
239+
// In the default we can't trust the `upper` for soundness
240+
// because it came from an untrusted `size_hint`.
241+
242+
// In debug mode we might as well check that the size_hint wasn't too small
243+
let _ = upper - count;
244+
}
245+
}
246+
247+
impl<I: TrustedLen> SpecAssumeCount for I {
248+
#[inline]
249+
unsafe fn assume_count_le_upper_bound(count: usize, upper: usize) {
250+
// SAFETY: The `upper` is trusted because it came from a `TrustedLen` iterator.
251+
unsafe { crate::hint::assert_unchecked(count <= upper) }
252+
}
253+
}
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
//@ compile-flags: -Copt-level=3
2+
//@ edition: 2024
3+
4+
#![crate_type = "lib"]
5+
6+
// Similar to how we `assume` that `slice::Iter::position` is within the length,
7+
// check that `count` also does that for `TrustedLen` iterators.
8+
// See https://rust-lang.zulipchat.com/#narrow/channel/122651-general/topic/Overflow-chk.20removed.20for.20array.20of.2059.2C.20but.20not.2060.2C.20elems/with/561070780
9+
10+
// CHECK-LABEL: @filter_count_untrusted
11+
#[unsafe(no_mangle)]
12+
pub fn filter_count_untrusted(bar: &[u8; 1234]) -> u16 {
13+
// CHECK-NOT: llvm.assume
14+
// CHECK: call void @{{.+}}unwrap_failed
15+
// CHECK-NOT: llvm.assume
16+
let mut iter = bar.iter();
17+
let iter = std::iter::from_fn(|| iter.next()); // Make it not TrustedLen
18+
u16::try_from(iter.filter(|v| **v == 0).count()).unwrap()
19+
}
20+
21+
// CHECK-LABEL: @filter_count_trusted
22+
#[unsafe(no_mangle)]
23+
pub fn filter_count_trusted(bar: &[u8; 1234]) -> u16 {
24+
// CHECK-NOT: unwrap_failed
25+
// CHECK: %[[ASSUME:.+]] = icmp ult {{i64|i32|i16}} %{{.+}}, 1235
26+
// CHECK-NEXT: tail call void @llvm.assume(i1 %[[ASSUME]])
27+
// CHECK-NOT: unwrap_failed
28+
let iter = bar.iter();
29+
u16::try_from(iter.filter(|v| **v == 0).count()).unwrap()
30+
}
31+
32+
// CHECK: ; core::result::unwrap_failed
33+
// CHECK-NEXT: Function Attrs
34+
// CHECK-NEXT: declare{{.+}}void @{{.+}}unwrap_failed
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
//@ run-pass
2+
//@ needs-unwind
3+
//@ ignore-backends: gcc
4+
//@ compile-flags: -C overflow-checks
5+
6+
use std::panic;
7+
8+
struct Lies(usize);
9+
10+
impl Iterator for Lies {
11+
type Item = usize;
12+
13+
fn next(&mut self) -> Option<usize> {
14+
if self.0 == 0 {
15+
None
16+
} else {
17+
self.0 -= 1;
18+
Some(self.0)
19+
}
20+
}
21+
22+
fn size_hint(&self) -> (usize, Option<usize>) {
23+
(0, Some(2))
24+
}
25+
}
26+
27+
fn main() {
28+
let r = panic::catch_unwind(|| {
29+
// This returns more items than its `size_hint` said was possible,
30+
// which `Filter::count` detects via `overflow-checks`.
31+
let _ = Lies(10).filter(|&x| x > 3).count();
32+
});
33+
assert!(r.is_err());
34+
}

0 commit comments

Comments
 (0)