Skip to content

Commit dafa3a5

Browse files
committed
test: move test_collect_interrupted to test_dataframe.py
1 parent 3f17c9c commit dafa3a5

File tree

2 files changed

+107
-123
lines changed

2 files changed

+107
-123
lines changed

python/tests/test_dataframe.py

Lines changed: 107 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,9 @@
1717
import datetime
1818
import os
1919
import re
20+
import threading
21+
import time
22+
import ctypes
2023
from typing import Any
2124

2225
import pyarrow as pa
@@ -1914,7 +1917,7 @@ def test_fill_null_date32_column(null_df):
19141917
dates = result.column(4).to_pylist()
19151918
assert dates[0] == datetime.date(2000, 1, 1) # Original value
19161919
assert dates[1] == epoch_date # Filled value
1917-
assert dates[2] == datetime.date(2022, 1, 1) # Original value
1920+
assert dates[2] == datetime.date(2022, 1, 1) # Original value
19181921
assert dates[3] == epoch_date # Filled value
19191922

19201923
# Other date column should be unchanged
@@ -2061,3 +2064,106 @@ def test_fill_null_all_null_column(ctx):
20612064
# Check that all nulls were filled
20622065
result = filled_df.collect()[0]
20632066
assert result.column(1).to_pylist() == ["filled", "filled", "filled"]
2067+
2068+
2069+
def test_collect_interrupted():
2070+
"""Test that a long-running query can be interrupted with Ctrl-C.
2071+
2072+
This test simulates a Ctrl-C keyboard interrupt by raising a KeyboardInterrupt
2073+
exception in the main thread during a long-running query execution.
2074+
"""
2075+
# Create a context and a DataFrame with a query that will run for a while
2076+
ctx = SessionContext()
2077+
2078+
# Create a recursive computation that will run for some time
2079+
batches = []
2080+
for i in range(10):
2081+
batch = pa.RecordBatch.from_arrays(
2082+
[
2083+
pa.array(list(range(i * 1000, (i + 1) * 1000))),
2084+
pa.array([f"value_{j}" for j in range(i * 1000, (i + 1) * 1000)]),
2085+
],
2086+
names=["a", "b"],
2087+
)
2088+
batches.append(batch)
2089+
2090+
# Register tables
2091+
ctx.register_record_batches("t1", [batches])
2092+
ctx.register_record_batches("t2", [batches])
2093+
2094+
# Create a large join operation that will take time to process
2095+
df = ctx.sql("""
2096+
WITH t1_expanded AS (
2097+
SELECT
2098+
a,
2099+
b,
2100+
CAST(a AS DOUBLE) / 1.5 AS c,
2101+
CAST(a AS DOUBLE) * CAST(a AS DOUBLE) AS d
2102+
FROM t1
2103+
CROSS JOIN (SELECT 1 AS dummy FROM t1 LIMIT 5)
2104+
),
2105+
t2_expanded AS (
2106+
SELECT
2107+
a,
2108+
b,
2109+
CAST(a AS DOUBLE) * 2.5 AS e,
2110+
CAST(a AS DOUBLE) * CAST(a AS DOUBLE) * CAST(a AS DOUBLE) AS f
2111+
FROM t2
2112+
CROSS JOIN (SELECT 1 AS dummy FROM t2 LIMIT 5)
2113+
)
2114+
SELECT
2115+
t1.a, t1.b, t1.c, t1.d,
2116+
t2.a AS a2, t2.b AS b2, t2.e, t2.f
2117+
FROM t1_expanded t1
2118+
JOIN t2_expanded t2 ON t1.a % 100 = t2.a % 100
2119+
WHERE t1.a > 100 AND t2.a > 100
2120+
""")
2121+
2122+
# Flag to track if the query was interrupted
2123+
interrupted = False
2124+
interrupt_error = None
2125+
main_thread = threading.main_thread()
2126+
2127+
# This function will be run in a separate thread and will raise
2128+
# KeyboardInterrupt in the main thread
2129+
def trigger_interrupt():
2130+
"""Wait a moment then raise KeyboardInterrupt in the main thread"""
2131+
time.sleep(0.5) # Give the query time to start
2132+
2133+
# Check if thread ID is available
2134+
thread_id = main_thread.ident
2135+
if thread_id is None:
2136+
msg = "Cannot get main thread ID"
2137+
raise RuntimeError(msg)
2138+
2139+
# Use ctypes to raise exception in main thread
2140+
exception = ctypes.py_object(KeyboardInterrupt)
2141+
res = ctypes.pythonapi.PyThreadState_SetAsyncExc(
2142+
ctypes.c_long(thread_id), exception)
2143+
if res != 1:
2144+
# If res is 0, the thread ID was invalid
2145+
# If res > 1, we modified multiple threads
2146+
ctypes.pythonapi.PyThreadState_SetAsyncExc(
2147+
ctypes.c_long(thread_id), ctypes.py_object(0))
2148+
msg = "Failed to raise KeyboardInterrupt in main thread"
2149+
raise RuntimeError(msg)
2150+
2151+
# Start a thread to trigger the interrupt
2152+
interrupt_thread = threading.Thread(target=trigger_interrupt)
2153+
interrupt_thread.daemon = True
2154+
interrupt_thread.start()
2155+
2156+
# Execute the query and expect it to be interrupted
2157+
try:
2158+
df.collect()
2159+
except KeyboardInterrupt:
2160+
interrupted = True
2161+
except Exception as e:
2162+
interrupt_error = e
2163+
2164+
# Assert that the query was interrupted properly
2165+
assert interrupted, "Query was not interrupted by KeyboardInterrupt"
2166+
assert interrupt_error is None, f"Unexpected error occurred: {interrupt_error}"
2167+
2168+
# Make sure the interrupt thread has finished
2169+
interrupt_thread.join(timeout=1.0)

python/tests/test_interruption.py

Lines changed: 0 additions & 122 deletions
This file was deleted.

0 commit comments

Comments
 (0)