File tree Expand file tree Collapse file tree 1 file changed +4
-1
lines changed
Expand file tree Collapse file tree 1 file changed +4
-1
lines changed Original file line number Diff line number Diff line change @@ -71,6 +71,7 @@ cdef extern from 'dlpack/dlpack.h' nogil:
7171 kDLFloat
7272 kDLBfloat
7373 kDLComplex
74+ kDLBool
7475
7576 ctypedef struct DLDataType:
7677 uint8_t code
@@ -244,7 +245,7 @@ cpdef to_dlpack_capsule(usm_ndarray usm_ary) except+:
244245 dl_tensor.dtype.lanes = < uint16_t> 1
245246 dl_tensor.dtype.bits = < uint8_t> (ary_dt.itemsize * 8 )
246247 if (ary_dtk == " b" ):
247- dl_tensor.dtype.code = < uint8_t> kDLUInt
248+ dl_tensor.dtype.code = < uint8_t> kDLBool
248249 elif (ary_dtk == " u" ):
249250 dl_tensor.dtype.code = < uint8_t> kDLUInt
250251 elif (ary_dtk == " i" ):
@@ -444,6 +445,8 @@ cpdef usm_ndarray from_dlpack_capsule(object py_caps) except +:
444445 ary_dt = np.dtype(" f" + str (element_bytesize))
445446 elif (dlm_tensor.dl_tensor.dtype.code == kDLComplex):
446447 ary_dt = np.dtype(" c" + str (element_bytesize))
448+ elif (dlm_tensor.dl_tensor.dtype.code == kDLBool):
449+ ary_dt = np.dtype(" ?" )
447450 else :
448451 raise BufferError(
449452 " Can not import DLPack tensor with type code {}." .format(
You can’t perform that action at this time.
0 commit comments