@@ -340,22 +340,29 @@ py_searchsorted(const dpctl::tensor::usm_ndarray &hay,
340340 int simplified_nd = needles_nd;
341341
342342 using shT = std::vector<py::ssize_t >;
343-
344343 shT simplified_common_shape;
345344 shT simplified_needles_strides;
346345 shT simplified_positions_strides;
347346 py::ssize_t needles_offset (0 );
348347 py::ssize_t positions_offset (0 );
349348
350- dpctl::tensor::py_internal::simplify_iteration_space (
351- // modified by refernce
352- simplified_nd,
353- // read-only inputs
354- needles_shape_ptr, needles_strides, positions_strides,
355- // output, modified by reference
356- simplified_common_shape, simplified_needles_strides,
357- simplified_positions_strides, needles_offset, positions_offset);
358-
349+ if (simplified_nd == 0 ) {
350+ // needles and positions have same nd
351+ simplified_nd = 1 ;
352+ simplified_common_shape.push_back (1 );
353+ simplified_needles_strides.push_back (0 );
354+ simplified_positions_strides.push_back (0 );
355+ }
356+ else {
357+ dpctl::tensor::py_internal::simplify_iteration_space (
358+ // modified by refernce
359+ simplified_nd,
360+ // read-only inputs
361+ needles_shape_ptr, needles_strides, positions_strides,
362+ // output, modified by reference
363+ simplified_common_shape, simplified_needles_strides,
364+ simplified_positions_strides, needles_offset, positions_offset);
365+ }
359366 std::vector<sycl::event> host_task_events;
360367 host_task_events.reserve (2 );
361368
0 commit comments