diff --git a/datafusion/functions/src/datetime/date_bin.rs b/datafusion/functions/src/datetime/date_bin.rs index 6c67fbad34a1..d839e50262d2 100644 --- a/datafusion/functions/src/datetime/date_bin.rs +++ b/datafusion/functions/src/datetime/date_bin.rs @@ -295,7 +295,7 @@ impl ScalarUDFImpl for DateBinFunc { const NANOS_PER_MICRO: i64 = 1_000; const NANOS_PER_MILLI: i64 = 1_000_000; const NANOS_PER_SEC: i64 = NANOSECONDS; - +type BinFunction = fn(i64, i64, i64) -> Result; enum Interval { Nanoseconds(i64), Months(i64), @@ -310,7 +310,7 @@ impl Interval { /// `source` is the timestamp being binned /// /// `origin` is the time, in nanoseconds, where windows are measured from - fn bin_fn(&self) -> (i64, fn(i64, i64, i64) -> i64) { + fn bin_fn(&self) -> (i64, BinFunction) { match self { Interval::Nanoseconds(nanos) => (*nanos, date_bin_nanos_interval), Interval::Months(months) => (*months, date_bin_months_interval), @@ -319,13 +319,13 @@ impl Interval { } // return time in nanoseconds that the source timestamp falls into based on the stride and origin -fn date_bin_nanos_interval(stride_nanos: i64, source: i64, origin: i64) -> i64 { +fn date_bin_nanos_interval(stride_nanos: i64, source: i64, origin: i64) -> Result { let time_diff = source - origin; // distance from origin to bin let time_delta = compute_distance(time_diff, stride_nanos); - origin + time_delta + Ok(origin + time_delta) } // distance from origin to bin @@ -341,10 +341,10 @@ fn compute_distance(time_diff: i64, stride: i64) -> i64 { } // return time in nanoseconds that the source timestamp falls into based on the stride and origin -fn date_bin_months_interval(stride_months: i64, source: i64, origin: i64) -> i64 { +fn date_bin_months_interval(stride_months: i64, source: i64, origin: i64) -> Result { // convert source and origin to DateTime - let source_date = to_utc_date_time(source); - let origin_date = to_utc_date_time(origin); + let source_date = to_utc_date_time(source)?; + let origin_date = to_utc_date_time(origin)?; // calculate the number of months between the source and origin let month_diff = (source_date.year() - origin_date.year()) * 12 @@ -355,9 +355,17 @@ fn date_bin_months_interval(stride_months: i64, source: i64, origin: i64) -> i64 let month_delta = compute_distance(month_diff as i64, stride_months); let mut bin_time = if month_delta < 0 { - origin_date - Months::new(month_delta.unsigned_abs() as u32) + match origin_date + .checked_sub_months(Months::new(month_delta.unsigned_abs() as u32)) + { + Some(dt) => dt, + None => return exec_err!("DATE_BIN month subtraction out of range"), + } } else { - origin_date + Months::new(month_delta as u32) + match origin_date.checked_add_months(Months::new(month_delta as u32)) { + Some(dt) => dt, + None => return exec_err!("DATE_BIN month addition out of range"), + } }; // If origin is not midnight of first date of the month, the bin_time may be larger than the source @@ -365,19 +373,32 @@ fn date_bin_months_interval(stride_months: i64, source: i64, origin: i64) -> i64 if bin_time > source_date { let month_delta = month_delta - stride_months; bin_time = if month_delta < 0 { - origin_date - Months::new(month_delta.unsigned_abs() as u32) + match origin_date + .checked_sub_months(Months::new(month_delta.unsigned_abs() as u32)) + { + Some(dt) => dt, + None => return exec_err!("DATE_BIN month subtraction out of range"), + } } else { - origin_date + Months::new(month_delta as u32) + match origin_date.checked_add_months(Months::new(month_delta as u32)) { + Some(dt) => dt, + None => return exec_err!("DATE_BIN month addition out of range"), + } }; } - - bin_time.timestamp_nanos_opt().unwrap() + match bin_time.timestamp_nanos_opt() { + Some(nanos) => Ok(nanos), + None => exec_err!("DATE_BIN result timestamp out of range"), + } } -fn to_utc_date_time(nanos: i64) -> DateTime { +fn to_utc_date_time(nanos: i64) -> Result> { let secs = nanos / NANOS_PER_SEC; let nsec = (nanos % NANOS_PER_SEC) as u32; - DateTime::from_timestamp(secs, nsec).unwrap() + match DateTime::from_timestamp(secs, nsec) { + Some(dt) => Ok(dt), + None => exec_err!("Invalid timestamp value"), + } } // Supported intervals: @@ -546,15 +567,18 @@ fn date_bin_impl( fn stride_map_fn( origin: i64, stride: i64, - stride_fn: fn(i64, i64, i64) -> i64, - ) -> impl Fn(i64) -> i64 { + stride_fn: BinFunction, + ) -> impl Fn(i64) -> Result { let scale = match T::UNIT { Nanosecond => 1, Microsecond => NANOS_PER_MICRO, Millisecond => NANOS_PER_MILLI, Second => NANOSECONDS, }; - move |x: i64| stride_fn(stride, x * scale, origin) / scale + move |x: i64| match stride_fn(stride, x * scale, origin) { + Ok(result) => Ok(result / scale), + Err(e) => Err(e), + } } Ok(match array { @@ -562,7 +586,7 @@ fn date_bin_impl( let apply_stride_fn = stride_map_fn::(origin, stride, stride_fn); ColumnarValue::Scalar(ScalarValue::TimestampNanosecond( - v.map(apply_stride_fn), + v.and_then(|val| apply_stride_fn(val).ok()), tz_opt.clone(), )) } @@ -570,7 +594,7 @@ fn date_bin_impl( let apply_stride_fn = stride_map_fn::(origin, stride, stride_fn); ColumnarValue::Scalar(ScalarValue::TimestampMicrosecond( - v.map(apply_stride_fn), + v.and_then(|val| apply_stride_fn(val).ok()), tz_opt.clone(), )) } @@ -578,7 +602,7 @@ fn date_bin_impl( let apply_stride_fn = stride_map_fn::(origin, stride, stride_fn); ColumnarValue::Scalar(ScalarValue::TimestampMillisecond( - v.map(apply_stride_fn), + v.and_then(|val| apply_stride_fn(val).ok()), tz_opt.clone(), )) } @@ -586,7 +610,7 @@ fn date_bin_impl( let apply_stride_fn = stride_map_fn::(origin, stride, stride_fn); ColumnarValue::Scalar(ScalarValue::TimestampSecond( - v.map(apply_stride_fn), + v.and_then(|val| apply_stride_fn(val).ok()), tz_opt.clone(), )) } @@ -594,50 +618,61 @@ fn date_bin_impl( if !is_time { return exec_err!("DATE_BIN with Time32 source requires Time32 origin"); } - let apply_stride_fn = move |x: i32| { - let binned_nanos = stride_fn(stride, x as i64 * NANOS_PER_MILLI, origin); - let nanos = binned_nanos % (NANOSECONDS_IN_DAY); - (nanos / NANOS_PER_MILLI) as i32 - }; - ColumnarValue::Scalar(ScalarValue::Time32Millisecond(v.map(apply_stride_fn))) + let result = v.and_then(|x| { + match stride_fn(stride, x as i64 * NANOS_PER_MILLI, origin) { + Ok(binned_nanos) => { + let nanos = binned_nanos % (NANOSECONDS_IN_DAY); + Some((nanos / NANOS_PER_MILLI) as i32) + } + Err(_) => None, + } + }); + ColumnarValue::Scalar(ScalarValue::Time32Millisecond(result)) } ColumnarValue::Scalar(ScalarValue::Time32Second(v)) => { if !is_time { return exec_err!("DATE_BIN with Time32 source requires Time32 origin"); } - let apply_stride_fn = move |x: i32| { - let binned_nanos = stride_fn(stride, x as i64 * NANOS_PER_SEC, origin); - let nanos = binned_nanos % (NANOSECONDS_IN_DAY); - (nanos / NANOS_PER_SEC) as i32 - }; - ColumnarValue::Scalar(ScalarValue::Time32Second(v.map(apply_stride_fn))) + let result = v.and_then(|x| { + match stride_fn(stride, x as i64 * NANOS_PER_SEC, origin) { + Ok(binned_nanos) => { + let nanos = binned_nanos % (NANOSECONDS_IN_DAY); + Some((nanos / NANOS_PER_SEC) as i32) + } + Err(_) => None, + } + }); + ColumnarValue::Scalar(ScalarValue::Time32Second(result)) } ColumnarValue::Scalar(ScalarValue::Time64Nanosecond(v)) => { if !is_time { return exec_err!("DATE_BIN with Time64 source requires Time64 origin"); } - let apply_stride_fn = move |x: i64| { - let binned_nanos = stride_fn(stride, x, origin); - binned_nanos % (NANOSECONDS_IN_DAY) - }; - ColumnarValue::Scalar(ScalarValue::Time64Nanosecond(v.map(apply_stride_fn))) + let result = v.and_then(|x| match stride_fn(stride, x, origin) { + Ok(binned_nanos) => Some(binned_nanos % (NANOSECONDS_IN_DAY)), + Err(_) => None, + }); + ColumnarValue::Scalar(ScalarValue::Time64Nanosecond(result)) } ColumnarValue::Scalar(ScalarValue::Time64Microsecond(v)) => { if !is_time { return exec_err!("DATE_BIN with Time64 source requires Time64 origin"); } - let apply_stride_fn = move |x: i64| { - let binned_nanos = stride_fn(stride, x * NANOS_PER_MICRO, origin); - let nanos = binned_nanos % (NANOSECONDS_IN_DAY); - nanos / NANOS_PER_MICRO - }; - ColumnarValue::Scalar(ScalarValue::Time64Microsecond(v.map(apply_stride_fn))) + let result = + v.and_then(|x| match stride_fn(stride, x * NANOS_PER_MICRO, origin) { + Ok(binned_nanos) => { + let nanos = binned_nanos % (NANOSECONDS_IN_DAY); + Some(nanos / NANOS_PER_MICRO) + } + Err(_) => None, + }); + ColumnarValue::Scalar(ScalarValue::Time64Microsecond(result)) } ColumnarValue::Array(array) => { fn transform_array_with_stride( origin: i64, stride: i64, - stride_fn: fn(i64, i64, i64) -> i64, + stride_fn: BinFunction, array: &ArrayRef, tz_opt: &Option>, ) -> Result @@ -645,11 +680,25 @@ fn date_bin_impl( T: ArrowTimestampType, { let array = as_primitive_array::(array)?; - let apply_stride_fn = stride_map_fn::(origin, stride, stride_fn); - let array: PrimitiveArray = array - .unary(apply_stride_fn) - .with_timezone_opt(tz_opt.clone()); + let scale = match T::UNIT { + Nanosecond => 1, + Microsecond => NANOS_PER_MICRO, + Millisecond => NANOS_PER_MILLI, + Second => NANOSECONDS, + }; + + let result: PrimitiveArray = array + .iter() + .map(|opt_val| { + opt_val.and_then(|val| { + stride_fn(stride, val * scale, origin) + .ok() + .map(|v| v / scale) + }) + }) + .collect(); + let array = result.with_timezone_opt(tz_opt.clone()); Ok(ColumnarValue::Array(Arc::new(array))) } @@ -681,15 +730,20 @@ fn date_bin_impl( ); } let array = array.as_primitive::(); - let apply_stride_fn = move |x: i32| { - let binned_nanos = - stride_fn(stride, x as i64 * NANOS_PER_MILLI, origin); - let nanos = binned_nanos % (NANOSECONDS_IN_DAY); - (nanos / NANOS_PER_MILLI) as i32 - }; - let array: PrimitiveArray = - array.unary(apply_stride_fn); - ColumnarValue::Array(Arc::new(array)) + let result: PrimitiveArray = array + .iter() + .map(|opt_val| { + opt_val.and_then(|x| { + stride_fn(stride, x as i64 * NANOS_PER_MILLI, origin) + .ok() + .map(|binned_nanos| { + let nanos = binned_nanos % (NANOSECONDS_IN_DAY); + (nanos / NANOS_PER_MILLI) as i32 + }) + }) + }) + .collect(); + ColumnarValue::Array(Arc::new(result)) } Time32(Second) => { if !is_time { @@ -698,15 +752,20 @@ fn date_bin_impl( ); } let array = array.as_primitive::(); - let apply_stride_fn = move |x: i32| { - let binned_nanos = - stride_fn(stride, x as i64 * NANOS_PER_SEC, origin); - let nanos = binned_nanos % (NANOSECONDS_IN_DAY); - (nanos / NANOS_PER_SEC) as i32 - }; - let array: PrimitiveArray = - array.unary(apply_stride_fn); - ColumnarValue::Array(Arc::new(array)) + let result: PrimitiveArray = array + .iter() + .map(|opt_val| { + opt_val.and_then(|x| { + stride_fn(stride, x as i64 * NANOS_PER_SEC, origin) + .ok() + .map(|binned_nanos| { + let nanos = binned_nanos % (NANOSECONDS_IN_DAY); + (nanos / NANOS_PER_SEC) as i32 + }) + }) + }) + .collect(); + ColumnarValue::Array(Arc::new(result)) } Time64(Microsecond) => { if !is_time { @@ -715,14 +774,20 @@ fn date_bin_impl( ); } let array = array.as_primitive::(); - let apply_stride_fn = move |x: i64| { - let binned_nanos = stride_fn(stride, x * NANOS_PER_MICRO, origin); - let nanos = binned_nanos % (NANOSECONDS_IN_DAY); - nanos / NANOS_PER_MICRO - }; - let array: PrimitiveArray = - array.unary(apply_stride_fn); - ColumnarValue::Array(Arc::new(array)) + let result: PrimitiveArray = array + .iter() + .map(|opt_val| { + opt_val.and_then(|x| { + stride_fn(stride, x * NANOS_PER_MICRO, origin).ok().map( + |binned_nanos| { + let nanos = binned_nanos % (NANOSECONDS_IN_DAY); + nanos / NANOS_PER_MICRO + }, + ) + }) + }) + .collect(); + ColumnarValue::Array(Arc::new(result)) } Time64(Nanosecond) => { if !is_time { @@ -731,13 +796,17 @@ fn date_bin_impl( ); } let array = array.as_primitive::(); - let apply_stride_fn = move |x: i64| { - let binned_nanos = stride_fn(stride, x, origin); - binned_nanos % (NANOSECONDS_IN_DAY) - }; - let array: PrimitiveArray = - array.unary(apply_stride_fn); - ColumnarValue::Array(Arc::new(array)) + let result: PrimitiveArray = array + .iter() + .map(|opt_val| { + opt_val.and_then(|x| { + stride_fn(stride, x, origin).ok().map(|binned_nanos| { + binned_nanos % (NANOSECONDS_IN_DAY) + }) + }) + }) + .collect(); + ColumnarValue::Array(Arc::new(result)) } _ => { return exec_err!( @@ -1193,7 +1262,7 @@ mod tests { let origin1 = string_to_timestamp_nanos(origin).unwrap(); let expected1 = string_to_timestamp_nanos(expected).unwrap(); - let result = date_bin_nanos_interval(stride1, source1, origin1); + let result = date_bin_nanos_interval(stride1, source1, origin1).unwrap(); assert_eq!(result, expected1, "{source} = {expected}"); }) } @@ -1221,8 +1290,55 @@ mod tests { let source1 = string_to_timestamp_nanos(source).unwrap(); let expected1 = string_to_timestamp_nanos(expected).unwrap(); - let result = date_bin_nanos_interval(stride1, source1, 0); + let result = date_bin_nanos_interval(stride1, source1, 0).unwrap(); assert_eq!(result, expected1, "{source} = {expected}"); }) } + + #[test] + fn test_date_bin_out_of_range() { + let return_field = &Arc::new(Field::new( + "f", + DataType::Timestamp(TimeUnit::Millisecond, None), + true, + )); + let args = vec![ + ColumnarValue::Scalar(ScalarValue::new_interval_mdn(1637426858, 0, 0)), + ColumnarValue::Scalar(ScalarValue::TimestampMillisecond( + Some(1040292460), + None, + )), + ColumnarValue::Scalar(ScalarValue::TimestampNanosecond( + Some(string_to_timestamp_nanos("1984-01-07 00:00:00").unwrap()), + None, + )), + ]; + + let result = invoke_date_bin_with_args(args, 1, return_field); + assert!(result.is_ok()); + if let ColumnarValue::Scalar(ScalarValue::TimestampMillisecond(val, _)) = + result.unwrap() + { + assert!(val.is_none(), "Expected None for out of range operation"); + } + let args = vec![ + ColumnarValue::Scalar(ScalarValue::new_interval_mdn(1637426858, 0, 0)), + ColumnarValue::Scalar(ScalarValue::TimestampMillisecond( + Some(-1040292460), + None, + )), + ColumnarValue::Scalar(ScalarValue::TimestampNanosecond( + Some(string_to_timestamp_nanos("1984-01-07 00:00:00").unwrap()), + None, + )), + ]; + + let result = invoke_date_bin_with_args(args, 1, return_field); + assert!(result.is_ok()); + if let ColumnarValue::Scalar(ScalarValue::TimestampMillisecond(val, _)) = + result.unwrap() + { + assert!(val.is_none(), "Expected None for out of range operation"); + } + } }