|
3 | 3 | import time |
4 | 4 | from contextlib import contextmanager |
5 | 5 | from dataclasses import dataclass |
| 6 | +from datetime import timedelta |
6 | 7 | from typing import TYPE_CHECKING |
7 | 8 |
|
| 9 | +from async_timeout import Timeout, timeout |
| 10 | + |
8 | 11 | if TYPE_CHECKING: |
9 | 12 | from collections.abc import Iterator |
10 | | - from datetime import timedelta |
| 13 | + from types import TracebackType |
11 | 14 |
|
12 | 15 | _SECONDS_PER_MINUTE = 60 |
13 | 16 | _SECONDS_PER_HOUR = 3600 |
@@ -35,6 +38,43 @@ def measure_time() -> Iterator[TimerResult]: |
35 | 38 | result.cpu = after_cpu - before_cpu |
36 | 39 |
|
37 | 40 |
|
| 41 | +class SharedTimeout: |
| 42 | + """Keeps track of a time budget shared by multiple independent async operations. |
| 43 | +
|
| 44 | + Provides a reusable, non-reentrant context manager interface. |
| 45 | + """ |
| 46 | + |
| 47 | + def __init__(self, timeout: timedelta) -> None: |
| 48 | + self._remaining_timeout = timeout |
| 49 | + self._active_timeout: Timeout | None = None |
| 50 | + self._activation_timestamp: float | None = None |
| 51 | + |
| 52 | + async def __aenter__(self) -> timedelta: |
| 53 | + if self._active_timeout is not None or self._activation_timestamp is not None: |
| 54 | + raise RuntimeError('A shared timeout context cannot be entered twice at the same time') |
| 55 | + |
| 56 | + self._activation_timestamp = time.monotonic() |
| 57 | + self._active_timeout = new_timeout = timeout(self._remaining_timeout.total_seconds()) |
| 58 | + await new_timeout.__aenter__() |
| 59 | + return self._remaining_timeout |
| 60 | + |
| 61 | + async def __aexit__( |
| 62 | + self, |
| 63 | + exc_type: type[BaseException] | None, |
| 64 | + exc_value: BaseException | None, |
| 65 | + exc_traceback: TracebackType | None, |
| 66 | + ) -> None: |
| 67 | + if self._active_timeout is None or self._activation_timestamp is None: |
| 68 | + raise RuntimeError('Logic error') |
| 69 | + |
| 70 | + await self._active_timeout.__aexit__(exc_type, exc_value, exc_traceback) |
| 71 | + elapsed = time.monotonic() - self._activation_timestamp |
| 72 | + self._remaining_timeout = self._remaining_timeout - timedelta(seconds=elapsed) |
| 73 | + |
| 74 | + self._active_timeout = None |
| 75 | + self._activation_timestamp = None |
| 76 | + |
| 77 | + |
38 | 78 | def format_duration(duration: timedelta | None) -> str: |
39 | 79 | """Format a timedelta into a human-readable string with appropriate units.""" |
40 | 80 | if duration is None: |
|
0 commit comments