@@ -71,6 +71,7 @@ cdef extern from 'dlpack/dlpack.h' nogil:
7171 kDLFloat
7272 kDLBfloat
7373 kDLComplex
74+ kDLBool
7475
7576 ctypedef struct DLDataType:
7677 uint8_t code
@@ -244,7 +245,7 @@ cpdef to_dlpack_capsule(usm_ndarray usm_ary) except+:
244245 dl_tensor.dtype.lanes = < uint16_t> 1
245246 dl_tensor.dtype.bits = < uint8_t> (ary_dt.itemsize * 8 )
246247 if (ary_dtk == " b" ):
247- dl_tensor.dtype.code = < uint8_t> kDLUInt
248+ dl_tensor.dtype.code = < uint8_t> kDLBool
248249 elif (ary_dtk == " u" ):
249250 dl_tensor.dtype.code = < uint8_t> kDLUInt
250251 elif (ary_dtk == " i" ):
@@ -311,14 +312,17 @@ cpdef usm_ndarray from_dlpack_capsule(object py_caps) except +:
311312 cdef DLManagedTensor * dlm_tensor = NULL
312313 cdef bytes usm_type
313314 cdef size_t sz = 1
315+ cdef size_t alloc_sz = 1
314316 cdef int i
315317 cdef int device_id = - 1
316318 cdef int element_bytesize = 0
317319 cdef Py_ssize_t offset_min = 0
318320 cdef Py_ssize_t offset_max = 0
319- cdef int64_t stride_i
320321 cdef char * mem_ptr = NULL
322+ cdef Py_ssize_t mem_ptr_delta = 0
321323 cdef Py_ssize_t element_offset = 0
324+ cdef int64_t stride_i = - 1
325+ cdef int64_t shape_i = - 1
322326
323327 if not cpython.PyCapsule_IsValid(py_caps, ' dltensor' ):
324328 if cpython.PyCapsule_IsValid(py_caps, ' used_dltensor' ):
@@ -370,22 +374,22 @@ cpdef usm_ndarray from_dlpack_capsule(object py_caps) except +:
370374 raise BufferError(
371375 " Can not import DLPack tensor with lanes != 1"
372376 )
377+ offset_min = 0
373378 if dlm_tensor.dl_tensor.strides is NULL :
374379 for i in range (dlm_tensor.dl_tensor.ndim):
375380 sz = sz * dlm_tensor.dl_tensor.shape[i]
381+ offset_max = sz - 1
376382 else :
377- offset_min = 0
378383 offset_max = 0
379384 for i in range (dlm_tensor.dl_tensor.ndim):
380385 stride_i = dlm_tensor.dl_tensor.strides[i]
381- if stride_i > 0 :
382- offset_max = offset_max + stride_i * (
383- dlm_tensor.dl_tensor.shape[i] - 1
384- )
385- else :
386- offset_min = offset_min + stride_i * (
387- dlm_tensor.dl_tensor.shape[i] - 1
388- )
386+ shape_i = dlm_tensor.dl_tensor.shape[i]
387+ if shape_i > 1 :
388+ shape_i -= 1
389+ if stride_i > 0 :
390+ offset_max = offset_max + stride_i * shape_i
391+ else :
392+ offset_min = offset_min + stride_i * shape_i
389393 sz = offset_max - offset_min + 1
390394 if sz == 0 :
391395 sz = 1
@@ -401,14 +405,29 @@ cpdef usm_ndarray from_dlpack_capsule(object py_caps) except +:
401405 if dlm_tensor.dl_tensor.data is NULL :
402406 usm_mem = dpmem.MemoryUSMDevice(sz, q)
403407 else :
404- mem_ptr = < char * > dlm_tensor.dl_tensor.data + dlm_tensor.dl_tensor.byte_offset
405- mem_ptr = mem_ptr - (element_offset * element_bytesize)
406- usm_mem = c_dpmem._Memory.create_from_usm_pointer_size_qref(
408+ mem_ptr_delta = dlm_tensor.dl_tensor.byte_offset - (
409+ element_offset * element_bytesize
410+ )
411+ mem_ptr = < char * > dlm_tensor.dl_tensor.data
412+ alloc_sz = dlm_tensor.dl_tensor.byte_offset + < uint64_t> (
413+ (offset_max + 1 ) * element_bytesize)
414+ tmp = c_dpmem._Memory.create_from_usm_pointer_size_qref(
407415 < DPCTLSyclUSMRef> mem_ptr,
408- sz ,
416+ max (alloc_sz, < uint64_t > element_bytesize) ,
409417 (< c_dpctl.SyclQueue> q).get_queue_ref(),
410418 memory_owner = dlm_holder
411419 )
420+ if mem_ptr_delta == 0 :
421+ usm_mem = tmp
422+ else :
423+ alloc_sz = dlm_tensor.dl_tensor.byte_offset + < uint64_t> (
424+ (offset_max * element_bytesize + mem_ptr_delta))
425+ usm_mem = c_dpmem._Memory.create_from_usm_pointer_size_qref(
426+ < DPCTLSyclUSMRef> (mem_ptr + (element_bytesize - mem_ptr_delta)),
427+ max (alloc_sz, < uint64_t> element_bytesize),
428+ (< c_dpctl.SyclQueue> q).get_queue_ref(),
429+ memory_owner = tmp
430+ )
412431 py_shape = list ()
413432 for i in range (dlm_tensor.dl_tensor.ndim):
414433 py_shape.append(dlm_tensor.dl_tensor.shape[i])
@@ -426,8 +445,10 @@ cpdef usm_ndarray from_dlpack_capsule(object py_caps) except +:
426445 ary_dt = np.dtype(" f" + str (element_bytesize))
427446 elif (dlm_tensor.dl_tensor.dtype.code == kDLComplex):
428447 ary_dt = np.dtype(" c" + str (element_bytesize))
448+ elif (dlm_tensor.dl_tensor.dtype.code == kDLBool):
449+ ary_dt = np.dtype(" ?" )
429450 else :
430- raise ValueError (
451+ raise BufferError (
431452 " Can not import DLPack tensor with type code {}." .format(
432453 < object > dlm_tensor.dl_tensor.dtype.code
433454 )
@@ -441,7 +462,7 @@ cpdef usm_ndarray from_dlpack_capsule(object py_caps) except +:
441462 )
442463 return res_ary
443464 else :
444- raise ValueError (
465+ raise BufferError (
445466 " The DLPack tensor resides on unsupported device."
446467 )
447468
0 commit comments