@@ -48,12 +48,13 @@ def dprint(*args):
4848def has_array_interface (x ):
4949 return hasattr (x , array_interface_property )
5050
51+
5152def _get_usm_base (ary ):
5253 ob = ary
5354 while True :
5455 if ob is None :
5556 return None
56- elif hasattr (ob , ' __sycl_usm_array_interface__' ):
57+ elif hasattr (ob , " __sycl_usm_array_interface__" ):
5758 return ob
5859 elif isinstance (ob , np .ndarray ):
5960 ob = ob .base
@@ -92,9 +93,11 @@ def __new__(
9293 dprint ("buffer None new_obj already has sycl_usm" )
9394 else :
9495 dprint ("buffer None new_obj will add sycl_usm" )
95- setattr (new_obj ,
96- array_interface_property ,
97- new_obj ._getter_sycl_usm_array_interface_ ())
96+ setattr (
97+ new_obj ,
98+ array_interface_property ,
99+ new_obj ._getter_sycl_usm_array_interface_ (),
100+ )
98101 return new_obj
99102 # zero copy if buffer is a usm backed array-like thing
100103 elif hasattr (buffer , array_interface_property ):
@@ -113,8 +116,11 @@ def __new__(
113116 dprint ("buffer None new_obj already has sycl_usm" )
114117 else :
115118 dprint ("buffer None new_obj will add sycl_usm" )
116- setattr (new_obj , array_interface_property ,
117- new_obj ._getter_sycl_usm_array_interface_ ())
119+ setattr (
120+ new_obj ,
121+ array_interface_property ,
122+ new_obj ._getter_sycl_usm_array_interface_ (),
123+ )
118124 return new_obj
119125 else :
120126 dprint ("dparray::ndarray __new__ buffer not None and not sycl_usm" )
@@ -144,29 +150,30 @@ def __new__(
144150 dprint ("buffer None new_obj already has sycl_usm" )
145151 else :
146152 dprint ("buffer None new_obj will add sycl_usm" )
147- setattr (new_obj , array_interface_property ,
148- new_obj ._getter_sycl_usm_array_interface_ ())
153+ setattr (
154+ new_obj ,
155+ array_interface_property ,
156+ new_obj ._getter_sycl_usm_array_interface_ (),
157+ )
149158 return new_obj
150159
151-
152160 def _getter_sycl_usm_array_interface_ (self ):
153161 ary_iface = self .__array_interface__
154162 _base = _get_usm_base (self )
155163 if _base is None :
156164 raise TypeError
157165
158- usm_iface = getattr (_base , ' __sycl_usm_array_interface__' , None )
166+ usm_iface = getattr (_base , " __sycl_usm_array_interface__" , None )
159167 if usm_iface is None :
160168 raise TypeError
161169
162- if ( ary_iface [' data' ][0 ] == usm_iface [' data' ][0 ]) :
163- ary_iface [' version' ] = usm_iface [' version' ]
164- ary_iface [' syclobj' ] = usm_iface [' syclobj' ]
170+ if ary_iface [" data" ][0 ] == usm_iface [" data" ][0 ]:
171+ ary_iface [" version" ] = usm_iface [" version" ]
172+ ary_iface [" syclobj" ] = usm_iface [" syclobj" ]
165173 else :
166174 raise TypeError
167175 return ary_iface
168176
169-
170177 def __array_finalize__ (self , obj ):
171178 dprint ("__array_finalize__:" , obj , hex (id (obj )), type (obj ))
172179 # When called from the explicit constructor, obj is None
@@ -191,7 +198,6 @@ def __array_finalize__(self, obj):
191198 "Non-USM allocated ndarray can not viewed as a USM-allocated one without a copy"
192199 )
193200
194-
195201 # Tell Numba to not treat this type just like a NumPy ndarray but to propagate its type.
196202 # This way it will use the custom dparray allocator.
197203 __numba_no_subtype_ndarray__ = True
0 commit comments