Skip to content

Commit c346b18

Browse files
committed
Fix test using .ndarray and .reformat
Due to differences in current memory reference methods, frames created using .from_ndarray or .reformat are not supported by __dlpack__.
1 parent 0af7dcf commit c346b18

File tree

1 file changed

+17
-13
lines changed

1 file changed

+17
-13
lines changed

tests/test_dlpack.py

Lines changed: 17 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -230,30 +230,34 @@ def test_video_frame_from_dlpack_p010_p016_cpu(fmt: str) -> None:
230230

231231

232232
def test_video_plane_dlpack_export_keeps_frame_alive_after_gc() -> None:
233-
container = av.open(fate_png())
234-
frame = next(container.decode(video=0))
235-
frame_nv12 = frame.reformat(format="nv12")
233+
width, height = 64, 48
234+
frame = VideoFrame(width, height, "nv12")
236235

237-
width = frame_nv12.width
238-
height = frame_nv12.height
239-
line_size = frame_nv12.planes[0].line_size
240-
expected = _plane_to_2d(frame_nv12.planes[0], height, width, numpy.uint8).copy()
236+
y_bytes = numpy.arange(frame.planes[0].buffer_size, dtype=numpy.uint8).tobytes()
237+
frame.planes[0].update(y_bytes)
241238

242-
y_dl = numpy.from_dlpack(frame_nv12.planes[0])
243-
assert y_dl.shape == (height, width)
244-
assert y_dl.strides == (line_size, 1)
239+
y_dl = numpy.from_dlpack(frame.planes[0])
240+
expected = y_dl.copy()
245241

246-
del frame_nv12
247242
del frame
248-
del container
249243
gc.collect()
250244

251245
assertNdarraysEqual(y_dl, expected)
252246

253247

254-
def test_video_plane_dlpack_unsupported_format_raises() -> None:
248+
def test_video_plane_dlpack_requires_refcounted_frame() -> None:
249+
# TODO: By extending `from_dlpack` to implement all pixel formats
250+
# of `from_ndarray`, the issue can be resolved by calling
251+
# `from_dlpack(np.array)` internally within `from_ndarray`.
255252
rgb = numpy.zeros((16, 16, 3), dtype=numpy.uint8)
256253
frame = VideoFrame.from_ndarray(rgb, format="rgb24")
254+
255+
with pytest.raises(TypeError, match="refcounted AVFrame"):
256+
frame.planes[0].__dlpack__()
257+
258+
259+
def test_video_plane_dlpack_unsupported_format_raises() -> None:
260+
frame = VideoFrame(16, 16, "rgb24")
257261
assert frame.planes[0].__dlpack_device__() == (1, 0)
258262

259263
with pytest.raises(

0 commit comments

Comments
 (0)