Skip to content

Commit 545a819

Browse files
committed
Don't predicate inside VectorSubs
1 parent a2ba86e commit 545a819

File tree

1 file changed

+23
-167
lines changed

1 file changed

+23
-167
lines changed

src/VectorizeLoops.cpp

Lines changed: 23 additions & 167 deletions
Original file line numberDiff line numberDiff line change
@@ -360,114 +360,6 @@ class SerializeLoops : public IRMutator {
360360
}
361361
};
362362

363-
// Wrap a vectorized predicate around a Load/Store node.
364-
class PredicateLoadStore : public IRMutator {
365-
string var;
366-
Expr vector_predicate;
367-
int lanes;
368-
bool valid = true;
369-
bool vectorized = false;
370-
371-
using IRMutator::visit;
372-
373-
Expr merge_predicate(Expr pred, const Expr &new_pred) {
374-
if (pred.type().lanes() == new_pred.type().lanes()) {
375-
Expr res = simplify(pred && new_pred);
376-
return res;
377-
}
378-
valid = false;
379-
return pred;
380-
}
381-
382-
Stmt visit(const Atomic *op) override {
383-
// We don't support codegen for vectorized predicated atomic stores, so
384-
// just bail out.
385-
valid = false;
386-
return op;
387-
}
388-
389-
Expr visit(const Load *op) override {
390-
valid = valid && ((op->predicate.type().lanes() == lanes) || (op->predicate.type().is_scalar() && !expr_uses_var(op->index, var)));
391-
if (!valid) {
392-
return op;
393-
}
394-
395-
Expr predicate, index;
396-
if (!op->index.type().is_scalar()) {
397-
internal_assert(op->predicate.type().lanes() == lanes);
398-
internal_assert(op->index.type().lanes() == lanes);
399-
400-
predicate = mutate(op->predicate);
401-
index = mutate(op->index);
402-
} else if (expr_uses_var(op->index, var)) {
403-
predicate = mutate(Broadcast::make(op->predicate, lanes));
404-
index = mutate(Broadcast::make(op->index, lanes));
405-
} else {
406-
return IRMutator::visit(op);
407-
}
408-
409-
predicate = merge_predicate(predicate, vector_predicate);
410-
if (!valid) {
411-
return op;
412-
}
413-
vectorized = true;
414-
return Load::make(op->type, op->name, index, op->image, op->param, predicate, op->alignment);
415-
}
416-
417-
Stmt visit(const Store *op) override {
418-
valid = valid && ((op->predicate.type().lanes() == lanes) || (op->predicate.type().is_scalar() && !expr_uses_var(op->index, var)));
419-
if (!valid) {
420-
return op;
421-
}
422-
423-
Expr predicate, value, index;
424-
if (!op->index.type().is_scalar()) {
425-
internal_assert(op->predicate.type().lanes() == lanes);
426-
internal_assert(op->index.type().lanes() == lanes);
427-
internal_assert(op->value.type().lanes() == lanes);
428-
429-
predicate = mutate(op->predicate);
430-
value = mutate(op->value);
431-
index = mutate(op->index);
432-
} else if (expr_uses_var(op->index, var)) {
433-
predicate = mutate(Broadcast::make(op->predicate, lanes));
434-
value = mutate(Broadcast::make(op->value, lanes));
435-
index = mutate(Broadcast::make(op->index, lanes));
436-
} else {
437-
return IRMutator::visit(op);
438-
}
439-
440-
predicate = merge_predicate(predicate, vector_predicate);
441-
if (!valid) {
442-
return op;
443-
}
444-
vectorized = true;
445-
return Store::make(op->name, value, index, op->param, predicate, op->alignment);
446-
}
447-
448-
Expr visit(const Call *op) override {
449-
// We should not vectorize calls with side-effects
450-
valid = valid && op->is_pure();
451-
return IRMutator::visit(op);
452-
}
453-
454-
Expr visit(const VectorReduce *op) override {
455-
// We can't predicate vector reductions.
456-
valid = valid && is_const_one(vector_predicate);
457-
return op;
458-
}
459-
460-
public:
461-
PredicateLoadStore(string v, const Expr &vpred)
462-
: var(std::move(v)), vector_predicate(vpred), lanes(vpred.type().lanes()) {
463-
internal_assert(lanes > 1);
464-
}
465-
466-
bool is_vectorized() const {
467-
return valid && vectorized;
468-
}
469-
};
470-
471363
Stmt vectorize_statement(const Stmt &stmt);
472364

