44import cython
55import cython .cimports .libav as lib
66from cython .cimports .av .dictionary import Dictionary
7- from cython .cimports .av .dlpack import DLManagedTensor , kDLCUDA , kDLUInt
7+ from cython .cimports .av .dlpack import DLManagedTensor , kDLCUDA , kDLUInt , kDLCPU
88from cython .cimports .av .error import err_check
99from cython .cimports .av .hwcontext import (
1010 AVHWFramesContext ,
2323)
2424from cython .cimports .libc .stdint import int64_t , uint8_t
2525
26+ import av ._hwdevice_registry as _hwreg
2627
2728_cuda_device_ctx_cache = {}
2829_cuda_frames_ctx_cache = {}
@@ -67,8 +68,12 @@ def _dlpack_avbuffer_free(
6768 managed .deleter (managed )
6869
6970@cython .cfunc
70- def _get_cuda_device_ctx (device_id : cython .int ) -> cython .pointer [lib .AVBufferRef ]:
71- cached = _cuda_device_ctx_cache .get (device_id )
71+ def _get_cuda_device_ctx (
72+ device_id : cython .int ,
73+ primary_ctx : cython .bint ,
74+ ) -> cython .pointer [lib .AVBufferRef ]:
75+ key = (int (device_id ), int (primary_ctx ))
76+ cached = _cuda_device_ctx_cache .get (key )
7277 if cached is not None :
7378 return cython .cast (
7479 cython .pointer [lib .AVBufferRef ],
@@ -78,7 +83,7 @@ def _get_cuda_device_ctx(device_id: cython.int) -> cython.pointer[lib.AVBufferRe
7883 device_ref : cython .pointer [lib .AVBufferRef ] = cython .NULL
7984 device_bytes = str (device_id ).encode ()
8085 c_device : cython .p_char = device_bytes
81- options : Dictionary = Dictionary ({"primary_ctx" : "1" })
86+ options : Dictionary = Dictionary ({"primary_ctx" : "1" if primary_ctx else "0" })
8287
8388 err_check (
8489 lib .av_hwdevice_ctx_create (
@@ -90,25 +95,31 @@ def _get_cuda_device_ctx(device_id: cython.int) -> cython.pointer[lib.AVBufferRe
9095 )
9196 )
9297
93- _cuda_device_ctx_cache [device_id ] = cython .cast (cython .size_t , device_ref )
98+ _hwreg .register_cuda_hwdevice_data_ptr (
99+ cython .cast (cython .size_t , device_ref .data ),
100+ device_id ,
101+ )
102+
103+ _cuda_device_ctx_cache [key ] = cython .cast (cython .size_t , device_ref )
94104 return device_ref
95105
96106@cython .cfunc
97107def _get_cuda_frames_ctx (
98108 device_id : cython .int ,
109+ primary_ctx : cython .bint ,
99110 sw_fmt : lib .AVPixelFormat ,
100111 width : cython .int ,
101112 height : cython .int ,
102113) -> cython .pointer [lib .AVBufferRef ]:
103- key = (device_id , int (sw_fmt ), int (width ), int (height ))
114+ key = (int ( device_id ), int ( primary_ctx ) , int (sw_fmt ), int (width ), int (height ))
104115 cached = _cuda_frames_ctx_cache .get (key )
105116 if cached is not None :
106117 return cython .cast (
107118 cython .pointer [lib .AVBufferRef ],
108119 cython .cast (cython .size_t , cached ),
109120 )
110121
111- device_ref = _get_cuda_device_ctx (device_id )
122+ device_ref = _get_cuda_device_ctx (device_id , primary_ctx )
112123 frames_ref = av_hwframe_ctx_alloc (device_ref )
113124 if frames_ref == cython .NULL :
114125 raise MemoryError ("av_hwframe_ctx_alloc() failed" )
@@ -1330,6 +1341,7 @@ def from_dlpack(
13301341 height : int = 0 ,
13311342 stream = None ,
13321343 device_id : int | None = None ,
1344+ primary_ctx : bool = True ,
13331345 ):
13341346 if not isinstance (planes , (tuple , list )):
13351347 planes = (planes ,)
@@ -1356,18 +1368,30 @@ def from_dlpack(
13561368 m0 = _consume_dlpack (planes [0 ], stream )
13571369 m1 = _consume_dlpack (planes [1 ], stream )
13581370
1359- if m0 .dl_tensor .device .device_type != kDLCUDA or m1 .dl_tensor .device .device_type != kDLCUDA :
1360- raise TypeError ("only CUDA DLPack tensors are supported" )
1371+ dev_type0 = m0 .dl_tensor .device .device_type
1372+ dev_type1 = m1 .dl_tensor .device .device_type
1373+ if dev_type0 != dev_type1 :
1374+ raise ValueError ("plane tensors must have the same device_type" )
1375+ if dev_type0 not in {kDLCUDA , kDLCPU }:
1376+ raise NotImplementedError ("only CPU and CUDA DLPack tensors are supported" )
13611377
13621378 dev0 = m0 .dl_tensor .device .device_id
13631379 dev1 = m1 .dl_tensor .device .device_id
13641380 if dev0 != dev1 :
13651381 raise ValueError ("plane tensors must be on the same CUDA device" )
1366-
1367- if device_id is None :
1368- device_id = dev0
1369- elif device_id != dev0 :
1370- raise ValueError ("device_id does not match the DLPack tensor device_id" )
1382+ if dev_type0 == kDLCUDA :
1383+ if dev0 != dev1 :
1384+ raise ValueError ("plane tensors must be on the same CUDA device" )
1385+ if device_id is None :
1386+ device_id = dev0
1387+ elif device_id != dev0 :
1388+ raise ValueError ("device_id does not match the DLPack tensor device_id" )
1389+ else :
1390+ if device_id not in (None , 0 ):
1391+ raise ValueError ("device_id must be 0 for CPU tensors" )
1392+ device_id = 0
1393+ if dev_type0 == kDLCPU and (dev0 != 0 or dev1 != 0 ):
1394+ raise ValueError ("CPU DLPack tensors must have device_id == 0" )
13711395
13721396 if (
13731397 m0 .dl_tensor .dtype .code != kDLUInt
@@ -1443,16 +1467,24 @@ def from_dlpack(
14431467 uv_linesize = cython .cast (int , uv_pitch_elems * itemsize )
14441468 uv_size = cython .cast (int , uv_linesize * (height // 2 ))
14451469
1446- frames_ref = _get_cuda_frames_ctx (device_id , sw_fmt , width , height )
1447-
14481470 frame = alloc_video_frame ()
14491471 frame .ptr .width = width
14501472 frame .ptr .height = height
1451- frame .ptr .format = get_pix_fmt (b"cuda" )
1452-
1453- frame .ptr .hw_frames_ctx = lib .av_buffer_ref (frames_ref )
1454- if frame .ptr .hw_frames_ctx == cython .NULL :
1455- raise MemoryError ("av_buffer_ref(hw_frames_ctx) failed" )
1473+ if dev_type0 == kDLCUDA :
1474+ if primary_ctx is None :
1475+ primary_ctx = True
1476+ if not isinstance (primary_ctx , (bool , int )):
1477+ raise TypeError ("primary_ctx must be a bool" )
1478+ primary_ctx = bool (primary_ctx )
1479+
1480+ frames_ref = _get_cuda_frames_ctx (device_id , primary_ctx , sw_fmt , width , height )
1481+
1482+ frame .ptr .format = get_pix_fmt (b"cuda" )
1483+ frame .ptr .hw_frames_ctx = lib .av_buffer_ref (frames_ref )
1484+ if frame .ptr .hw_frames_ctx == cython .NULL :
1485+ raise MemoryError ("av_buffer_ref(hw_frames_ctx) failed" )
1486+ else :
1487+ frame .ptr .format = sw_fmt
14561488
14571489 y_ptr = cython .cast (cython .pointer [uint8_t ], m0 .dl_tensor .data ) + cython .cast (
14581490 cython .size_t , m0 .dl_tensor .byte_offset
0 commit comments