Skip to content

Commit 3f17c9c

Browse files
committed
test: add interruption handling test for long-running queries in DataFusion
1 parent 97f86dc commit 3f17c9c

File tree

3 files changed

+156
-0
lines changed

3 files changed

+156
-0
lines changed

python/tests/test_dataframe.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1296,6 +1296,7 @@ def test_collect_partitioned():
12961296
assert [[batch]] == ctx.create_dataframe([[batch]]).collect_partitioned()
12971297

12981298

1299+
12991300
def test_union(ctx):
13001301
batch = pa.RecordBatch.from_arrays(
13011302
[pa.array([1, 2, 3]), pa.array([4, 5, 6])],

python/tests/test_interruption.py

Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
18+
"""Tests for handling interruptions in DataFusion operations."""
19+
20+
import threading
21+
import time
22+
import ctypes
23+
24+
import pytest
25+
import pyarrow as pa
26+
from datafusion import SessionContext
27+
28+
29+
def test_collect_interrupted():
30+
"""Test that a long-running query can be interrupted with Ctrl-C.
31+
32+
This test simulates a Ctrl-C keyboard interrupt by raising a KeyboardInterrupt
33+
exception in the main thread during a long-running query execution.
34+
"""
35+
# Create a context and a DataFrame with a query that will run for a while
36+
ctx = SessionContext()
37+
38+
# Create a recursive computation that will run for some time
39+
batches = []
40+
for i in range(10):
41+
batch = pa.RecordBatch.from_arrays(
42+
[
43+
pa.array(list(range(i * 1000, (i + 1) * 1000))),
44+
pa.array([f"value_{j}" for j in range(i * 1000, (i + 1) * 1000)]),
45+
],
46+
names=["a", "b"],
47+
)
48+
batches.append(batch)
49+
50+
# Register tables
51+
ctx.register_record_batches("t1", [batches])
52+
ctx.register_record_batches("t2", [batches])
53+
54+
# Create a large join operation that will take time to process
55+
df = ctx.sql("""
56+
WITH t1_expanded AS (
57+
SELECT
58+
a,
59+
b,
60+
CAST(a AS DOUBLE) / 1.5 AS c,
61+
CAST(a AS DOUBLE) * CAST(a AS DOUBLE) AS d
62+
FROM t1
63+
CROSS JOIN (SELECT 1 AS dummy FROM t1 LIMIT 5)
64+
),
65+
t2_expanded AS (
66+
SELECT
67+
a,
68+
b,
69+
CAST(a AS DOUBLE) * 2.5 AS e,
70+
CAST(a AS DOUBLE) * CAST(a AS DOUBLE) * CAST(a AS DOUBLE) AS f
71+
FROM t2
72+
CROSS JOIN (SELECT 1 AS dummy FROM t2 LIMIT 5)
73+
)
74+
SELECT
75+
t1.a, t1.b, t1.c, t1.d,
76+
t2.a AS a2, t2.b AS b2, t2.e, t2.f
77+
FROM t1_expanded t1
78+
JOIN t2_expanded t2 ON t1.a % 100 = t2.a % 100
79+
WHERE t1.a > 100 AND t2.a > 100
80+
""")
81+
82+
# Flag to track if the query was interrupted
83+
interrupted = False
84+
interrupt_error = None
85+
main_thread = threading.main_thread()
86+
87+
# This function will be run in a separate thread and will raise
88+
# KeyboardInterrupt in the main thread
89+
def trigger_interrupt():
90+
"""Wait a moment then raise KeyboardInterrupt in the main thread"""
91+
time.sleep(0.5) # Give the query time to start
92+
93+
# Use ctypes to raise exception in main thread
94+
exception = ctypes.py_object(KeyboardInterrupt)
95+
res = ctypes.pythonapi.PyThreadState_SetAsyncExc(
96+
ctypes.c_long(main_thread.ident), exception)
97+
if res != 1:
98+
# If res is 0, the thread ID was invalid
99+
# If res > 1, we modified multiple threads
100+
ctypes.pythonapi.PyThreadState_SetAsyncExc(
101+
ctypes.c_long(main_thread.ident), ctypes.py_object(0))
102+
raise RuntimeError("Failed to raise KeyboardInterrupt in main thread")
103+
104+
# Start a thread to trigger the interrupt
105+
interrupt_thread = threading.Thread(target=trigger_interrupt)
106+
interrupt_thread.daemon = True
107+
interrupt_thread.start()
108+
109+
# Execute the query and expect it to be interrupted
110+
try:
111+
df.collect()
112+
except KeyboardInterrupt:
113+
interrupted = True
114+
except Exception as e:
115+
interrupt_error = e
116+
117+
# Assert that the query was interrupted properly
118+
assert interrupted, "Query was not interrupted by KeyboardInterrupt"
119+
assert interrupt_error is None, f"Unexpected error occurred: {interrupt_error}"
120+
121+
# Make sure the interrupt thread has finished
122+
interrupt_thread.join(timeout=1.0)

src/context.rs

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -753,6 +753,14 @@ impl PySessionContext {
753753
// Extract the single byte to use in the future
754754
let delimiter_byte = delimiter_bytes[0];
755755

756+
// Validate file_compression_type synchronously before async call
757+
if let Some(compression_type) = &file_compression_type {
758+
// Return Python error directly instead of wrapping it in PyDataFusionError to match test expectations
759+
if let Err(err) = parse_file_compression_type(Some(compression_type.clone())) {
760+
return Err(PyDataFusionError::PythonError(err));
761+
}
762+
}
763+
756764
// Clone all string references to create owned values
757765
let file_extension_owned = file_extension.to_string();
758766
let name_owned = name.to_string();
@@ -839,6 +847,14 @@ impl PySessionContext {
839847
.to_str()
840848
.ok_or_else(|| PyValueError::new_err("Unable to convert path to a string"))?;
841849

850+
// Validate file_compression_type synchronously before async call
851+
if let Some(compression_type) = &file_compression_type {
852+
// Return Python error directly instead of wrapping it in PyDataFusionError to match test expectations
853+
if let Err(err) = parse_file_compression_type(Some(compression_type.clone())) {
854+
return Err(PyDataFusionError::PythonError(err));
855+
}
856+
}
857+
842858
// Clone required values to avoid borrowing in the future
843859
let ctx = self.ctx.clone();
844860
let name_owned = name.to_string();
@@ -1040,6 +1056,15 @@ impl PySessionContext {
10401056
let path = path
10411057
.to_str()
10421058
.ok_or_else(|| PyValueError::new_err("Unable to convert path to a string"))?;
1059+
1060+
// Validate file_compression_type synchronously before async call
1061+
if let Some(compression_type) = &file_compression_type {
1062+
// Return Python error directly instead of wrapping it in PyDataFusionError to match test expectations
1063+
if let Err(err) = parse_file_compression_type(Some(compression_type.clone())) {
1064+
return Err(PyDataFusionError::PythonError(err));
1065+
}
1066+
}
1067+
10431068
// Clone required values to avoid borrowing in the future
10441069
let ctx = self.ctx.clone();
10451070
let path_owned = path.to_string();
@@ -1123,6 +1148,14 @@ impl PySessionContext {
11231148
// Store just the delimiter byte to use in the future
11241149
let delimiter_byte = delimiter_bytes[0];
11251150

1151+
// Validate file_compression_type synchronously before async call
1152+
if let Some(compression_type) = &file_compression_type {
1153+
// Return Python error directly instead of wrapping it in PyDataFusionError to match test expectations
1154+
if let Err(err) = parse_file_compression_type(Some(compression_type.clone())) {
1155+
return Err(PyDataFusionError::PythonError(err));
1156+
}
1157+
}
1158+
11261159
// Clone required values to avoid borrowing in the future
11271160
let ctx = self.ctx.clone();
11281161
let file_extension_owned = file_extension.to_string();

0 commit comments

Comments
 (0)