Skip to content

Commit 3f42a03

Browse files
Revert "introduce callback to handle link expiry"
This reverts commit bd51b1c.
1 parent 0868fe3 commit 3f42a03

File tree

6 files changed

+32
-335
lines changed

6 files changed

+32
-335
lines changed

src/databricks/sql/backend/sea/queue.py

Lines changed: 22 additions & 229 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
ResultManifest,
2323
)
2424
from databricks.sql.backend.sea.utils.constants import ResultFormat
25-
from databricks.sql.exc import ProgrammingError, Error
25+
from databricks.sql.exc import ProgrammingError
2626
from databricks.sql.thrift_api.TCLIService.ttypes import TSparkArrowResultLink
2727
from databricks.sql.types import SSLOptions
2828
from databricks.sql.utils import (
@@ -137,73 +137,10 @@ def __init__(
137137
self._error: Optional[Exception] = None
138138
self.chunk_index_to_link: Dict[int, "ExternalLink"] = {}
139139

140-
# Add initial links (no notification needed during init)
141-
self._add_links_to_manager(initial_links, notify=False)
142-
self.total_chunk_count = total_chunk_count
143-
self._worker_thread: Optional[threading.Thread] = None
144-
145-
def _add_links_to_manager(self, links: List["ExternalLink"], notify: bool = True):
146-
"""
147-
Add external links to both chunk mapping and download manager.
148-
149-
Args:
150-
links: List of external links to add
151-
notify: Whether to notify waiting threads (default True)
152-
"""
153-
for link in links:
140+
for link in initial_links:
154141
self.chunk_index_to_link[link.chunk_index] = link
155142
self.download_manager.add_link(self._convert_to_thrift_link(link))
156-
157-
if notify:
158-
self._link_data_update.notify_all()
159-
160-
def _clear_chunks_from_index(self, start_chunk_index: int):
161-
"""
162-
Clear all chunks >= start_chunk_index from the chunk mapping.
163-
164-
Args:
165-
start_chunk_index: The chunk index to start clearing from (inclusive)
166-
"""
167-
chunks_to_remove = [
168-
chunk_idx
169-
for chunk_idx in self.chunk_index_to_link.keys()
170-
if chunk_idx >= start_chunk_index
171-
]
172-
173-
logger.debug(
174-
f"LinkFetcher: Clearing chunks {chunks_to_remove} from index {start_chunk_index}"
175-
)
176-
for chunk_idx in chunks_to_remove:
177-
del self.chunk_index_to_link[chunk_idx]
178-
179-
def _fetch_and_add_links(self, chunk_index: int) -> List["ExternalLink"]:
180-
"""
181-
Fetch links from backend and add them to manager.
182-
183-
Args:
184-
chunk_index: The chunk index to fetch
185-
186-
Returns:
187-
List of fetched external links
188-
189-
Raises:
190-
Exception: If fetching fails
191-
"""
192-
logger.debug(f"LinkFetcher: Fetching links for chunk {chunk_index}")
193-
194-
try:
195-
links = self.backend.get_chunk_links(self._statement_id, chunk_index)
196-
self._add_links_to_manager(links, notify=True)
197-
logger.debug(
198-
f"LinkFetcher: Added {len(links)} links starting from chunk {chunk_index}"
199-
)
200-
return links
201-
202-
except Exception as e:
203-
logger.error(f"LinkFetcher: Failed to fetch chunk {chunk_index}: {e}")
204-
self._error = e
205-
self._link_data_update.notify_all()
206-
raise e
143+
self.total_chunk_count = total_chunk_count
207144

208145
def _get_next_chunk_index(self) -> Optional[int]:
209146
with self._link_data_update:
@@ -218,13 +155,23 @@ def _trigger_next_batch_download(self) -> bool:
218155
if next_chunk_index is None:
219156
return False
220157

221-
with self._link_data_update:
222-
try:
223-
self._fetch_and_add_links(next_chunk_index)
224-
return True
225-
except Exception:
226-
# Error already logged and set by _fetch_and_add_links
227-
return False
158+
try:
159+
links = self.backend.get_chunk_links(self._statement_id, next_chunk_index)
160+
with self._link_data_update:
161+
for l in links:
162+
self.chunk_index_to_link[l.chunk_index] = l
163+
self.download_manager.add_link(self._convert_to_thrift_link(l))
164+
self._link_data_update.notify_all()
165+
except Exception as e:
166+
logger.error(
167+
f"LinkFetcher: Error fetching links for chunk {next_chunk_index}: {e}"
168+
)
169+
with self._link_data_update:
170+
self._error = e
171+
self._link_data_update.notify_all()
172+
return False
173+
174+
return True
228175

229176
def get_chunk_link(self, chunk_index: int) -> Optional["ExternalLink"]:
230177
if chunk_index >= self.total_chunk_count:
@@ -238,45 +185,6 @@ def get_chunk_link(self, chunk_index: int) -> Optional["ExternalLink"]:
238185

239186
return self.chunk_index_to_link.get(chunk_index, None)
240187

241-
def restart_from_chunk(self, chunk_index: int):
242-
"""
243-
Restart the LinkFetcher from a specific chunk index.
244-
245-
This method handles both cases:
246-
1. LinkFetcher is done/closed but we need to restart it
247-
2. LinkFetcher is active but we need it to start from the expired chunk
248-
249-
The key insight: we need to clear all chunks >= restart_chunk_index
250-
so that _get_next_chunk_index() returns the correct next chunk.
251-
252-
Args:
253-
chunk_index: The chunk index to restart from
254-
"""
255-
logger.debug(f"LinkFetcher: Restarting from chunk {chunk_index}")
256-
257-
# Stop the current worker if running
258-
self.stop()
259-
260-
with self._link_data_update:
261-
# Clear error state
262-
self._error = None
263-
264-
# 🔥 CRITICAL: Clear all chunks >= restart_chunk_index
265-
# This ensures _get_next_chunk_index() works correctly
266-
self._clear_chunks_from_index(chunk_index)
267-
268-
# Now fetch the restart chunk (and potentially its batch)
269-
# This becomes our new "max chunk" and starting point
270-
try:
271-
self._fetch_and_add_links(chunk_index)
272-
except Exception as e:
273-
# Error already logged and set by _fetch_and_add_links
274-
raise e
275-
276-
# Start the worker again - now _get_next_chunk_index() will work correctly
277-
self.start()
278-
logger.debug(f"LinkFetcher: Successfully restarted from chunk {chunk_index}")
279-
280188
def _convert_to_thrift_link(self, link: "ExternalLink") -> TSparkArrowResultLink:
281189
"""Convert SEA external links to Thrift format for compatibility with existing download manager."""
282190
# Parse the ISO format expiration time
@@ -297,17 +205,12 @@ def _worker_loop(self):
297205
break
298206

299207
def start(self):
300-
if self._worker_thread and self._worker_thread.is_alive():
301-
return # Already running
302-
303-
self._shutdown_event.clear()
304208
self._worker_thread = threading.Thread(target=self._worker_loop)
305209
self._worker_thread.start()
306210

307211
def stop(self):
308-
if self._worker_thread and self._worker_thread.is_alive():
309-
self._shutdown_event.set()
310-
self._worker_thread.join()
212+
self._shutdown_event.set()
213+
self._worker_thread.join()
311214

312215

313216
class SeaCloudFetchQueue(CloudFetchQueue):
@@ -366,7 +269,6 @@ def __init__(
366269
max_download_threads=max_download_threads,
367270
lz4_compressed=lz4_compressed,
368271
ssl_options=ssl_options,
369-
expired_link_callback=self._handle_expired_link,
370272
)
371273

372274
self.link_fetcher = LinkFetcher(
@@ -381,115 +283,6 @@ def __init__(
381283
# Initialize table and position
382284
self.table = self._create_next_table()
383285

384-
def _handle_expired_link(
385-
self, expired_link: TSparkArrowResultLink
386-
) -> TSparkArrowResultLink:
387-
"""
388-
Handle expired link for SEA backend.
389-
390-
For SEA backend, we can handle expired links robustly by:
391-
1. Cancelling all pending downloads
392-
2. Finding the chunk index for the expired link
393-
3. Restarting the LinkFetcher from that chunk
394-
4. Returning the requested link
395-
396-
Args:
397-
expired_link: The expired link
398-
399-
Returns:
400-
A new link with the same row offset
401-
402-
Raises:
403-
Error: If unable to fetch new link
404-
"""
405-
logger.warning(
406-
"SeaCloudFetchQueue: Link expired for offset {}, row count {}. Attempting to fetch new links.".format(
407-
expired_link.startRowOffset, expired_link.rowCount
408-
)
409-
)
410-
411-
if not self.download_manager:
412-
raise ValueError("Download manager not initialized")
413-
414-
try:
415-
# Step 1: Cancel all pending downloads
416-
self.download_manager.cancel_all_downloads()
417-
logger.debug("SeaCloudFetchQueue: Cancelled all pending downloads")
418-
419-
# Step 2: Find which chunk contains the expired link
420-
target_chunk_index = self._find_chunk_index_for_row_offset(
421-
expired_link.startRowOffset
422-
)
423-
if target_chunk_index is None:
424-
# If we can't find the chunk, we may need to search more broadly
425-
# For now, let's assume it's a reasonable chunk based on the row offset
426-
# This is a fallback - in practice this should be rare
427-
logger.warning(
428-
"SeaCloudFetchQueue: Could not find chunk index for row offset {}, using fallback approach".format(
429-
expired_link.startRowOffset
430-
)
431-
)
432-
# Try to estimate chunk index - this is a heuristic
433-
target_chunk_index = 0 # Start from beginning as fallback
434-
435-
# Step 3: Restart LinkFetcher from the target chunk
436-
# This handles both stopped and active LinkFetcher cases
437-
self.link_fetcher.restart_from_chunk(target_chunk_index)
438-
439-
# Step 4: Find and return the link that matches the expired link's row offset
440-
# After restart, the chunk should be available
441-
for (
442-
chunk_index,
443-
external_link,
444-
) in self.link_fetcher.chunk_index_to_link.items():
445-
if external_link.row_offset == expired_link.startRowOffset:
446-
new_thrift_link = self.link_fetcher._convert_to_thrift_link(
447-
external_link
448-
)
449-
logger.debug(
450-
"SeaCloudFetchQueue: Found replacement link for offset {}, row count {}".format(
451-
new_thrift_link.startRowOffset, new_thrift_link.rowCount
452-
)
453-
)
454-
return new_thrift_link
455-
456-
# If we still can't find it, raise an error
457-
logger.error(
458-
"SeaCloudFetchQueue: Could not find replacement link for row offset {} after restart".format(
459-
expired_link.startRowOffset
460-
)
461-
)
462-
raise Error(
463-
f"CloudFetch link has expired and could not be renewed for offset {expired_link.startRowOffset}"
464-
)
465-
466-
except Exception as e:
467-
logger.error(
468-
"SeaCloudFetchQueue: Error handling expired link: {}".format(str(e))
469-
)
470-
if isinstance(e, Error):
471-
raise e
472-
else:
473-
raise Error(f"CloudFetch link has expired and renewal failed: {str(e)}")
474-
475-
def _find_chunk_index_for_row_offset(self, row_offset: int) -> Optional[int]:
476-
"""
477-
Find the chunk index that contains the given row offset.
478-
479-
Args:
480-
row_offset: The row offset to find
481-
482-
Returns:
483-
The chunk index, or None if not found
484-
"""
485-
# Search through our known chunks to find the one containing this row offset
486-
for chunk_index, external_link in self.link_fetcher.chunk_index_to_link.items():
487-
if external_link.row_offset == row_offset:
488-
return chunk_index
489-
490-
# If not found in known chunks, return None and let the caller handle it
491-
return None
492-
493286
def _create_next_table(self) -> Union["pyarrow.Table", None]:
494287
"""Create next table by retrieving the logical next downloaded file."""
495288
if not self.download_manager:

src/databricks/sql/cloudfetch/download_manager.py

Lines changed: 2 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import logging
22

33
from concurrent.futures import ThreadPoolExecutor, Future
4-
from typing import List, Union, Callable
4+
from typing import List, Union
55

66
from databricks.sql.cloudfetch.downloader import (
77
ResultSetDownloadHandler,
@@ -22,7 +22,6 @@ def __init__(
2222
max_download_threads: int,
2323
lz4_compressed: bool,
2424
ssl_options: SSLOptions,
25-
expired_link_callback: Callable[[TSparkArrowResultLink], TSparkArrowResultLink],
2625
):
2726
self._pending_links: List[TSparkArrowResultLink] = []
2827
for link in links:
@@ -39,10 +38,7 @@ def __init__(
3938
self._max_download_threads: int = max_download_threads
4039
self._thread_pool = ThreadPoolExecutor(max_workers=self._max_download_threads)
4140

42-
self._downloadable_result_settings = DownloadableResultSettings(
43-
is_lz4_compressed=lz4_compressed,
44-
expired_link_callback=expired_link_callback,
45-
)
41+
self._downloadable_result_settings = DownloadableResultSettings(lz4_compressed)
4642
self._ssl_options = ssl_options
4743

4844
def get_next_downloaded_file(
@@ -123,29 +119,6 @@ def add_link(self, link: TSparkArrowResultLink):
123119
)
124120
self._pending_links.append(link)
125121

126-
def cancel_all_downloads(self):
127-
"""
128-
Cancel all pending downloads and clear the download queue.
129-
130-
This method is typically called when links have expired and we need to
131-
cancel all pending downloads before fetching new links.
132-
"""
133-
logger.debug("ResultFileDownloadManager: cancelling all downloads")
134-
135-
# Cancel all pending download tasks
136-
cancelled_count = 0
137-
for task in self._download_tasks:
138-
if task.cancel():
139-
cancelled_count += 1
140-
141-
logger.debug(
142-
f"ResultFileDownloadManager: cancelled {cancelled_count} out of {len(self._download_tasks)} downloads"
143-
)
144-
145-
# Clear the download tasks and pending links
146-
self._download_tasks.clear()
147-
self._pending_links.clear()
148-
149122
def _shutdown_manager(self):
150123
# Clear download handlers and shutdown the thread pool
151124
self._pending_links = []

0 commit comments

Comments
 (0)