Skip to content

Commit da5cc26

Browse files
Rogdhammbeijen
andcommitted
Zstandard: fix crash on frame boundaries
Co-authored-by: Michiel W. Beijen <mb@x14.nl>
1 parent 4d05db3 commit da5cc26

File tree

2 files changed

+30
-11
lines changed

2 files changed

+30
-11
lines changed

httpx/_decoders.py

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -179,28 +179,27 @@ def __init__(self) -> None:
179179
) from None
180180

181181
self.decompressor = zstd.ZstdDecompressor()
182-
self.seen_data = False
182+
self.at_valid_eof = True
183183

184184
def decode(self, data: bytes) -> bytes:
185185
assert zstd is not None
186-
self.seen_data = True
187186
output = io.BytesIO()
188187
try:
189-
output.write(self.decompressor.decompress(data))
190-
while self.decompressor.eof and self.decompressor.unused_data:
191-
unused_data = self.decompressor.unused_data
192-
self.decompressor = zstd.ZstdDecompressor()
193-
output.write(self.decompressor.decompress(unused_data))
188+
while data:
189+
self.at_valid_eof = False
190+
output.write(self.decompressor.decompress(data))
191+
data = self.decompressor.unused_data
192+
if self.decompressor.eof:
193+
self.at_valid_eof = True
194+
self.decompressor = zstd.ZstdDecompressor()
194195
except zstd.ZstdError as exc:
195196
raise DecodingError(str(exc)) from exc
196197
return output.getvalue()
197198

198199
def flush(self) -> bytes:
199-
if not self.seen_data:
200+
if self.at_valid_eof:
200201
return b""
201-
if not self.decompressor.eof:
202-
raise DecodingError("Zstandard data is incomplete") # pragma: no cover
203-
return b""
202+
raise DecodingError("Zstandard data is incomplete") # pragma: no cover
204203

205204

206205
class MultiDecoder(ContentDecoder):

tests/test_decoders.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,26 @@ def test_zstd_multiframe():
146146
assert response.content == b"foobar"
147147

148148

149+
def test_zstd_streaming_multiple_frames():
150+
body1 = b"test 123 "
151+
body2 = b"another frame"
152+
153+
# Create two separate complete frames
154+
frame1 = zstd.compress(body1)
155+
frame2 = zstd.compress(body2)
156+
157+
# Create an iterator that yields frames separately
158+
def content_iterator() -> typing.Iterator[bytes]:
159+
yield frame1
160+
yield frame2
161+
162+
headers = [(b"Content-Encoding", b"zstd")]
163+
response = httpx.Response(200, headers=headers, content=content_iterator())
164+
response.read()
165+
166+
assert response.content == body1 + body2
167+
168+
149169
def test_multi():
150170
body = b"test 123"
151171

0 commit comments

Comments
 (0)