@@ -49,6 +49,19 @@ def dprint(*args):
4949def has_array_interface (x ):
5050 return hasattr (x , array_interface_property )
5151
52+ def _get_usm_base (ary ):
53+ ob = ary
54+ while True :
55+ if ob is None :
56+ return None
57+ elif hasattr (ob , '__sycl_usm_array_interface__' ):
58+ return ob
59+ elif isinstance (ob , np .ndarray ):
60+ ob = ob .base
61+ elif isinstance (ob , memoryview ):
62+ ob = ob .obj
63+ else :
64+ return None
5265
5366class ndarray (np .ndarray ):
5467 """
@@ -80,7 +93,9 @@ def __new__(
8093 dprint ("buffer None new_obj already has sycl_usm" )
8194 else :
8295 dprint ("buffer None new_obj will add sycl_usm" )
83- setattr (new_obj , array_interface_property , {})
96+ setattr (new_obj ,
97+ array_interface_property ,
98+ new_obj ._getter_sycl_usm_array_interface_ ())
8499 return new_obj
85100 # zero copy if buffer is a usm backed array-like thing
86101 elif hasattr (buffer , array_interface_property ):
@@ -99,7 +114,8 @@ def __new__(
99114 dprint ("buffer None new_obj already has sycl_usm" )
100115 else :
101116 dprint ("buffer None new_obj will add sycl_usm" )
102- setattr (new_obj , array_interface_property , {})
117+ setattr (new_obj , array_interface_property ,
118+ new_obj ._getter_sycl_usm_array_interface_ ())
103119 return new_obj
104120 else :
105121 dprint ("dparray::ndarray __new__ buffer not None and not sycl_usm" )
@@ -129,9 +145,29 @@ def __new__(
129145 dprint ("buffer None new_obj already has sycl_usm" )
130146 else :
131147 dprint ("buffer None new_obj will add sycl_usm" )
132- setattr (new_obj , array_interface_property , {})
148+ setattr (new_obj , array_interface_property ,
149+ new_obj ._getter_sycl_usm_array_interface_ ())
133150 return new_obj
134151
152+
153+ def _getter_sycl_usm_array_interface_ (self ):
154+ ary_iface = self .__array_interface__
155+ _base = _get_usm_base (self )
156+ if _base is None :
157+ raise TypeError
158+
159+ usm_iface = getattr (_base , '__sycl_usm_array_interface__' , None )
160+ if usm_iface is None :
161+ raise TypeError
162+
163+ if (ary_iface ['data' ][0 ] == usm_iface ['data' ][0 ]):
164+ ary_iface ['version' ] = usm_iface ['version' ]
165+ ary_iface ['syclobj' ] = usm_iface ['syclobj' ]
166+ else :
167+ raise TypeError
168+ return ary_iface
169+
170+
135171 def __array_finalize__ (self , obj ):
136172 dprint ("__array_finalize__:" , obj , hex (id (obj )), type (obj ))
137173 # When called from the explicit constructor, obj is None
@@ -156,6 +192,7 @@ def __array_finalize__(self, obj):
156192 "Non-USM allocated ndarray can not viewed as a USM-allocated one without a copy"
157193 )
158194
195+
159196 # Tell Numba to not treat this type just like a NumPy ndarray but to propagate its type.
160197 # This way it will use the custom dparray allocator.
161198 __numba_no_subtype_ndarray__ = True
0 commit comments