@@ -408,9 +408,8 @@ int simplify_iteration_stride(const int nd,
408408
409409 The new shape and new strides, as well as the offset
410410 `(new_shape, new_strides1, disp1, new_stride2, disp2)` are such that
411- iterating over them will traverse the same pairs of elements, possibly in
412- different order.
413-
411+ iterating over them will traverse the same set of pairs of elements,
412+ possibly in a different order.
414413 */
415414template <class ShapeTy , class StridesTy >
416415int simplify_iteration_two_strides (const int nd,
@@ -447,7 +446,7 @@ int simplify_iteration_two_strides(const int nd,
447446 auto str1_p = strides1[p];
448447 auto str2_p = strides2[p];
449448 shape_w.push_back (sh_p);
450- if (str1_p < 0 && str2_p < 0 ) {
449+ if (str1_p <= 0 && str2_p <= 0 && std::min (str1_p, str2_p) < 0 ) {
451450 disp1 += str1_p * (sh_p - 1 );
452451 str1_p = -str1_p;
453452 disp2 += str2_p * (sh_p - 1 );
@@ -468,7 +467,7 @@ int simplify_iteration_two_strides(const int nd,
468467 StridesTy jump1 = strides1_w[i] - (shape_w[i + 1 ] - 1 ) * str1;
469468 StridesTy jump2 = strides2_w[i] - (shape_w[i + 1 ] - 1 ) * str2;
470469
471- if (jump1 == str1 and jump2 == str2) {
470+ if (jump1 == str1 && jump2 == str2) {
472471 changed = true ;
473472 shape_w[i] *= shape_w[i + 1 ];
474473 for (int j = i; j < nd_; ++j) {
@@ -540,3 +539,148 @@ contract_iter2(vecT shape, vecT strides1, vecT strides2)
540539 out_strides2.resize (nd);
541540 return std::make_tuple (out_shape, out_strides1, disp1, out_strides2, disp2);
542541}
542+
543+ /*
544+ For purposes of iterating over pairs of elements of three arrays
545+ with `shape` and strides `strides1`, `strides2`, `strides3` given as
546+ pointers `simplify_iteration_three_strides(nd, shape_ptr, strides1_ptr,
547+ strides2_ptr, strides3_ptr, disp1, disp2, disp3)`
548+ may modify memory and returns new length of these arrays.
549+
550+ The new shape and new strides, as well as the offset
551+ `(new_shape, new_strides1, disp1, new_stride2, disp2, new_stride3, disp3)`
552+ are such that iterating over them will traverse the same set of tuples of
553+ elements, possibly in a different order.
554+ */
555+ template <class ShapeTy , class StridesTy >
556+ int simplify_iteration_three_strides (const int nd,
557+ ShapeTy *shape,
558+ StridesTy *strides1,
559+ StridesTy *strides2,
560+ StridesTy *strides3,
561+ StridesTy &disp1,
562+ StridesTy &disp2,
563+ StridesTy &disp3)
564+ {
565+ disp1 = std::ptrdiff_t (0 );
566+ disp2 = std::ptrdiff_t (0 );
567+ if (nd < 2 )
568+ return nd;
569+
570+ std::vector<int > pos (nd);
571+ std::iota (pos.begin (), pos.end (), 0 );
572+
573+ std::stable_sort (
574+ pos.begin (), pos.end (), [&strides1, &shape](int i1, int i2) {
575+ auto abs_str1 = (strides1[i1] < 0 ) ? -strides1[i1] : strides1[i1];
576+ auto abs_str2 = (strides1[i2] < 0 ) ? -strides1[i2] : strides1[i2];
577+ return (abs_str1 > abs_str2) ||
578+ (abs_str1 == abs_str2 && shape[i1] > shape[i2]);
579+ });
580+
581+ std::vector<ShapeTy> shape_w;
582+ std::vector<StridesTy> strides1_w;
583+ std::vector<StridesTy> strides2_w;
584+ std::vector<StridesTy> strides3_w;
585+
586+ bool contractable = true ;
587+ for (int i = 0 ; i < nd; ++i) {
588+ auto p = pos[i];
589+ auto sh_p = shape[p];
590+ auto str1_p = strides1[p];
591+ auto str2_p = strides2[p];
592+ auto str3_p = strides3[p];
593+ shape_w.push_back (sh_p);
594+ if (str1_p <= 0 && str2_p <= 0 && str3_p <= 0 &&
595+ std::min (std::min (str1_p, str2_p), str3_p) < 0 )
596+ {
597+ disp1 += str1_p * (sh_p - 1 );
598+ str1_p = -str1_p;
599+ disp2 += str2_p * (sh_p - 1 );
600+ str2_p = -str2_p;
601+ disp3 += str3_p * (sh_p - 1 );
602+ str3_p = -str3_p;
603+ }
604+ if (str1_p < 0 || str2_p < 0 || str3_p < 0 ) {
605+ contractable = false ;
606+ }
607+ strides1_w.push_back (str1_p);
608+ strides2_w.push_back (str2_p);
609+ strides3_w.push_back (str3_p);
610+ }
611+ int nd_ = nd;
612+ while (contractable) {
613+ bool changed = false ;
614+ for (int i = 0 ; i + 1 < nd_; ++i) {
615+ StridesTy str1 = strides1_w[i + 1 ];
616+ StridesTy str2 = strides2_w[i + 1 ];
617+ StridesTy str3 = strides3_w[i + 1 ];
618+ StridesTy jump1 = strides1_w[i] - (shape_w[i + 1 ] - 1 ) * str1;
619+ StridesTy jump2 = strides2_w[i] - (shape_w[i + 1 ] - 1 ) * str2;
620+ StridesTy jump3 = strides3_w[i] - (shape_w[i + 1 ] - 1 ) * str3;
621+
622+ if (jump1 == str1 && jump2 == str2 && jump3 == str3) {
623+ changed = true ;
624+ shape_w[i] *= shape_w[i + 1 ];
625+ for (int j = i; j < nd_; ++j) {
626+ strides1_w[j] = strides1_w[j + 1 ];
627+ }
628+ for (int j = i; j < nd_; ++j) {
629+ strides2_w[j] = strides2_w[j + 1 ];
630+ }
631+ for (int j = i; j < nd_; ++j) {
632+ strides3_w[j] = strides3_w[j + 1 ];
633+ }
634+ for (int j = i + 1 ; j + 1 < nd_; ++j) {
635+ shape_w[j] = shape_w[j + 1 ];
636+ }
637+ --nd_;
638+ break ;
639+ }
640+ }
641+ if (!changed)
642+ break ;
643+ }
644+ for (int i = 0 ; i < nd_; ++i) {
645+ shape[i] = shape_w[i];
646+ }
647+ for (int i = 0 ; i < nd_; ++i) {
648+ strides1[i] = strides1_w[i];
649+ }
650+ for (int i = 0 ; i < nd_; ++i) {
651+ strides2[i] = strides2_w[i];
652+ }
653+ for (int i = 0 ; i < nd_; ++i) {
654+ strides3[i] = strides3_w[i];
655+ }
656+
657+ return nd_;
658+ }
659+
660+ template <typename T, class Error , typename vecT = std::vector<T>>
661+ std::tuple<vecT, vecT, T, vecT, T, vecT, T>
662+ contract_iter3 (vecT shape, vecT strides1, vecT strides2, vecT strides3)
663+ {
664+ const size_t dim = shape.size ();
665+ if (dim != strides1.size () || dim != strides2.size () ||
666+ dim != strides3.size ()) {
667+ throw Error (" Shape and strides must be of equal size." );
668+ }
669+ vecT out_shape = shape;
670+ vecT out_strides1 = strides1;
671+ vecT out_strides2 = strides2;
672+ vecT out_strides3 = strides3;
673+ T disp1 (0 );
674+ T disp2 (0 );
675+ T disp3 (0 );
676+
677+ int nd = simplify_iteration_three_strides (
678+ dim, out_shape.data (), out_strides1.data (), out_strides2.data (),
679+ out_strides3.data (), disp1, disp2, disp3);
680+ out_shape.resize (nd);
681+ out_strides1.resize (nd);
682+ out_strides2.resize (nd);
683+ out_strides3.resize (nd);
684+ return std::make_tuple (out_shape, out_strides1, disp1, out_strides2, disp2,
685+ out_strides3, disp3);
686+ }
0 commit comments