Skip to content

Commit c6f09e5

Browse files
committed
Ensured the explicit closing of async generators
# Conflicts: # CHANGELOG.md
1 parent 3fa326b commit c6f09e5

File tree

10 files changed

+80
-58
lines changed

10 files changed

+80
-58
lines changed

CHANGELOG.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,10 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/).
66

77
## [UNRELEASED]
88

9+
### Fixed
10+
11+
* Explicitly close all async generators to ensure predictable behavior
12+
913
### Removed
1014

1115
* Drop support for Python 3.8

httpx/_client.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import time
77
import typing
88
import warnings
9+
from collections.abc import AsyncGenerator
910
from contextlib import asynccontextmanager, contextmanager
1011
from types import TracebackType
1112

@@ -46,7 +47,7 @@
4647
TimeoutTypes,
4748
)
4849
from ._urls import URL, QueryParams
49-
from ._utils import URLPattern, get_environment_proxies
50+
from ._utils import URLPattern, get_environment_proxies, safe_async_iterate
5051

5152
if typing.TYPE_CHECKING:
5253
import ssl # pragma: no cover
@@ -172,9 +173,10 @@ def __init__(
172173
self._response = response
173174
self._start = start
174175

175-
async def __aiter__(self) -> typing.AsyncIterator[bytes]:
176-
async for chunk in self._stream:
177-
yield chunk
176+
async def __aiter__(self) -> AsyncGenerator[bytes]:
177+
async with safe_async_iterate(self._stream) as iterator:
178+
async for chunk in iterator:
179+
yield chunk
178180

179181
async def aclose(self) -> None:
180182
elapsed = time.perf_counter() - self._start

httpx/_content.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import inspect
44
import warnings
5+
from collections.abc import AsyncGenerator
56
from json import dumps as json_dumps
67
from typing import (
78
Any,
@@ -10,6 +11,7 @@
1011
Iterable,
1112
Iterator,
1213
Mapping,
14+
NoReturn,
1315
)
1416
from urllib.parse import urlencode
1517

@@ -23,7 +25,7 @@
2325
ResponseContent,
2426
SyncByteStream,
2527
)
26-
from ._utils import peek_filelike_length, primitive_value_to_str
28+
from ._utils import peek_filelike_length, primitive_value_to_str, safe_async_iterate
2729

2830
__all__ = ["ByteStream"]
2931

@@ -35,7 +37,7 @@ def __init__(self, stream: bytes) -> None:
3537
def __iter__(self) -> Iterator[bytes]:
3638
yield self._stream
3739

38-
async def __aiter__(self) -> AsyncIterator[bytes]:
40+
async def __aiter__(self) -> AsyncGenerator[bytes]:
3941
yield self._stream
4042

4143

@@ -85,8 +87,9 @@ async def __aiter__(self) -> AsyncIterator[bytes]:
8587
chunk = await self._stream.aread(self.CHUNK_SIZE)
8688
else:
8789
# Otherwise iterate.
88-
async for part in self._stream:
89-
yield part
90+
async with safe_async_iterate(self._stream) as iterator:
91+
async for part in iterator:
92+
yield part
9093

9194

9295
class UnattachedStream(AsyncByteStream, SyncByteStream):
@@ -99,9 +102,8 @@ class UnattachedStream(AsyncByteStream, SyncByteStream):
99102
def __iter__(self) -> Iterator[bytes]:
100103
raise StreamClosed()
101104

102-
async def __aiter__(self) -> AsyncIterator[bytes]:
105+
def __aiter__(self) -> NoReturn:
103106
raise StreamClosed()
104-
yield b"" # pragma: no cover
105107

106108

107109
def encode_content(

httpx/_models.py

Lines changed: 26 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import re
88
import typing
99
import urllib.request
10-
from collections.abc import Mapping
10+
from collections.abc import AsyncGenerator, Mapping
1111
from http.cookiejar import Cookie, CookieJar
1212

1313
from ._content import ByteStream, UnattachedStream, encode_request, encode_response
@@ -46,7 +46,7 @@
4646
SyncByteStream,
4747
)
4848
from ._urls import URL
49-
from ._utils import to_bytes_or_str, to_str
49+
from ._utils import safe_async_iterate, to_bytes_or_str, to_str
5050

5151
__all__ = ["Cookies", "Headers", "Request", "Response"]
5252

@@ -979,9 +979,7 @@ async def aread(self) -> bytes:
979979
self._content = b"".join([part async for part in self.aiter_bytes()])
980980
return self._content
981981

