Skip to content

Commit 15307eb

Browse files
committed
Fix signed overflow with midpoint interpolation
The simple algorithm is replaced with an Euclid division and remainder based slightly more complex one when the inputs are integers. Moreover, the implementation is separated for integers and floating point numbers, relying on macros instead of generics with trait bounds. A regression test that failed before the changes is added along with property-based testing ensuring that the results match those that the previous version output (when within the non-overflowing limits).
1 parent 2b07b9e commit 15307eb

File tree

3 files changed

+130
-18
lines changed

3 files changed

+130
-18
lines changed

src/maybe_nan/impl_not_none.rs

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
use super::NotNone;
2-
use num_traits::{FromPrimitive, ToPrimitive};
2+
use num_traits::{Euclid, FromPrimitive, ToPrimitive};
33
use std::cmp;
44
use std::fmt;
55
use std::ops::{Add, Deref, DerefMut, Div, Mul, Rem, Sub};
@@ -101,6 +101,19 @@ impl<T: Rem> Rem for NotNone<T> {
101101
}
102102
}
103103

104+
impl<T: Euclid> Euclid for NotNone<T> {
105+
#[inline]
106+
fn div_euclid(&self, rhs: &Self) -> Self {
107+
let result = self.deref().div_euclid(rhs.deref());
108+
NotNone(Some(result))
109+
}
110+
#[inline]
111+
fn rem_euclid(&self, rhs: &Self) -> Self {
112+
let result = self.deref().rem_euclid(rhs.deref());
113+
NotNone(Some(result))
114+
}
115+
}
116+
104117
impl<T: ToPrimitive> ToPrimitive for NotNone<T> {
105118
#[inline]
106119
fn to_isize(&self) -> Option<isize> {

src/quantile/interpolate.rs

Lines changed: 105 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
//! Interpolation strategies.
22
use noisy_float::types::N64;
3-
use num_traits::{Float, FromPrimitive, NumOps, ToPrimitive};
3+
use num_traits::{Euclid, Float, FromPrimitive, NumOps, ToPrimitive};
4+
5+
use crate::maybe_nan::NotNone;
46

57
fn float_quantile_index(q: N64, len: usize) -> N64 {
68
q * ((len - 1) as f64)
@@ -104,25 +106,69 @@ impl<T> Interpolate<T> for Nearest {
104106
private_impl! {}
105107
}
106108

107-
impl<T> Interpolate<T> for Midpoint
108-
where
109-
T: NumOps + Clone + FromPrimitive,
110-
{
111-
fn needs_lower(_q: N64, _len: usize) -> bool {
112-
true
113-
}
114-
fn needs_higher(_q: N64, _len: usize) -> bool {
115-
true
109+
macro_rules! impl_midpoint_interpolate_for_float {
110+
($($t:ty),*) => {
111+
$(
112+
impl Interpolate<$t> for Midpoint {
113+
fn needs_lower(_q: N64, _len: usize) -> bool {
114+
true
115+
}
116+
fn needs_higher(_q: N64, _len: usize) -> bool {
117+
true
118+
}
119+
fn interpolate(lower: Option<$t>, higher: Option<$t>, _q: N64, _len: usize) -> $t {
120+
let lower = lower.unwrap();
121+
let higher = higher.unwrap();
122+
lower + (higher - lower) / 2.0
123+
}
124+
private_impl! {}
125+
}
126+
)*
116127
}
117-
fn interpolate(lower: Option<T>, higher: Option<T>, _q: N64, _len: usize) -> T {
118-
let denom = T::from_u8(2).unwrap();
119-
let lower = lower.unwrap();
120-
let higher = higher.unwrap();
121-
lower.clone() + (higher.clone() - lower.clone()) / denom.clone()
128+
}
129+
130+
impl_midpoint_interpolate_for_float!(f32, f64);
131+
132+
macro_rules! impl_midpoint_interpolate_for_integer {
133+
($($t:ty),*) => {
134+
$(
135+
impl Interpolate<$t> for Midpoint {
136+
fn needs_lower(_q: N64, _len: usize) -> bool {
137+
true
138+
}
139+
fn needs_higher(_q: N64, _len: usize) -> bool {
140+
true
141+
}
142+
fn interpolate(lower: Option<$t>, higher: Option<$t>, _q: N64, _len: usize) -> $t {
143+
let two = <$t>::from_u8(2).unwrap();
144+
let (lower_half, lower_rem) = lower.unwrap().div_rem_euclid(&two);
145+
let (higher_half, higher_rem) = higher.unwrap().div_rem_euclid(&two);
146+
lower_half + higher_half + (lower_rem * higher_rem)
147+
}
148+
private_impl! {}
149+
}
150+
)*
122151
}
123-
private_impl! {}
124152
}
125153

154+
impl_midpoint_interpolate_for_integer!(
155+
i8, i16, i32, i64, i128, isize, u8, u16, u32, u64, u128, usize
156+
);
157+
impl_midpoint_interpolate_for_integer!(
158+
NotNone<i8>,
159+
NotNone<i16>,
160+
NotNone<i32>,
161+
NotNone<i64>,
162+
NotNone<i128>,
163+
NotNone<isize>,
164+
NotNone<u8>,
165+
NotNone<u16>,
166+
NotNone<u32>,
167+
NotNone<u64>,
168+
NotNone<u128>,
169+
NotNone<usize>
170+
);
171+
126172
impl<T> Interpolate<T> for Linear
127173
where
128174
T: NumOps + Clone + FromPrimitive + ToPrimitive,
@@ -143,3 +189,46 @@ where
143189
}
144190
private_impl! {}
145191
}
192+
193+
#[cfg(test)]
194+
mod tests {
195+
use super::*;
196+
use noisy_float::types::n64;
197+
use quickcheck::TestResult;
198+
use quickcheck_macros::quickcheck;
199+
200+
#[derive(Clone, Copy, Debug)]
201+
struct LowerHigherPair<T>(T, T);
202+
203+
impl quickcheck::Arbitrary for LowerHigherPair<i64> {
204+
fn arbitrary<G: quickcheck::Gen>(g: &mut G) -> Self {
205+
let (l, h) = loop {
206+
let (l, h) = (i64::arbitrary(g), i64::arbitrary(g));
207+
if l > h || h.checked_sub(l).is_none() {
208+
continue;
209+
}
210+
break (l, h);
211+
};
212+
LowerHigherPair(l, h)
213+
}
214+
}
215+
216+
impl From<LowerHigherPair<i64>> for (i64, i64) {
217+
fn from(value: LowerHigherPair<i64>) -> Self {
218+
(value.0, value.1)
219+
}
220+
}
221+
222+
fn naive_midpoint_i64(lower: i64, higher: i64) -> i64 {
223+
// Overflows when higher is very big and lower is very small
224+
lower + (higher - lower) / 2
225+
}
226+
227+
#[quickcheck]
228+
fn test_midpoint_algo_eq_naive_algo_i64(lh: LowerHigherPair<i64>) -> TestResult {
229+
let (lower, higher) = lh.into();
230+
let naive = naive_midpoint_i64(lower, higher);
231+
let midpoint = Midpoint::interpolate(Some(lower), Some(higher), n64(0.0), 0);
232+
TestResult::from_bool(naive == midpoint)
233+
}
234+
}

tests/quantile.rs

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -268,7 +268,7 @@ fn test_quantile_axis_skipnan_mut_linear_opt_i32() {
268268
}
269269

270270
#[test]
271-
fn test_midpoint_overflow() {
271+
fn test_midpoint_overflow_unsigned() {
272272
// Regression test
273273
// This triggered an overflow panic with a naive Midpoint implementation: (a+b)/2
274274
let mut a: Array1<u8> = array![129, 130, 130, 131];
@@ -277,6 +277,16 @@ fn test_midpoint_overflow() {
277277
assert_eq!(median, expected_median);
278278
}
279279

280+
#[test]
281+
fn test_midpoint_overflow_signed() {
282+
// Regression test
283+
// This triggered an overflow panic with a naive Midpoint implementation: b+(a-b)/2
284+
let mut a: Array1<i64> = array![i64::MIN, i64::MAX];
285+
let median = a.quantile_mut(n64(0.5), &Midpoint).unwrap();
286+
let expected_median = -1;
287+
assert_eq!(median, expected_median);
288+
}
289+
280290
#[quickcheck]
281291
fn test_quantiles_mut(xs: Vec<i64>) -> bool {
282292
let v = Array::from(xs.clone());

0 commit comments

Comments
 (0)