@@ -539,34 +539,32 @@ sort_over_work_group_contig_impl(sycl::queue &q,
539539 sycl::group_barrier (it.get_group ());
540540
541541 bool data_in_temp = false ;
542- size_t sorted_size = 1 ;
543- while (true ) {
544- const size_t nelems_sorted_so_far = sorted_size * chunk;
545- if (nelems_sorted_so_far < wg_chunk_size) {
546- const size_t q = (lid / sorted_size);
547- const size_t start_1 =
548- sycl::min (2 * nelems_sorted_so_far * q, wg_chunk_size);
549- const size_t end_1 = sycl::min (
550- start_1 + nelems_sorted_so_far, wg_chunk_size);
551- const size_t end_2 =
552- sycl::min (end_1 + nelems_sorted_so_far, wg_chunk_size);
553- const size_t offset = chunk * (lid - q * sorted_size);
554-
555- if (data_in_temp) {
556- merge_impl (offset, scratch_space, work_space, start_1,
557- end_1, end_2, start_1, comp, chunk);
558- }
559- else {
560- merge_impl (offset, work_space, scratch_space, start_1,
561- end_1, end_2, start_1, comp, chunk);
562- }
563- sycl::group_barrier (it.get_group ());
564-
565- data_in_temp = !data_in_temp;
566- sorted_size *= 2 ;
542+ size_t n_chunks_merged = 1 ;
543+
544+ // merge chunk while n_chunks_merged * chunk < wg_chunk_size
545+ const size_t max_chunks_merged = 1 + ((wg_chunk_size - 1 ) / chunk);
546+ for (; n_chunks_merged < max_chunks_merged;
547+ data_in_temp = !data_in_temp, n_chunks_merged *= 2 )
548+ {
549+ const size_t nelems_sorted_so_far = n_chunks_merged * chunk;
550+ const size_t q = (lid / n_chunks_merged);
551+ const size_t start_1 =
552+ sycl::min (2 * nelems_sorted_so_far * q, wg_chunk_size);
553+ const size_t end_1 =
554+ sycl::min (start_1 + nelems_sorted_so_far, wg_chunk_size);
555+ const size_t end_2 =
556+ sycl::min (end_1 + nelems_sorted_so_far, wg_chunk_size);
557+ const size_t offset = chunk * (lid - q * n_chunks_merged);
558+
559+ if (data_in_temp) {
560+ merge_impl (offset, scratch_space, work_space, start_1,
561+ end_1, end_2, start_1, comp, chunk);
562+ }
563+ else {
564+ merge_impl (offset, work_space, scratch_space, start_1,
565+ end_1, end_2, start_1, comp, chunk);
567566 }
568- else
569- break ;
567+ sycl::group_barrier (it.get_group ());
570568 }
571569
572570 const auto &out_src = (data_in_temp) ? scratch_space : work_space;
0 commit comments