982-
async def aiter_bytes(
983-
self, chunk_size: int | None = None
984-
) -> typing.AsyncIterator[bytes]:
982+
async def aiter_bytes(self, chunk_size: int | None = None) -> AsyncGenerator[bytes]:
985983
"""
986984
A byte-iterator over the decoded response content.
987985
This allows us to handle gzip, deflate, brotli, and zstd encoded responses.
@@ -994,19 +992,19 @@ async def aiter_bytes(
994992
decoder = self._get_content_decoder()
995993
chunker = ByteChunker(chunk_size=chunk_size)
996994
with request_context(request=self._request):
997-
async for raw_bytes in self.aiter_raw():
998-
decoded = decoder.decode(raw_bytes)
999-
for chunk in chunker.decode(decoded):
1000-
yield chunk
995+
async with safe_async_iterate(self.aiter_raw()) as iterator:
996+
async for raw_bytes in iterator:
997+
decoded = decoder.decode(raw_bytes)
998+
for chunk in chunker.decode(decoded):
999+
yield chunk
1000+
10011001
decoded = decoder.flush()
10021002
for chunk in chunker.decode(decoded):
10031003
yield chunk # pragma: no cover
10041004
for chunk in chunker.flush():
10051005
yield chunk
10061006

1007-
async def aiter_text(
1008-
self, chunk_size: int | None = None
1009-
) -> typing.AsyncIterator[str]:
1007+
async def aiter_text(self, chunk_size: int | None = None) -> AsyncGenerator[str]:
10101008
"""
10111009
A str-iterator over the decoded response content
10121010
that handles both gzip, deflate, etc but also detects the content's
@@ -1015,28 +1013,28 @@ async def aiter_text(
10151013
decoder = TextDecoder(encoding=self.encoding or "utf-8")
10161014
chunker = TextChunker(chunk_size=chunk_size)
10171015
with request_context(request=self._request):
1018-
async for byte_content in self.aiter_bytes():
1019-
text_content = decoder.decode(byte_content)
1020-
for chunk in chunker.decode(text_content):
1021-
yield chunk
1016+
async with safe_async_iterate(self.aiter_bytes()) as iterator:
1017+
async for byte_content in iterator:
1018+
text_content = decoder.decode(byte_content)
1019+
for chunk in chunker.decode(text_content):
1020+
yield chunk
10221021
text_content = decoder.flush()
10231022
for chunk in chunker.decode(text_content):
10241023
yield chunk # pragma: no cover
10251024
for chunk in chunker.flush():
10261025
yield chunk
10271026

1028-
async def aiter_lines(self) -> typing.AsyncIterator[str]:
1027+
async def aiter_lines(self) -> AsyncGenerator[str]:
10291028
decoder = LineDecoder()
10301029
with request_context(request=self._request):
1031-
async for text in self.aiter_text():
1032-
for line in decoder.decode(text):
1033-
yield line
1030+
async with safe_async_iterate(self.aiter_text()) as iterator:
1031+
async for text in iterator:
1032+
for line in decoder.decode(text):
1033+
yield line
10341034
for line in decoder.flush():
10351035
yield line
10361036

1037-
async def aiter_raw(
1038-
self, chunk_size: int | None = None
1039-
) -> typing.AsyncIterator[bytes]:
1037+
async def aiter_raw(self, chunk_size: int | None = None) -> AsyncGenerator[bytes]:
10401038
"""
10411039
A byte-iterator over the raw response content.
10421040
"""
@@ -1052,10 +1050,11 @@ async def aiter_raw(
10521050
chunker = ByteChunker(chunk_size=chunk_size)
10531051

10541052
with request_context(request=self._request):
1055-
async for raw_stream_bytes in self.stream:
1056-
self._num_bytes_downloaded += len(raw_stream_bytes)
1057-
for chunk in chunker.decode(raw_stream_bytes):
1058-
yield chunk
1053+
async with safe_async_iterate(self.stream) as iterator:
1054+
async for raw_stream_bytes in iterator:
1055+
self._num_bytes_downloaded += len(raw_stream_bytes)
1056+
for chunk in chunker.decode(raw_stream_bytes):
1057+
yield chunk
10591058

10601059
for chunk in chunker.flush():
10611060
yield chunk

httpx/_multipart.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import os
66
import re
77
import typing
8+
from collections.abc import AsyncGenerator
89
from pathlib import Path
910

1011
from ._types import (
@@ -295,6 +296,6 @@ def __iter__(self) -> typing.Iterator[bytes]:
295296
for chunk in self.iter_chunks():
296297
yield chunk
297298

298-
async def __aiter__(self) -> typing.AsyncIterator[bytes]:
299+
async def __aiter__(self) -> AsyncGenerator[bytes]:
299300
for chunk in self.iter_chunks():
300301
yield chunk

httpx/_transports/asgi.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from __future__ import annotations
22

33
import typing
4+
from collections.abc import AsyncGenerator
45

56
from .._models import Request, Response
67
from .._types import AsyncByteStream
@@ -56,7 +57,7 @@ class ASGIResponseStream(AsyncByteStream):
5657
def __init__(self, body: list[bytes]) -> None:
5758
self._body = body
5859

59-
async def __aiter__(self) -> typing.AsyncIterator[bytes]:
60+
async def __aiter__(self) -> AsyncGenerator[bytes]:
6061
yield b"".join(self._body)
6162

6263

httpx/_transports/default.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828

2929
import contextlib
3030
import typing
31+
from collections.abc import AsyncGenerator
3132
from types import TracebackType
3233

3334
if typing.TYPE_CHECKING:
@@ -55,6 +56,7 @@
5556
from .._models import Request, Response
5657
from .._types import AsyncByteStream, CertTypes, ProxyTypes, SyncByteStream
5758
from .._urls import URL
59+
from .._utils import safe_async_iterate
5860
from .base import AsyncBaseTransport, BaseTransport
5961

6062
T = typing.TypeVar("T", bound="HTTPTransport")
@@ -266,10 +268,11 @@ class AsyncResponseStream(AsyncByteStream):
266268
def __init__(self, httpcore_stream: typing.AsyncIterable[bytes]) -> None:
267269
self._httpcore_stream = httpcore_stream
268270

269-
async def __aiter__(self) -> typing.AsyncIterator[bytes]:
271+
async def __aiter__(self) -> AsyncGenerator[bytes]:
270272
with map_httpcore_exceptions():
271-
async for part in self._httpcore_stream:
272-
yield part
273+
async with safe_async_iterate(self._httpcore_stream) as iterator:
274+
async for part in iterator:
275+
yield part
273276

274277
async def aclose(self) -> None:
275278
if hasattr(self._httpcore_stream, "aclose"):

httpx/_types.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,6 @@ def __iter__(self) -> Iterator[bytes]:
9494
raise NotImplementedError(
9595
"The '__iter__' method must be implemented."
9696
) # pragma: no cover
97-
yield b"" # pragma: no cover
9897

9998
def close(self) -> None:
10099
"""
@@ -104,11 +103,10 @@ def close(self) -> None:
104103

105104

106105
class AsyncByteStream:
107-
async def __aiter__(self) -> AsyncIterator[bytes]:
106+
def __aiter__(self) -> AsyncIterator[bytes]:
108107
raise NotImplementedError(
109108
"The '__aiter__' method must be implemented."
110109
) # pragma: no cover
111-
yield b"" # pragma: no cover
112110

113111
async def aclose(self) -> None:
114112
pass

httpx/_utils.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,18 @@
44
import os
55
import re
66
import typing
7+
from collections.abc import AsyncGenerator, AsyncIterable, AsyncIterator
8+
from contextlib import asynccontextmanager
9+
from inspect import isasyncgen
710
from urllib.request import getproxies
811

912
from ._types import PrimitiveData
1013

1114
if typing.TYPE_CHECKING: # pragma: no cover
1215
from ._urls import URL
1316

17+
T = typing.TypeVar("T")
18+
1419

1520
def primitive_value_to_str(value: PrimitiveData) -> str:
1621
"""
@@ -240,3 +245,19 @@ def is_ipv6_hostname(hostname: str) -> bool:
240245
except Exception:
241246
return False
242247
return True
248+
249+
250+
@asynccontextmanager
251+
async def safe_async_iterate(
252+
iterable_or_iterator: AsyncIterable[T] | AsyncIterator[T], /
253+
) -> AsyncGenerator[AsyncIterator[T]]:
254+
iterator = (
255+
iterable_or_iterator
256+
if isinstance(iterable_or_iterator, AsyncIterator)
257+
else iterable_or_iterator.__aiter__()
258+
)
259+
try:
260+
yield iterator
261+
finally:
262+
if isasyncgen(iterator):
263+
await iterator.aclose()

tests/test_content.py

Lines changed: 4 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import io
22
import typing
3+
from collections.abc import AsyncGenerator
34

45
import pytest
56

@@ -64,20 +65,10 @@ async def test_bytesio_content():
6465

6566
@pytest.mark.anyio
6667
async def test_async_bytesio_content():
67-
class AsyncBytesIO:
68-
def __init__(self, content: bytes) -> None:
69-
self._idx = 0
70-
self._content = content
68+
async def fixed_stream(content: bytes) -> AsyncGenerator[bytes]:
69+
yield content
7170

72-
async def aread(self, chunk_size: int) -> bytes:
73-
chunk = self._content[self._idx : self._idx + chunk_size]
74-
self._idx = self._idx + chunk_size
75-
return chunk
76-
77-
async def __aiter__(self):
78-
yield self._content # pragma: no cover
79-
80-
request = httpx.Request(method, url, content=AsyncBytesIO(b"Hello, world!"))
71+
request = httpx.Request(method, url, content=fixed_stream(b"Hello, world!"))
8172
assert not isinstance(request.stream, typing.Iterable)
8273
assert isinstance(request.stream, typing.AsyncIterable)
8374

0 commit comments

Comments
 (0)