@@ -360,114 +360,6 @@ class SerializeLoops : public IRMutator {
360360 }
361361};
362362
363- // Wrap a vectorized predicate around a Load/Store node.
364- class PredicateLoadStore : public IRMutator {
365- string var;
366- Expr vector_predicate;
367- int lanes;
368- bool valid = true ;
369- bool vectorized = false ;
370-
371- using IRMutator::visit;
372-
373- Expr merge_predicate (Expr pred, const Expr &new_pred) {
374- if (pred.type ().lanes () == new_pred.type ().lanes ()) {
375- Expr res = simplify (pred && new_pred);
376- return res;
377- }
378- valid = false ;
379- return pred;
380- }
381-
382- Stmt visit (const Atomic *op) override {
383- // We don't support codegen for vectorized predicated atomic stores, so
384- // just bail out.
385- valid = false ;
386- return op;
387- }
388-
389- Expr visit (const Load *op) override {
390- valid = valid && ((op->predicate .type ().lanes () == lanes) || (op->predicate .type ().is_scalar () && !expr_uses_var (op->index , var)));
391- if (!valid) {
392- return op;
393- }
394-
395- Expr predicate, index;
396- if (!op->index .type ().is_scalar ()) {
397- internal_assert (op->predicate .type ().lanes () == lanes);
398- internal_assert (op->index .type ().lanes () == lanes);
399-
400- predicate = mutate (op->predicate );
401- index = mutate (op->index );
402- } else if (expr_uses_var (op->index , var)) {
403- predicate = mutate (Broadcast::make (op->predicate , lanes));
404- index = mutate (Broadcast::make (op->index , lanes));
405- } else {
406- return IRMutator::visit (op);
407- }
408-
409- predicate = merge_predicate (predicate, vector_predicate);
410- if (!valid) {
411- return op;
412- }
413- vectorized = true ;
414- return Load::make (op->type , op->name , index, op->image , op->param , predicate, op->alignment );
415- }
416-
417- Stmt visit (const Store *op) override {
418- valid = valid && ((op->predicate .type ().lanes () == lanes) || (op->predicate .type ().is_scalar () && !expr_uses_var (op->index , var)));
419- if (!valid) {
420- return op;
421- }
422-
423- Expr predicate, value, index;
424- if (!op->index .type ().is_scalar ()) {
425- internal_assert (op->predicate .type ().lanes () == lanes);
426- internal_assert (op->index .type ().lanes () == lanes);
427- internal_assert (op->value .type ().lanes () == lanes);
428-
429- predicate = mutate (op->predicate );
430- value = mutate (op->value );
431- index = mutate (op->index );
432- } else if (expr_uses_var (op->index , var)) {
433- predicate = mutate (Broadcast::make (op->predicate , lanes));
434- value = mutate (Broadcast::make (op->value , lanes));
435- index = mutate (Broadcast::make (op->index , lanes));
436- } else {
437- return IRMutator::visit (op);
438- }
439-
440- predicate = merge_predicate (predicate, vector_predicate);
441- if (!valid) {
442- return op;
443- }
444- vectorized = true ;
445- return Store::make (op->name , value, index, op->param , predicate, op->alignment );
446- }
447-
448- Expr visit (const Call *op) override {
449- // We should not vectorize calls with side-effects
450- valid = valid && op->is_pure ();
451- return IRMutator::visit (op);
452- }
453-
454- Expr visit (const VectorReduce *op) override {
455- // We can't predicate vector reductions.
456- valid = valid && is_const_one (vector_predicate);
457- return op;
458- }
459-
460- public:
461- PredicateLoadStore (string v, const Expr &vpred)
462- : var(std::move(v)), vector_predicate(vpred), lanes(vpred.type().lanes()) {
463- internal_assert (lanes > 1 );
464- }
465-
466- bool is_vectorized () const {
467- return valid && vectorized;
468- }
469- };
470-
471363Stmt vectorize_statement (const Stmt &stmt);
472364
473365struct VectorizedVar {
@@ -859,25 +751,6 @@ class VectorSubs : public IRMutator {
859751 // which would mean control flow divergence within the
860752 // SIMD lanes.
861753
862- bool vectorize_predicate = true ;
863-
864- Stmt predicated_stmt;
865- if (vectorize_predicate) {
866- PredicateLoadStore p (vectorized_vars.front ().name , cond);
867- predicated_stmt = p.mutate (then_case);
868- vectorize_predicate = p.is_vectorized ();
869- }
870- if (vectorize_predicate && else_case.defined ()) {
871- PredicateLoadStore p (vectorized_vars.front ().name , !cond);
872- predicated_stmt = Block::make (predicated_stmt, p.mutate (else_case));
873- vectorize_predicate = p.is_vectorized ();
874- }
875-
876- debug (4 ) << " IfThenElse should vectorize predicate "
877- << " ? " << vectorize_predicate << " ; cond: " << cond << " \n " ;
878- debug (4 ) << " Predicated stmt:\n "
879- << predicated_stmt << " \n " ;
880-
881754 // First check if the condition is marked as likely.
882755 if (const Call *likely = Call::as_intrinsic (cond, {Call::likely, Call::likely_if_innermost})) {
883756
@@ -892,48 +765,31 @@ class VectorSubs : public IRMutator {
892765 all_true = Call::make (Bool (), likely->name ,
893766 {all_true}, Call::PureIntrinsic);
894767
895- if (!vectorize_predicate) {
896- // We should strip the likelies from the case
897- // that's going to scalarize, because it's no
898- // longer likely.
899- Stmt without_likelies =
900- IfThenElse::make (unwrap_tags (op->condition ),
901- op->then_case , op->else_case );
902-
903- // scalarize() will put back all vectorized loops around the statement as serial,
904- // but it still may happen that there are vectorized loops inside of the statement
905- // itself which we may want to handle. All the context is invalid though, so
906- // we just start anew for this specific statement.
907- Stmt scalarized = scalarize (without_likelies, false );
908- scalarized = vectorize_statement (scalarized);
909- Stmt stmt =
910- IfThenElse::make (all_true,
911- then_case,
912- scalarized);
913- debug (4 ) << " ...With all_true likely: \n "
914- << stmt << " \n " ;
915- return stmt;
916- } else {
917- Stmt stmt =
918- IfThenElse::make (all_true,
919- then_case,
920- predicated_stmt);
921- debug (4 ) << " ...Predicated IfThenElse: \n "
922- << stmt << " \n " ;
923- return stmt;
924- }
768+ // We should strip the likelies from the case
769+ // that's going to scalarize, because it's no
770+ // longer likely.
771+ Stmt without_likelies =
772+ IfThenElse::make (unwrap_tags (op->condition ),
773+ op->then_case , op->else_case );
774+
775+ // scalarize() will put back all vectorized loops around the statement as serial,
776+ // but it still may happen that there are vectorized loops inside of the statement
777+ // itself which we may want to handle. All the context is invalid though, so
778+ // we just start anew for this specific statement.
779+ Stmt scalarized = scalarize (without_likelies, false );
780+ scalarized = vectorize_statement (scalarized);
781+ Stmt stmt =
782+ IfThenElse::make (all_true,
783+ then_case,
784+ scalarized);
785+ debug (4 ) << " ...With all_true likely: \n "
786+ << stmt << " \n " ;
787+ return stmt;
925788 } else {
926789 // It's some arbitrary vector condition.
927- if (!vectorize_predicate) {
928- debug (4 ) << " ...Scalarizing vector predicate: \n "
929- << Stmt (op) << " \n " ;
930- return scalarize (op);
931- } else {
932- Stmt stmt = predicated_stmt;
933- debug (4 ) << " ...Predicated IfThenElse: \n "
934- << stmt << " \n " ;
935- return stmt;
936- }
790+ debug (4 ) << " ...Scalarizing vector predicate: \n "
791+ << Stmt (op) << " \n " ;
792+ return scalarize (op);
937793 }
938794 } else {
939795 // It's an if statement on a scalar, we're ok to vectorize the innards.
0 commit comments