Skip to content

Commit 56dd2dc

Browse files
committed
Impl minimal support device_id
1 parent aaa90db commit 56dd2dc

File tree

6 files changed

+128
-38
lines changed

6 files changed

+128
-38
lines changed

av/_hwdevice_registry.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
_cuda_hwdevice_data_ptr_to_device_id: dict[int, int] = {}
2+
3+
4+
def register_cuda_hwdevice_data_ptr(hwdevice_data_ptr: int, device_id: int) -> None:
5+
if hwdevice_data_ptr:
6+
_cuda_hwdevice_data_ptr_to_device_id[int(hwdevice_data_ptr)] = int(device_id)
7+
8+
9+
def lookup_cuda_device_id(hwdevice_data_ptr: int) -> int:
10+
if not hwdevice_data_ptr:
11+
return 0
12+
return _cuda_hwdevice_data_ptr_to_device_id.get(int(hwdevice_data_ptr), 0)

av/codec/hwaccel.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
from cython.cimports.av.error import err_check
99
from cython.cimports.av.video.format import get_video_format
1010

11+
import av._hwdevice_registry as _hwreg
12+
1113

1214
class HWDeviceType(IntEnum):
1315
none = lib.AV_HWDEVICE_TYPE_NONE
@@ -112,6 +114,9 @@ def __init__(
112114
flags=None,
113115
output_format="sw",
114116
):
117+
if isinstance(device, int):
118+
device = str(device)
119+
115120
if isinstance(device_type, HWDeviceType):
116121
self._device_type = device_type
117122
elif isinstance(device_type, str):
@@ -131,7 +136,10 @@ def __init__(
131136

132137
self._device = device
133138
self.allow_software_fallback = allow_software_fallback
139+
134140
self.options = {} if not options else dict(options)
141+
if self._device_type == HWDeviceType.cuda and self.output_format == "hw":
142+
self.options.setdefault("primary_ctx", "1")
135143
self.flags = 0 if not flags else flags
136144
self.ptr = cython.NULL
137145
self.config = None
@@ -164,6 +172,19 @@ def _initialize_hw_context(self, codec: Codec):
164172
)
165173
)
166174

175+
if config.ptr.device_type == lib.AV_HWDEVICE_TYPE_CUDA:
176+
device_id = 0
177+
if self._device:
178+
try:
179+
device_id = int(self._device)
180+
except ValueError:
181+
device_id = 0
182+
183+
_hwreg.register_cuda_hwdevice_data_ptr(
184+
cython.cast(cython.size_t, self.ptr.data),
185+
device_id,
186+
)
187+
167188
def create(self, codec: Codec):
168189
"""Create a new hardware accelerator context with the given codec"""
169190
if self.ptr:

av/video/frame.py

Lines changed: 53 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import cython
55
import cython.cimports.libav as lib
66
from cython.cimports.av.dictionary import Dictionary
7-
from cython.cimports.av.dlpack import DLManagedTensor, kDLCUDA, kDLUInt
7+
from cython.cimports.av.dlpack import DLManagedTensor, kDLCUDA, kDLUInt, kDLCPU
88
from cython.cimports.av.error import err_check
99
from cython.cimports.av.hwcontext import (
1010
AVHWFramesContext,
@@ -23,6 +23,7 @@
2323
)
2424
from cython.cimports.libc.stdint import int64_t, uint8_t
2525

26+
import av._hwdevice_registry as _hwreg
2627

