Skip to content

Commit 7eee90d

Browse files
Rogdhammbeijen
andcommitted
Zstandard: fix crash on frame boundaries
Co-authored-by: Michiel W. Beijen <mb@x14.nl>
1 parent 4d05db3 commit 7eee90d

File tree

2 files changed

+44
-12
lines changed

2 files changed

+44
-12
lines changed

httpx/_decoders.py

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -179,28 +179,27 @@ def __init__(self) -> None:
179179
) from None
180180

181181
self.decompressor = zstd.ZstdDecompressor()
182-
self.seen_data = False
182+
self.at_valid_eof = True
183183

184184
def decode(self, data: bytes) -> bytes:
185185
assert zstd is not None
186-
self.seen_data = True
187186
output = io.BytesIO()
188187
try:
189-
output.write(self.decompressor.decompress(data))
190-
while self.decompressor.eof and self.decompressor.unused_data:
191-
unused_data = self.decompressor.unused_data
192-
self.decompressor = zstd.ZstdDecompressor()
193-
output.write(self.decompressor.decompress(unused_data))
188+
self.at_valid_eof = False
189+
while data:
190+
output.write(self.decompressor.decompress(data))
191+
data = self.decompressor.unused_data
192+
if self.decompressor.eof:
193+
self.decompressor = zstd.ZstdDecompressor()
194+
self.at_valid_eof = not data
194195
except zstd.ZstdError as exc:
195196
raise DecodingError(str(exc)) from exc
196197
return output.getvalue()
197198

198199
def flush(self) -> bytes:
199-
if not self.seen_data:
200+
if self.at_valid_eof:
200201
return b""
201-
if not self.decompressor.eof:
202-
raise DecodingError("Zstandard data is incomplete") # pragma: no cover
203-
return b""
202+
raise DecodingError("Zstandard data is incomplete") # pragma: no cover
204203

205204

206205
class MultiDecoder(ContentDecoder):

tests/test_decoders.py

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ def test_zstd_truncated():
120120
httpx.Response(
121121
200,
122122
headers=headers,
123-
content=compressed_body[1:3],
123+
content=compressed_body[:-1],
124124
)
125125

126126

@@ -146,6 +146,39 @@ def test_zstd_multiframe():
146146
assert response.content == b"foobar"
147147

148148

149+
def test_zstd_truncated_multiframe():
150+
body = b"test 123"
151+
compressed_body = zstd.compress(body)
152+
153+
headers = [(b"Content-Encoding", b"zstd")]
154+
with pytest.raises(httpx.DecodingError):
155+
httpx.Response(
156+
200,
157+
headers=headers,
158+
content=compressed_body + compressed_body[:-1],
159+
)
160+
161+
162+
def test_zstd_streaming_multiple_frames():
163+
body1 = b"test 123 "
164+
body2 = b"another frame"
165+
166+
# Create two separate complete frames
167+
frame1 = zstd.compress(body1)
168+
frame2 = zstd.compress(body2)
169+
170+
# Create an iterator that yields frames separately
171+
def content_iterator() -> typing.Iterator[bytes]:
172+
yield frame1
173+
yield frame2
174+
175+
headers = [(b"Content-Encoding", b"zstd")]
176+
response = httpx.Response(200, headers=headers, content=content_iterator())
177+
response.read()
178+
179+
assert response.content == body1 + body2
180+
181+
149182
def test_multi():
150183
body = b"test 123"
151184

0 commit comments

Comments
 (0)