@@ -52,6 +52,15 @@ cdef class InternalUSMArrayError(Exception):
5252 pass
5353
5454
55+ cdef object _as_zero_dim_ndarray(object usm_ary):
56+ " Convert size-1 array to NumPy 0d array"
57+ mem_view = dpmem.as_usm_memory(usm_ary)
58+ host_buf = mem_view.copy_to_host()
59+ view = host_buf.view(usm_ary.dtype)
60+ view.shape = tuple ()
61+ return view
62+
63+
5564cdef class usm_ndarray:
5665 """ usm_ndarray(shape, dtype=None, strides=None, buffer="device", \
5766 offset=0, order="C", buffer_ctor_kwargs=dict(), \
@@ -840,9 +849,7 @@ cdef class usm_ndarray:
840849
841850 def __bool__ (self ):
842851 if self .size == 1 :
843- mem_view = dpmem.as_usm_memory(self )
844- host_buf = mem_view.copy_to_host()
845- view = host_buf.view(self .dtype)
852+ view = _as_zero_dim_ndarray(self )
846853 return view.__bool__()
847854
848855 if self .size == 0 :
@@ -857,9 +864,7 @@ cdef class usm_ndarray:
857864
858865 def __float__ (self ):
859866 if self .size == 1 :
860- mem_view = dpmem.as_usm_memory(self )
861- host_buf = mem_view.copy_to_host()
862- view = host_buf.view(self .dtype)
867+ view = _as_zero_dim_ndarray(self )
863868 return view.__float__ ()
864869
865870 raise ValueError (
@@ -868,9 +873,7 @@ cdef class usm_ndarray:
868873
869874 def __complex__ (self ):
870875 if self .size == 1 :
871- mem_view = dpmem.as_usm_memory(self )
872- host_buf = mem_view.copy_to_host()
873- view = host_buf.view(self .dtype)
876+ view = _as_zero_dim_ndarray(self )
874877 return view.__complex__ ()
875878
876879 raise ValueError (
@@ -879,9 +882,7 @@ cdef class usm_ndarray:
879882
880883 def __int__ (self ):
881884 if self .size == 1 :
882- mem_view = dpmem.as_usm_memory(self )
883- host_buf = mem_view.copy_to_host()
884- view = host_buf.view(self .dtype)
885+ view = _as_zero_dim_ndarray(self )
885886 return view.__int__ ()
886887
887888 raise ValueError (
0 commit comments