2728
_cuda_device_ctx_cache = {}
2829
_cuda_frames_ctx_cache = {}
@@ -67,8 +68,12 @@ def _dlpack_avbuffer_free(
6768
managed.deleter(managed)
6869

6970
@cython.cfunc
70-
def _get_cuda_device_ctx(device_id: cython.int) -> cython.pointer[lib.AVBufferRef]:
71-
cached = _cuda_device_ctx_cache.get(device_id)
71+
def _get_cuda_device_ctx(
72+
device_id: cython.int,
73+
primary_ctx: cython.bint,
74+
) -> cython.pointer[lib.AVBufferRef]:
75+
key = (int(device_id), int(primary_ctx))
76+
cached = _cuda_device_ctx_cache.get(key)
7277
if cached is not None:
7378
return cython.cast(
7479
cython.pointer[lib.AVBufferRef],
@@ -78,7 +83,7 @@ def _get_cuda_device_ctx(device_id: cython.int) -> cython.pointer[lib.AVBufferRe
7883
device_ref: cython.pointer[lib.AVBufferRef] = cython.NULL
7984
device_bytes = str(device_id).encode()
8085
c_device: cython.p_char = device_bytes
81-
options: Dictionary = Dictionary({"primary_ctx": "1"})
86+
options: Dictionary = Dictionary({"primary_ctx": "1" if primary_ctx else "0"})
8287

8388
err_check(
8489
lib.av_hwdevice_ctx_create(
@@ -90,25 +95,31 @@ def _get_cuda_device_ctx(device_id: cython.int) -> cython.pointer[lib.AVBufferRe
9095
)
9196
)
9297

93-
_cuda_device_ctx_cache[device_id] = cython.cast(cython.size_t, device_ref)
98+
_hwreg.register_cuda_hwdevice_data_ptr(
99+
cython.cast(cython.size_t, device_ref.data),
100+
device_id,
101+
)
102+
103+
_cuda_device_ctx_cache[key] = cython.cast(cython.size_t, device_ref)
94104
return device_ref
95105

96106
@cython.cfunc
97107
def _get_cuda_frames_ctx(
98108
device_id: cython.int,
109+
primary_ctx: cython.bint,
99110
sw_fmt: lib.AVPixelFormat,
100111
width: cython.int,
101112
height: cython.int,
102113
) -> cython.pointer[lib.AVBufferRef]:
103-
key = (device_id, int(sw_fmt), int(width), int(height))
114+
key = (int(device_id), int(primary_ctx), int(sw_fmt), int(width), int(height))
104115
cached = _cuda_frames_ctx_cache.get(key)
105116
if cached is not None:
106117
return cython.cast(
107118
cython.pointer[lib.AVBufferRef],
108119
cython.cast(cython.size_t, cached),
109120
)
110121

111-
device_ref = _get_cuda_device_ctx(device_id)
122+
device_ref = _get_cuda_device_ctx(device_id, primary_ctx)
112123
frames_ref = av_hwframe_ctx_alloc(device_ref)
113124
if frames_ref == cython.NULL:
114125
raise MemoryError("av_hwframe_ctx_alloc() failed")
@@ -1330,6 +1341,7 @@ def from_dlpack(
13301341
height: int = 0,
13311342
stream=None,
13321343
device_id: int | None = None,
1344+
primary_ctx: bool = True,
13331345
):
13341346
if not isinstance(planes, (tuple, list)):
13351347
planes = (planes,)
@@ -1356,18 +1368,30 @@ def from_dlpack(
13561368
m0 = _consume_dlpack(planes[0], stream)
13571369
m1 = _consume_dlpack(planes[1], stream)
13581370

1359-
if m0.dl_tensor.device.device_type != kDLCUDA or m1.dl_tensor.device.device_type != kDLCUDA:
1360-
raise TypeError("only CUDA DLPack tensors are supported")
1371+
dev_type0 = m0.dl_tensor.device.device_type
1372+
dev_type1 = m1.dl_tensor.device.device_type
1373+
if dev_type0 != dev_type1:
1374+
raise ValueError("plane tensors must have the same device_type")
1375+
if dev_type0 not in {kDLCUDA, kDLCPU}:
1376+
raise NotImplementedError("only CPU and CUDA DLPack tensors are supported")
13611377

13621378
dev0 = m0.dl_tensor.device.device_id
13631379
dev1 = m1.dl_tensor.device.device_id
13641380
if dev0 != dev1:
13651381
raise ValueError("plane tensors must be on the same CUDA device")
1366-
1367-
if device_id is None:
1368-
device_id = dev0
1369-
elif device_id != dev0:
1370-
raise ValueError("device_id does not match the DLPack tensor device_id")
1382+
if dev_type0 == kDLCUDA:
1383+
if dev0 != dev1:
1384+
raise ValueError("plane tensors must be on the same CUDA device")
1385+
if device_id is None:
1386+
device_id = dev0
1387+
elif device_id != dev0:
1388+
raise ValueError("device_id does not match the DLPack tensor device_id")
1389+
else:
1390+
if device_id not in (None, 0):
1391+
raise ValueError("device_id must be 0 for CPU tensors")
1392+
device_id = 0
1393+
if dev_type0 == kDLCPU and (dev0 != 0 or dev1 != 0):
1394+
raise ValueError("CPU DLPack tensors must have device_id == 0")
13711395

13721396
if (
13731397
m0.dl_tensor.dtype.code != kDLUInt
@@ -1443,16 +1467,24 @@ def from_dlpack(
14431467
uv_linesize = cython.cast(int, uv_pitch_elems * itemsize)
14441468
uv_size = cython.cast(int, uv_linesize * (height // 2))
14451469

1446-
frames_ref = _get_cuda_frames_ctx(device_id, sw_fmt, width, height)
1447-
14481470
frame = alloc_video_frame()
14491471
frame.ptr.width = width
14501472
frame.ptr.height = height
1451-
frame.ptr.format = get_pix_fmt(b"cuda")
1452-
1453-
frame.ptr.hw_frames_ctx = lib.av_buffer_ref(frames_ref)
1454-
if frame.ptr.hw_frames_ctx == cython.NULL:
1455-
raise MemoryError("av_buffer_ref(hw_frames_ctx) failed")
1473+
if dev_type0 == kDLCUDA:
1474+
if primary_ctx is None:
1475+
primary_ctx = True
1476+
if not isinstance(primary_ctx, (bool, int)):
1477+
raise TypeError("primary_ctx must be a bool")
1478+
primary_ctx = bool(primary_ctx)
1479+
1480+
frames_ref = _get_cuda_frames_ctx(device_id, primary_ctx, sw_fmt, width, height)
1481+
1482+
frame.ptr.format = get_pix_fmt(b"cuda")
1483+
frame.ptr.hw_frames_ctx = lib.av_buffer_ref(frames_ref)
1484+
if frame.ptr.hw_frames_ctx == cython.NULL:
1485+
raise MemoryError("av_buffer_ref(hw_frames_ctx) failed")
1486+
else:
1487+
frame.ptr.format = sw_fmt
14561488

14571489
y_ptr = cython.cast(cython.pointer[uint8_t], m0.dl_tensor.data) + cython.cast(
14581490
cython.size_t, m0.dl_tensor.byte_offset

av/video/frame.pyi

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,4 +92,5 @@ class VideoFrame(Frame):
9292
height: int = 0,
9393
stream: int | None = None,
9494
device_id: int | None = None,
95+
primary_ctx: bool = True,
9596
) -> "VideoFrame": ...

av/video/plane.py

Lines changed: 40 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import cython
22
import cython.cimports.libav as lib
33
from cython.cimports.av.buffer import Buffer
4-
from cython.cimports.av.dlpack import DLManagedTensor, kDLCUDA, kDLUInt
4+
from cython.cimports.av.dlpack import DLManagedTensor, kDLCPU, kDLCUDA, kDLUInt
55
from cython.cimports.av.error import err_check
66
from cython.cimports.av.hwcontext import AVHWFramesContext
77
from cython.cimports.av.video.format import get_pix_fmt, get_video_format
@@ -17,6 +17,8 @@
1717
from cython.cimports.libc.stdint import int64_t
1818
from cython.cimports.libc.stdlib import free, malloc
1919

20+
import av._hwdevice_registry as _hwreg
21+
2022

2123
@cython.cclass
2224
class VideoPlane(Plane):
@@ -79,22 +81,44 @@ def __getbuffer__(self, view: cython.pointer[Py_buffer], flags: cython.int):
7981
PyBuffer_FillInfo(view, self, self._buffer_ptr(), self._buffer_size(), 0, flags)
8082

8183
def __dlpack_device__(self):
82-
if not self.frame.ptr.hw_frames_ctx:
83-
raise TypeError("DLPack export is only supported for hardware frames")
84-
if cython.cast(lib.AVPixelFormat, self.frame.ptr.format) != get_pix_fmt(b"cuda"):
85-
raise NotImplementedError("DLPack export is only implemented for CUDA hw frames")
86-
return (kDLCUDA, 0)
84+
if self.frame.ptr.hw_frames_ctx:
85+
if cython.cast(lib.AVPixelFormat, self.frame.ptr.format) != get_pix_fmt(b"cuda"):
86+
raise NotImplementedError("DLPack export is only implemented for CUDA hw frames")
87+
88+
frames_ctx: cython.pointer[AVHWFramesContext] = cython.cast(
89+
cython.pointer[AVHWFramesContext], self.frame.ptr.hw_frames_ctx.data
90+
)
91+
device_id = _hwreg.lookup_cuda_device_id(
92+
cython.cast(cython.size_t, frames_ctx.device_ref.data)
93+
)
94+
return (kDLCUDA, device_id)
95+
96+
return (kDLCPU, 0)
8797

8898
def __dlpack__(self, stream=None):
89-
if not self.frame.ptr.hw_frames_ctx:
90-
raise TypeError("DLPack export is only supported for hardware frames")
91-
if cython.cast(lib.AVPixelFormat, self.frame.ptr.format) != get_pix_fmt(b"cuda"):
92-
raise NotImplementedError("DLPack export is only implemented for CUDA hw frames")
99+
if self.frame.ptr.buf[0] == cython.NULL:
100+
raise TypeError("DLPack export requires a refcounted AVFrame (frame.buf[0] is NULL)")
93101

94-
frames_ctx: cython.pointer[AVHWFramesContext] = cython.cast(
95-
cython.pointer[AVHWFramesContext], self.frame.ptr.hw_frames_ctx.data
96-
)
97-
sw_fmt = frames_ctx.sw_format
102+
device_type: cython.int
103+
device_id: cython.int
104+
sw_fmt: lib.AVPixelFormat
105+
106+
if self.frame.ptr.hw_frames_ctx:
107+
if cython.cast(lib.AVPixelFormat, self.frame.ptr.format) != get_pix_fmt(b"cuda"):
108+
raise NotImplementedError("DLPack export is only implemented for CUDA hw frames")
109+
110+
frames_ctx: cython.pointer[AVHWFramesContext] = cython.cast(
111+
cython.pointer[AVHWFramesContext], self.frame.ptr.hw_frames_ctx.data
112+
)
113+
sw_fmt = frames_ctx.sw_format
114+
device_type = kDLCUDA
115+
device_id = _hwreg.lookup_cuda_device_id(
116+
cython.cast(cython.size_t, frames_ctx.device_ref.data)
117+
)
118+
else:
119+
sw_fmt = cython.cast(lib.AVPixelFormat, self.frame.ptr.format)
120+
device_type = kDLCPU
121+
device_id = 0
98122

99123
line_size = self.line_size
100124
if line_size < 0:
@@ -206,8 +230,8 @@ def __dlpack__(self, stream=None):
206230
raise MemoryError("malloc() failed")
207231

208232
managed.dl_tensor.data = cython.cast(cython.p_void, frame_ref.data[self.index])
209-
managed.dl_tensor.device.device_type = kDLCUDA
210-
managed.dl_tensor.device.device_id = 0
233+
managed.dl_tensor.device.device_type = device_type
234+
managed.dl_tensor.device.device_id = device_id
211235
managed.dl_tensor.ndim = ndim
212236
managed.dl_tensor.dtype.code = kDLUInt
213237
managed.dl_tensor.dtype.bits = bits

include/libavcodec/avcodec.pxd

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -353,7 +353,7 @@ cdef extern from "libavcodec/avcodec.h" nogil:
353353
int64_t pkt_dts
354354
void *opaque
355355
int sample_rate
356-
AVBufferRef *buf[4]
356+
AVBufferRef *buf[8]
357357
AVBufferRef **extended_buf
358358
int nb_extended_buf
359359

0 commit comments

Comments
 (0)