473365
struct VectorizedVar {
@@ -859,25 +751,6 @@ class VectorSubs : public IRMutator {
859751
// which would mean control flow divergence within the
860752
// SIMD lanes.
861753

862-
bool vectorize_predicate = true;
863-
864-
Stmt predicated_stmt;
865-
if (vectorize_predicate) {
866-
PredicateLoadStore p(vectorized_vars.front().name, cond);
867-
predicated_stmt = p.mutate(then_case);
868-
vectorize_predicate = p.is_vectorized();
869-
}
870-
if (vectorize_predicate && else_case.defined()) {
871-
PredicateLoadStore p(vectorized_vars.front().name, !cond);
872-
predicated_stmt = Block::make(predicated_stmt, p.mutate(else_case));
873-
vectorize_predicate = p.is_vectorized();
874-
}
875-
876-
debug(4) << "IfThenElse should vectorize predicate "
877-
<< "? " << vectorize_predicate << "; cond: " << cond << "\n";
878-
debug(4) << "Predicated stmt:\n"
879-
<< predicated_stmt << "\n";
880-
881754
// First check if the condition is marked as likely.
882755
if (const Call *likely = Call::as_intrinsic(cond, {Call::likely, Call::likely_if_innermost})) {
883756

@@ -892,48 +765,31 @@ class VectorSubs : public IRMutator {
892765
all_true = Call::make(Bool(), likely->name,
893766
{all_true}, Call::PureIntrinsic);
894767

895-
if (!vectorize_predicate) {
896-
// We should strip the likelies from the case
897-
// that's going to scalarize, because it's no
898-
// longer likely.
899-
Stmt without_likelies =
900-
IfThenElse::make(unwrap_tags(op->condition),
901-
op->then_case, op->else_case);
902-
903-
// scalarize() will put back all vectorized loops around the statement as serial,
904-
// but it still may happen that there are vectorized loops inside of the statement
905-
// itself which we may want to handle. All the context is invalid though, so
906-
// we just start anew for this specific statement.
907-
Stmt scalarized = scalarize(without_likelies, false);
908-
scalarized = vectorize_statement(scalarized);
909-
Stmt stmt =
910-
IfThenElse::make(all_true,
911-
then_case,
912-
scalarized);
913-
debug(4) << "...With all_true likely: \n"
914-
<< stmt << "\n";
915-
return stmt;
916-
} else {
917-
Stmt stmt =
918-
IfThenElse::make(all_true,
919-
then_case,
920-
predicated_stmt);
921-
debug(4) << "...Predicated IfThenElse: \n"
922-
<< stmt << "\n";
923-
return stmt;
924-
}
768+
// We should strip the likelies from the case
769+
// that's going to scalarize, because it's no
770+
// longer likely.
771+
Stmt without_likelies =
772+
IfThenElse::make(unwrap_tags(op->condition),
773+
op->then_case, op->else_case);
774+
775+
// scalarize() will put back all vectorized loops around the statement as serial,
776+
// but it still may happen that there are vectorized loops inside of the statement
777+
// itself which we may want to handle. All the context is invalid though, so
778+
// we just start anew for this specific statement.
779+
Stmt scalarized = scalarize(without_likelies, false);
780+
scalarized = vectorize_statement(scalarized);
781+
Stmt stmt =
782+
IfThenElse::make(all_true,
783+
then_case,
784+
scalarized);
785+
debug(4) << "...With all_true likely: \n"
786+
<< stmt << "\n";
787+
return stmt;
925788
} else {
926789
// It's some arbitrary vector condition.
927-
if (!vectorize_predicate) {
928-
debug(4) << "...Scalarizing vector predicate: \n"
929-
<< Stmt(op) << "\n";
930-
return scalarize(op);
931-
} else {
932-
Stmt stmt = predicated_stmt;
933-
debug(4) << "...Predicated IfThenElse: \n"
934-
<< stmt << "\n";
935-
return stmt;
936-
}
790+
debug(4) << "...Scalarizing vector predicate: \n"
791+
<< Stmt(op) << "\n";
792+
return scalarize(op);
937793
}
938794
} else {
939795
// It's an if statement on a scalar, we're ok to vectorize the innards.

0 commit comments

Comments
 (0)