@@ -428,11 +428,19 @@ int simplify_iteration_two_strides(const int nd,
428428 std::iota (pos.begin (), pos.end (), 0 );
429429
430430 std::stable_sort (
431- pos.begin (), pos.end (), [&strides1, &shape](int i1, int i2) {
432- auto abs_str1 = (strides1[i1] < 0 ) ? -strides1[i1] : strides1[i1];
433- auto abs_str2 = (strides1[i2] < 0 ) ? -strides1[i2] : strides1[i2];
434- return (abs_str1 > abs_str2) ||
435- (abs_str1 == abs_str2 && shape[i1] > shape[i2]);
431+ pos.begin (), pos.end (), [&strides1, &strides2, &shape](int i1, int i2) {
432+ auto abs_str1_i1 =
433+ (strides1[i1] < 0 ) ? -strides1[i1] : strides1[i1];
434+ auto abs_str1_i2 =
435+ (strides1[i2] < 0 ) ? -strides1[i2] : strides1[i2];
436+ auto abs_str2_i1 =
437+ (strides2[i1] < 0 ) ? -strides2[i1] : strides2[i1];
438+ auto abs_str2_i2 =
439+ (strides2[i2] < 0 ) ? -strides2[i2] : strides2[i2];
440+ return (abs_str1_i1 > abs_str1_i2) ||
441+ (abs_str1_i1 == abs_str1_i2 &&
442+ (abs_str2_i1 > abs_str2_i2 ||
443+ (abs_str2_i1 == abs_str2_i2 && shape[i1] > shape[i2])));
436444 });
437445
438446 std::vector<ShapeTy> shape_w;
@@ -458,6 +466,7 @@ int simplify_iteration_two_strides(const int nd,
458466 strides1_w.push_back (str1_p);
459467 strides2_w.push_back (str2_p);
460468 }
469+
461470 int nd_ = nd;
462471 while (contractable) {
463472 bool changed = false ;
@@ -570,13 +579,28 @@ int simplify_iteration_three_strides(const int nd,
570579 std::vector<int > pos (nd);
571580 std::iota (pos.begin (), pos.end (), 0 );
572581
573- std::stable_sort (
574- pos.begin (), pos.end (), [&strides1, &shape](int i1, int i2) {
575- auto abs_str1 = (strides1[i1] < 0 ) ? -strides1[i1] : strides1[i1];
576- auto abs_str2 = (strides1[i2] < 0 ) ? -strides1[i2] : strides1[i2];
577- return (abs_str1 > abs_str2) ||
578- (abs_str1 == abs_str2 && shape[i1] > shape[i2]);
579- });
582+ std::stable_sort (pos.begin (), pos.end (),
583+ [&strides1, &strides2, &strides3, &shape](int i1, int i2) {
584+ auto abs_str1_i1 =
585+ (strides1[i1] < 0 ) ? -strides1[i1] : strides1[i1];
586+ auto abs_str1_i2 =
587+ (strides1[i2] < 0 ) ? -strides1[i2] : strides1[i2];
588+ auto abs_str2_i1 =
589+ (strides2[i1] < 0 ) ? -strides2[i1] : strides2[i1];
590+ auto abs_str2_i2 =
591+ (strides2[i2] < 0 ) ? -strides2[i2] : strides2[i2];
592+ auto abs_str3_i1 =
593+ (strides3[i1] < 0 ) ? -strides3[i1] : strides3[i1];
594+ auto abs_str3_i2 =
595+ (strides3[i2] < 0 ) ? -strides3[i2] : strides3[i2];
596+ return (abs_str1_i1 > abs_str1_i2) ||
597+ ((abs_str1_i1 == abs_str1_i2) &&
598+ ((abs_str2_i1 > abs_str2_i2) ||
599+ ((abs_str2_i1 == abs_str2_i2) &&
600+ ((abs_str3_i1 > abs_str3_i2) ||
601+ ((abs_str3_i1 == abs_str3_i2) &&
602+ (shape[i1] > shape[i2]))))));
603+ });
580604
581605 std::vector<ShapeTy> shape_w;
582606 std::vector<StridesTy> strides1_w;
0 commit comments