diff --git a/src/VectorizeLoops.cpp b/src/VectorizeLoops.cpp index d06cc0815300..5f6aee541b28 100644 --- a/src/VectorizeLoops.cpp +++ b/src/VectorizeLoops.cpp @@ -1,7 +1,6 @@ #include #include -#include "CSE.h" #include "CodeGen_GPU_Dev.h" #include "Deinterleave.h" #include "ExprUsesVar.h" @@ -361,114 +360,6 @@ class SerializeLoops : public IRMutator { } }; -// Wrap a vectorized predicate around a Load/Store node. -class PredicateLoadStore : public IRMutator { - string var; - Expr vector_predicate; - int lanes; - bool valid = true; - bool vectorized = false; - - using IRMutator::visit; - - Expr merge_predicate(Expr pred, const Expr &new_pred) { - if (pred.type().lanes() == new_pred.type().lanes()) { - Expr res = simplify(pred && new_pred); - return res; - } - valid = false; - return pred; - } - - Stmt visit(const Atomic *op) override { - // We don't support codegen for vectorized predicated atomic stores, so - // just bail out. - valid = false; - return op; - } - - Expr visit(const Load *op) override { - valid = valid && ((op->predicate.type().lanes() == lanes) || (op->predicate.type().is_scalar() && !expr_uses_var(op->index, var))); - if (!valid) { - return op; - } - - Expr predicate, index; - if (!op->index.type().is_scalar()) { - internal_assert(op->predicate.type().lanes() == lanes); - internal_assert(op->index.type().lanes() == lanes); - - predicate = mutate(op->predicate); - index = mutate(op->index); - } else if (expr_uses_var(op->index, var)) { - predicate = mutate(Broadcast::make(op->predicate, lanes)); - index = mutate(Broadcast::make(op->index, lanes)); - } else { - return IRMutator::visit(op); - } - - predicate = merge_predicate(predicate, vector_predicate); - if (!valid) { - return op; - } - vectorized = true; - return Load::make(op->type, op->name, index, op->image, op->param, predicate, op->alignment); - } - - Stmt visit(const Store *op) override { - valid = valid && ((op->predicate.type().lanes() == lanes) || (op->predicate.type().is_scalar() && !expr_uses_var(op->index, var))); - if (!valid) { - return op; - } - - Expr predicate, value, index; - if (!op->index.type().is_scalar()) { - internal_assert(op->predicate.type().lanes() == lanes); - internal_assert(op->index.type().lanes() == lanes); - internal_assert(op->value.type().lanes() == lanes); - - predicate = mutate(op->predicate); - value = mutate(op->value); - index = mutate(op->index); - } else if (expr_uses_var(op->index, var)) { - predicate = mutate(Broadcast::make(op->predicate, lanes)); - value = mutate(Broadcast::make(op->value, lanes)); - index = mutate(Broadcast::make(op->index, lanes)); - } else { - return IRMutator::visit(op); - } - - predicate = merge_predicate(predicate, vector_predicate); - if (!valid) { - return op; - } - vectorized = true; - return Store::make(op->name, value, index, op->param, predicate, op->alignment); - } - - Expr visit(const Call *op) override { - // We should not vectorize calls with side-effects - valid = valid && op->is_pure(); - return IRMutator::visit(op); - } - - Expr visit(const VectorReduce *op) override { - // We can't predicate vector reductions. - valid = valid && is_const_one(vector_predicate); - return op; - } - -public: - PredicateLoadStore(string v, const Expr &vpred) - : var(std::move(v)), vector_predicate(vpred), lanes(vpred.type().lanes()) { - internal_assert(lanes > 1); - } - - bool is_vectorized() const { - return valid && vectorized; - } -}; - Stmt vectorize_statement(const Stmt &stmt); struct VectorizedVar { @@ -502,16 +393,16 @@ class VectorSubs : public IRMutator { vector> containing_lets; // Widen an expression to the given number of lanes. - Expr widen(Expr e, int lanes) { + static Expr widen(Expr e, int lanes) { if (e.type().lanes() == lanes) { return e; - } else if (lanes % e.type().lanes() == 0) { + } + if (lanes % e.type().lanes() == 0) { return Broadcast::make(e, lanes / e.type().lanes()); - } else { - internal_error << "Mismatched vector lanes in VectorSubs " << e.type().lanes() - << " " << lanes << "\n"; } - return Expr(); + internal_error + << "Cannot widen " << e.type().lanes() << " lanes to " << lanes << ".\n" + << "Expression: " << e << "\n"; } using IRMutator::visit; @@ -612,19 +503,19 @@ class VectorSubs : public IRMutator { Expr condition = mutate(op->condition); Expr true_value = mutate(op->true_value); Expr false_value = mutate(op->false_value); + if (condition.same_as(op->condition) && true_value.same_as(op->true_value) && false_value.same_as(op->false_value)) { return op; - } else { - int lanes = std::max(true_value.type().lanes(), false_value.type().lanes()); - lanes = std::max(lanes, condition.type().lanes()); - // Widen the true and false values, but we don't have to widen the condition - true_value = widen(true_value, lanes); - false_value = widen(false_value, lanes); - condition = widen(condition, lanes); - return Select::make(condition, true_value, false_value); } + + int lanes = std::max({condition.type().lanes(), + true_value.type().lanes(), + false_value.type().lanes()}); + return Select::make(widen(condition, lanes), + widen(true_value, lanes), + widen(false_value, lanes)); } Expr visit(const Load *op) override { @@ -633,12 +524,15 @@ class VectorSubs : public IRMutator { if (predicate.same_as(op->predicate) && index.same_as(op->index)) { return op; - } else { - int w = index.type().lanes(); - predicate = widen(predicate, w); - return Load::make(op->type.with_lanes(w), op->name, index, op->image, - op->param, predicate, op->alignment); } + + int lanes = std::max(predicate.type().lanes(), index.type().lanes()); + return Load::make(op->type.with_lanes(lanes), + op->name, + widen(index, lanes), + op->image, op->param, + widen(predicate, lanes), + op->alignment); } Expr visit(const Call *op) override { @@ -715,12 +609,12 @@ class VectorSubs : public IRMutator { } return Call::make(op->type, Call::trace, new_args, op->call_type); } else if (op->is_intrinsic(Call::if_then_else) && op->args.size() == 2) { - Expr cond = widen(new_args[0], max_lanes); Expr true_value = widen(new_args[1], max_lanes); - - const Load *load = true_value.as(); - if (load) { - return Load::make(op->type.with_lanes(max_lanes), load->name, load->index, load->image, load->param, cond, load->alignment); + if (const Load *load = true_value.as()) { + return Load::make(op->type.with_lanes(max_lanes), + load->name, load->index, load->image, load->param, + widen(new_args[0], max_lanes), + load->alignment); } } @@ -731,9 +625,9 @@ class VectorSubs : public IRMutator { Type new_op_type = op->type.with_lanes(max_lanes); if (op->is_intrinsic(Call::prefetch)) { - // We don't want prefetch args to ve vectorized, but we can't just skip the mutation - // (otherwise we can end up with dead loop variables. Instead, use extract_lane() on each arg - // to scalarize it again. + // We don't want prefetch args to be vectorized, but we can't just skip the mutation + // (otherwise we can end up with dead loop variables. Instead, use extract_lane() on + // each arg to scalarize it again. for (auto &arg : new_args) { if (arg.type().is_vector()) { arg = extract_lane(arg, 0); @@ -828,13 +722,13 @@ class VectorSubs : public IRMutator { if (predicate.same_as(op->predicate) && value.same_as(op->value) && index.same_as(op->index)) { return op; - } else { - int lanes = std::max({predicate.type().lanes(), - value.type().lanes(), - index.type().lanes()}); - return Store::make(op->name, widen(value, lanes), widen(index, lanes), - op->param, widen(predicate, lanes), op->alignment); } + + int lanes = std::max({predicate.type().lanes(), + value.type().lanes(), + index.type().lanes()}); + return Store::make(op->name, widen(value, lanes), widen(index, lanes), + op->param, widen(predicate, lanes), op->alignment); } Stmt visit(const AssertStmt *op) override { @@ -857,25 +751,6 @@ class VectorSubs : public IRMutator { // which would mean control flow divergence within the // SIMD lanes. - bool vectorize_predicate = true; - - Stmt predicated_stmt; - if (vectorize_predicate) { - PredicateLoadStore p(vectorized_vars.front().name, cond); - predicated_stmt = p.mutate(then_case); - vectorize_predicate = p.is_vectorized(); - } - if (vectorize_predicate && else_case.defined()) { - PredicateLoadStore p(vectorized_vars.front().name, !cond); - predicated_stmt = Block::make(predicated_stmt, p.mutate(else_case)); - vectorize_predicate = p.is_vectorized(); - } - - debug(4) << "IfThenElse should vectorize predicate " - << "? " << vectorize_predicate << "; cond: " << cond << "\n"; - debug(4) << "Predicated stmt:\n" - << predicated_stmt << "\n"; - // First check if the condition is marked as likely. if (const Call *likely = Call::as_intrinsic(cond, {Call::likely, Call::likely_if_innermost})) { @@ -890,48 +765,31 @@ class VectorSubs : public IRMutator { all_true = Call::make(Bool(), likely->name, {all_true}, Call::PureIntrinsic); - if (!vectorize_predicate) { - // We should strip the likelies from the case - // that's going to scalarize, because it's no - // longer likely. - Stmt without_likelies = - IfThenElse::make(unwrap_tags(op->condition), - op->then_case, op->else_case); - - // scalarize() will put back all vectorized loops around the statement as serial, - // but it still may happen that there are vectorized loops inside of the statement - // itself which we may want to handle. All the context is invalid though, so - // we just start anew for this specific statement. - Stmt scalarized = scalarize(without_likelies, false); - scalarized = vectorize_statement(scalarized); - Stmt stmt = - IfThenElse::make(all_true, - then_case, - scalarized); - debug(4) << "...With all_true likely: \n" - << stmt << "\n"; - return stmt; - } else { - Stmt stmt = - IfThenElse::make(all_true, - then_case, - predicated_stmt); - debug(4) << "...Predicated IfThenElse: \n" - << stmt << "\n"; - return stmt; - } + // We should strip the likelies from the case + // that's going to scalarize, because it's no + // longer likely. + Stmt without_likelies = + IfThenElse::make(unwrap_tags(op->condition), + op->then_case, op->else_case); + + // scalarize() will put back all vectorized loops around the statement as serial, + // but it still may happen that there are vectorized loops inside of the statement + // itself which we may want to handle. All the context is invalid though, so + // we just start anew for this specific statement. + Stmt scalarized = scalarize(without_likelies, false); + scalarized = vectorize_statement(scalarized); + Stmt stmt = + IfThenElse::make(all_true, + then_case, + scalarized); + debug(4) << "...With all_true likely: \n" + << stmt << "\n"; + return stmt; } else { // It's some arbitrary vector condition. - if (!vectorize_predicate) { - debug(4) << "...Scalarizing vector predicate: \n" - << Stmt(op) << "\n"; - return scalarize(op); - } else { - Stmt stmt = predicated_stmt; - debug(4) << "...Predicated IfThenElse: \n" - << stmt << "\n"; - return stmt; - } + debug(4) << "...Scalarizing vector predicate: \n" + << Stmt(op) << "\n"; + return scalarize(op); } } else { // It's an if statement on a scalar, we're ok to vectorize the innards. @@ -972,64 +830,64 @@ class VectorSubs : public IRMutator { body = IfThenElse::make(likely(var < op->min + op->extent), body); } - if (op->for_type == ForType::Vectorized) { - const IntImm *extent_int = extent.as(); - internal_assert(extent_int) - << "Vectorized for loop extent should have been rewritten to a constant\n"; - if (extent_int->value <= 1) { - user_error << "Loop over " << op->name - << " has extent " << extent - << ". Can only vectorize loops over a " - << "constant extent > 1\n"; - } + if (op->for_type != ForType::Vectorized) { + body = mutate(body); - vectorized_vars.push_back({op->name, min, (int)extent_int->value}); - update_replacements(); - // Go over lets which were vectorized in the order of their occurrence and update - // them according to the current loop level. - for (const auto &[var, val] : containing_lets) { - // Skip if this var wasn't vectorized. - if (!scope.contains(var)) { - continue; - } - string vectorized_name = get_widened_var_name(var); - Expr vectorized_value = mutate(scope.get(var)); - vector_scope.push(vectorized_name, vectorized_value); + if (min.same_as(op->min) && + extent.same_as(op->extent) && + for_type == op->for_type && + body.same_as(op->body)) { + return op; } - body = mutate(body); + return For::make(op->name, min, extent, for_type, op->partition_policy, op->device_api, body); + } - // Append vectorized lets for this loop level. - for (const auto &[var, _] : reverse_view(containing_lets)) { - // Skip if this var wasn't vectorized. - if (!scope.contains(var)) { - continue; - } - string vectorized_name = get_widened_var_name(var); - Expr vectorized_value = vector_scope.get(vectorized_name); - vector_scope.pop(vectorized_name); - InterleavedRamp ir; - if (is_interleaved_ramp(vectorized_value, vector_scope, &ir)) { - body = substitute(vectorized_name, vectorized_value, body); - } else { - body = LetStmt::make(vectorized_name, vectorized_value, body); - } + const IntImm *extent_int = extent.as(); + internal_assert(extent_int) + << "Vectorized for loop extent should have been rewritten to a constant\n"; + if (extent_int->value <= 1) { + user_error << "Loop over " << op->name + << " has extent " << extent + << ". Can only vectorize loops over a " + << "constant extent > 1\n"; + } + + vectorized_vars.push_back({op->name, min, (int)extent_int->value}); + update_replacements(); + // Go over lets which were vectorized in the order of their occurrence and update + // them according to the current loop level. + for (const auto &[var, val] : containing_lets) { + // Skip if this var wasn't vectorized. + if (!scope.contains(var)) { + continue; } - vectorized_vars.pop_back(); - update_replacements(); - return body; - } else { - body = mutate(body); + string vectorized_name = get_widened_var_name(var); + Expr vectorized_value = mutate(scope.get(var)); + vector_scope.push(vectorized_name, vectorized_value); + } - if (min.same_as(op->min) && - extent.same_as(op->extent) && - body.same_as(op->body) && - for_type == op->for_type) { - return op; + body = mutate(body); + + // Append vectorized lets for this loop level. + for (const auto &[var, _] : reverse_view(containing_lets)) { + // Skip if this var wasn't vectorized. + if (!scope.contains(var)) { + continue; + } + string vectorized_name = get_widened_var_name(var); + Expr vectorized_value = vector_scope.get(vectorized_name); + vector_scope.pop(vectorized_name); + InterleavedRamp ir; + if (is_interleaved_ramp(vectorized_value, vector_scope, &ir)) { + body = substitute(vectorized_name, vectorized_value, body); } else { - return For::make(op->name, min, extent, for_type, op->partition_policy, op->device_api, body); + body = LetStmt::make(vectorized_name, vectorized_value, body); } } + vectorized_vars.pop_back(); + update_replacements(); + return body; } Stmt visit(const Allocate *op) override { @@ -1086,218 +944,257 @@ class VectorSubs : public IRMutator { return Allocate::make(op->name, op->type, op->memory_type, new_extents, op->condition, body, new_expr, op->free_function); } - Stmt visit(const Atomic *op) override { - // Recognize a few special cases that we can handle as within-vector reduction trees. - do { - if (!op->mutex_name.empty()) { - // We can't vectorize over a mutex - break; - } + // Recognize a few special cases that we can handle as within-vector reduction trees. + std::optional visit_atomic(const Atomic *op) { + if (!op->mutex_name.empty()) { + // We can't vectorize over a mutex + return std::nullopt; + } - const Store *store = op->body.as(); - if (!store) { - break; - } + const Store *store = op->body.as(); + if (!store) { + return std::nullopt; + } - // f[x] = y - if (!expr_uses_var(store->value, store->name) && - !expr_uses_var(store->predicate, store->name)) { - // This can be naively vectorized just fine. If there are - // repeated values in the vectorized store index, the ordering - // of writes may be undetermined and backend-dependent, but - // they'll be atomic. - Stmt s = mutate(store); - - // We may still need the atomic node, if there was more - // parallelism than just the vectorization. - s = Atomic::make(op->producer_name, op->mutex_name, s); - return s; - } + // f[x] = y + if (!expr_uses_var(store->value, store->name) && + !expr_uses_var(store->predicate, store->name)) { + // This can be naively vectorized just fine. If there are + // repeated values in the vectorized store index, the ordering + // of writes may be undetermined and backend-dependent, but + // they'll be atomic. + Stmt s = mutate(store); - // f[x] = f[x] y - VectorReduce::Operator reduce_op = VectorReduce::Add; - Expr a, b; - if (const Add *add = store->value.as()) { - a = add->a; - b = add->b; - reduce_op = VectorReduce::Add; - } else if (const Mul *mul = store->value.as()) { - a = mul->a; - b = mul->b; - reduce_op = VectorReduce::Mul; - } else if (const Min *min = store->value.as()) { - a = min->a; - b = min->b; - reduce_op = VectorReduce::Min; - } else if (const Max *max = store->value.as()) { - a = max->a; - b = max->b; - reduce_op = VectorReduce::Max; - } else if (const Cast *cast_op = store->value.as()) { - if (cast_op->type.element_of() == UInt(8) && - cast_op->value.type().is_bool()) { - if (const And *and_op = cast_op->value.as()) { - a = and_op->a; - b = and_op->b; - reduce_op = VectorReduce::And; - } else if (const Or *or_op = cast_op->value.as()) { - a = or_op->a; - b = or_op->b; - reduce_op = VectorReduce::Or; - } - } - } else if (const Call *call_op = store->value.as()) { - if (call_op->is_intrinsic(Call::saturating_add)) { - a = call_op->args[0]; - b = call_op->args[1]; - reduce_op = VectorReduce::SaturatingAdd; + // We may still need the atomic node, if there was more + // parallelism than just the vectorization. + s = Atomic::make(op->producer_name, op->mutex_name, s); + return s; + } + + // f[x] = f[x] y + VectorReduce::Operator reduce_op = VectorReduce::Add; + Expr a, b; + if (const Add *add = store->value.as()) { + a = add->a; + b = add->b; + reduce_op = VectorReduce::Add; + } else if (const Mul *mul = store->value.as()) { + a = mul->a; + b = mul->b; + reduce_op = VectorReduce::Mul; + } else if (const Min *min = store->value.as()) { + a = min->a; + b = min->b; + reduce_op = VectorReduce::Min; + } else if (const Max *max = store->value.as()) { + a = max->a; + b = max->b; + reduce_op = VectorReduce::Max; + } else if (const Cast *cast_op = store->value.as()) { + if (cast_op->type.element_of() == UInt(8) && + cast_op->value.type().is_bool()) { + if (const And *and_op = cast_op->value.as()) { + a = and_op->a; + b = and_op->b; + reduce_op = VectorReduce::And; + } else if (const Or *or_op = cast_op->value.as()) { + a = or_op->a; + b = or_op->b; + reduce_op = VectorReduce::Or; } } - - if (!a.defined() || !b.defined()) { - break; + } else if (const Call *call_op = store->value.as()) { + if (call_op->is_intrinsic(Call::saturating_add)) { + a = call_op->args[0]; + b = call_op->args[1]; + reduce_op = VectorReduce::SaturatingAdd; } + } - // Bools get cast to uint8 for storage. Strip off that - // cast around any load. - if (b.type().is_bool()) { - const Cast *cast_op = b.as(); - if (cast_op) { - b = cast_op->value; - } + if (!a.defined() || !b.defined()) { + return std::nullopt; + } + + // Bools get cast to uint8 for storage. Strip off that + // cast around any load. + if (b.type().is_bool()) { + const Cast *cast_op = b.as(); + if (cast_op) { + b = cast_op->value; } - if (a.type().is_bool()) { - const Cast *cast_op = b.as(); - if (cast_op) { - a = cast_op->value; - } + } + if (a.type().is_bool()) { + const Cast *cast_op = b.as(); + if (cast_op) { + a = cast_op->value; } + } - if (a.as() && !b.as()) { - std::swap(a, b); - } + if (a.as() && !b.as()) { + std::swap(a, b); + } - // We require b to be a var, because it should have been lifted. - const Variable *var_b = b.as(); - const Load *load_a = a.as(); - - if (!var_b || - !scope.contains(var_b->name) || - !load_a || - load_a->name != store->name || - !is_const_one(load_a->predicate) || - !is_const_one(store->predicate)) { - break; - } + // We require b to be a var, because it should have been lifted. + const Load *load_a = a.as(); + if (!load_a || load_a->name != store->name) { + debug(4) << "Unable to vectorize atomic store\n"; + return std::nullopt; + } - b = vector_scope.get(get_widened_var_name(var_b->name)); - Expr store_index = mutate(store->index); - Expr load_index = mutate(load_a->index); - - // The load and store indices must be the same interleaved - // ramp (or the same scalar, in the total reduction case). - InterleavedRamp store_ir, load_ir; - Expr test; - if (store_index.type().is_scalar()) { - test = simplify(load_index == store_index); - } else if (is_interleaved_ramp(store_index, vector_scope, &store_ir) && - is_interleaved_ramp(load_index, vector_scope, &load_ir) && - store_ir.inner_repetitions == load_ir.inner_repetitions && - store_ir.outer_repetitions == load_ir.outer_repetitions && - store_ir.lanes == load_ir.lanes) { - test = simplify(store_ir.base == load_ir.base && - store_ir.stride == load_ir.stride); - } + const Variable *var_b = b.as(); + if (!var_b || !scope.contains(var_b->name)) { + debug(4) << "Unable to vectorize atomic store because RHS was not vectorized\n"; + return std::nullopt; + } - if (!test.defined()) { - break; - } + if (!equal(load_a->predicate, store->predicate)) { + debug(4) << "Not vectorizing an atomic store with a mismatched predicate\n"; + return std::nullopt; + } + + b = vector_scope.get(get_widened_var_name(var_b->name)); + Expr store_index = mutate(store->index); + Expr load_index = mutate(load_a->index); + Expr store_predicate = mutate(store->predicate); + + // The load and store indices must be the same interleaved + // ramp (or the same scalar, in the total reduction case). + InterleavedRamp store_ir, load_ir; + Expr test; + if (store_index.type().is_scalar()) { + test = simplify(load_index == store_index); + } else if (is_interleaved_ramp(store_index, vector_scope, &store_ir) && + is_interleaved_ramp(load_index, vector_scope, &load_ir) && + store_ir.inner_repetitions == load_ir.inner_repetitions && + store_ir.outer_repetitions == load_ir.outer_repetitions && + store_ir.lanes == load_ir.lanes) { + test = simplify(store_ir.base == load_ir.base && + store_ir.stride == load_ir.stride); + } - if (is_const_zero(test)) { - break; - } else if (!is_const_one(test)) { - // TODO: try harder by substituting in more things in scope - break; + if (!test.defined()) { + return std::nullopt; + } + + if (is_const_zero(test)) { + return std::nullopt; + } else if (!is_const_one(test)) { + // TODO: try harder by substituting in more things in scope + return std::nullopt; + } + + auto binop = [=](const Expr &a, const Expr &b) { + switch (reduce_op) { + case VectorReduce::Add: + return a + b; + case VectorReduce::Mul: + return a * b; + case VectorReduce::Min: + return min(a, b); + case VectorReduce::Max: + return max(a, b); + case VectorReduce::And: + return a && b; + case VectorReduce::Or: + return a || b; + case VectorReduce::SaturatingAdd: + return saturating_add(a, b); } + internal_error << "Missing case for " << reduce_op; + }; - auto binop = [=](const Expr &a, const Expr &b) { + if (!is_const_one(store_predicate)) { + Expr identity = [&] { switch (reduce_op) { case VectorReduce::Add: - return a + b; + case VectorReduce::SaturatingAdd: + return make_zero(b.type()); case VectorReduce::Mul: - return a * b; + return make_one(b.type()); case VectorReduce::Min: - return min(a, b); + return b.type().max(); case VectorReduce::Max: - return max(a, b); + return b.type().min(); case VectorReduce::And: - return a && b; + return const_true(b.type().lanes()); case VectorReduce::Or: - return a || b; - case VectorReduce::SaturatingAdd: - return saturating_add(a, b); + return const_false(b.type().lanes()); } - return Expr(); - }; + internal_error << "Missing case for " << reduce_op; + }(); + b = Select::make(store_predicate, b, identity); + } - int output_lanes = 1; - if (store_index.type().is_scalar()) { - // The index doesn't depend on the value being - // vectorized, so it's a total reduction. + int output_lanes = 1; + if (store_index.type().is_scalar()) { + // The index doesn't depend on the value being + // vectorized, so it's a total reduction. - b = VectorReduce::make(reduce_op, b, 1); - } else { + b = VectorReduce::make(reduce_op, b, 1); + } else { - output_lanes = store_index.type().lanes() / (store_ir.inner_repetitions * store_ir.outer_repetitions); + output_lanes = store_index.type().lanes() / (store_ir.inner_repetitions * store_ir.outer_repetitions); - store_index = Ramp::make(store_ir.base, store_ir.stride, output_lanes / store_ir.base.type().lanes()); - if (store_ir.inner_repetitions > 1) { - b = VectorReduce::make(reduce_op, b, output_lanes * store_ir.outer_repetitions); - } + store_index = Ramp::make(store_ir.base, store_ir.stride, output_lanes / store_ir.base.type().lanes()); + if (store_ir.inner_repetitions > 1) { + b = VectorReduce::make(reduce_op, b, output_lanes * store_ir.outer_repetitions); + } - // Handle outer repetitions by unrolling the reduction - // over slices. - if (store_ir.outer_repetitions > 1) { - // First remove all powers of two with a binary reduction tree. - int reps = store_ir.outer_repetitions; - while (reps % 2 == 0) { - int l = b.type().lanes() / 2; - Expr b0 = Shuffle::make_slice(b, 0, 1, l); - Expr b1 = Shuffle::make_slice(b, l, 1, l); - b = binop(b0, b1); - reps /= 2; - } + // Handle outer repetitions by unrolling the reduction + // over slices. + if (store_ir.outer_repetitions > 1) { + // First remove all powers of two with a binary reduction tree. + int reps = store_ir.outer_repetitions; + while (reps % 2 == 0) { + int l = b.type().lanes() / 2; + Expr b0 = Shuffle::make_slice(b, 0, 1, l); + Expr b1 = Shuffle::make_slice(b, l, 1, l); + b = binop(b0, b1); + reps /= 2; + } - // Then reduce linearly over slices for the rest. - if (reps > 1) { - Expr v = Shuffle::make_slice(b, 0, 1, output_lanes); - for (int i = 1; i < reps; i++) { - Expr slice = simplify(Shuffle::make_slice(b, i * output_lanes, 1, output_lanes)); - v = binop(v, slice); - } - b = v; + // Then reduce linearly over slices for the rest. + if (reps > 1) { + Expr v = Shuffle::make_slice(b, 0, 1, output_lanes); + for (int i = 1; i < reps; i++) { + Expr slice = simplify(Shuffle::make_slice(b, i * output_lanes, 1, output_lanes)); + v = binop(v, slice); } + b = v; } } + } - Expr new_load = Load::make(load_a->type.with_lanes(output_lanes), - load_a->name, store_index, load_a->image, - load_a->param, const_true(output_lanes), - ModulusRemainder{}); + Expr new_load = Load::make(load_a->type.with_lanes(output_lanes), + load_a->name, store_index, load_a->image, + load_a->param, const_true(output_lanes), + ModulusRemainder{}); - Expr lhs = cast(b.type(), new_load); - b = binop(lhs, b); - b = cast(new_load.type(), b); + Expr lhs = cast(b.type(), new_load); + b = binop(lhs, b); + b = cast(new_load.type(), b); - Stmt s = Store::make(store->name, b, store_index, store->param, - const_true(b.type().lanes()), store->alignment); + Expr final_predicate = [&] { + if (is_const_one(store_predicate)) { + return const_true(output_lanes); + } + return VectorReduce::make(VectorReduce::Or, store_predicate, output_lanes); + }(); - // We may still need the atomic node, if there was more - // parallelism than just the vectorization. - s = Atomic::make(op->producer_name, op->mutex_name, s); + Stmt s = Store::make(store->name, b, store_index, + store->param, final_predicate, store->alignment); - return s; - } while (false); + // We may still need the atomic node, if there was more + // parallelism than just the vectorization. + s = Atomic::make(op->producer_name, op->mutex_name, s); + + return s; + } + + Stmt visit(const Atomic *op) override { + if (auto vectorized_op = visit_atomic(op)) { + return *vectorized_op; + } // In the general case, if a whole stmt has to be done // atomically, we need to serialize. @@ -1324,40 +1221,6 @@ class VectorSubs : public IRMutator { return s; } - Expr scalarize(Expr e) { - // This method returns a select tree that produces a vector lanes - // result expression - user_assert(replacements.size() == 1) << "Can't scalarize nested vectorization\n"; - string var = replacements.begin()->first; - Expr replacement = replacements.begin()->second; - - Expr result; - int lanes = replacement.type().lanes(); - - for (int i = lanes - 1; i >= 0; --i) { - // Hide all the vector let values in scope with a scalar version - // in the appropriate lane. - for (Scope::const_iterator iter = scope.cbegin(); iter != scope.cend(); ++iter) { - e = substitute(iter.name(), - get_lane(Variable::make(iter.value().type(), iter.name()), i), - e); - } - - // Replace uses of the vectorized variable with the extracted - // lane expression - e = substitute(var, i, e); - - if (i == lanes - 1) { - result = Broadcast::make(e, lanes); - } else { - Expr cond = (replacement == Broadcast::make(i, lanes)); - result = Select::make(cond, Broadcast::make(e, lanes), result); - } - } - - return result; - } - // Recompute all replacements for vectorized vars based on // the current stack of vectorized loops. void update_replacements() { @@ -1614,6 +1477,7 @@ class AllStoresInScope : public IRVisitor { : s(s) { } }; + bool all_stores_in_scope(const Stmt &stmt, const Scope<> &scope) { AllStoresInScope checker(scope); stmt.accept(&checker); @@ -1660,12 +1524,95 @@ Stmt vectorize_statement(const Stmt &stmt) { return VectorizeLoops().mutate(stmt); } +struct PredicateOps : IRMutator { + using IRMutator::visit; + + Scope<> vectorized_vars; + std::vector predicates; + + Stmt visit(const For *op) override { + if (op->for_type != ForType::Vectorized) { + return IRMutator::visit(op); + } + + vectorized_vars.push(op->name); + Stmt body = mutate(op->body); + vectorized_vars.pop(op->name); + + if (body.same_as(op->body)) { + return op; + } + + return For::make(op->name, op->min, op->extent, op->for_type, op->partition_policy, op->device_api, + body); + } + + Stmt visit(const IfThenElse *op) override { + if (!expr_uses_vars(op->condition, vectorized_vars)) { + // No need to mutate condition if it doesn't use any vectorized vars + Stmt then_case = mutate(op->then_case); + Stmt else_case = mutate(op->else_case); + if (then_case.same_as(op->then_case) && else_case.same_as(op->else_case)) { + return op; + } + return IfThenElse::make(op->condition, then_case, else_case); + } + + predicates.push_back(op->condition); + Stmt then_case = mutate(op->then_case); + predicates.pop_back(); + + Stmt else_case = op->else_case; + if (!else_case.defined()) { + return then_case; + } + + predicates.push_back(!op->condition); + else_case = mutate(else_case); + predicates.pop_back(); + + return Block::make(then_case, else_case); + } + + Expr apply_predicates(Expr pred) const { + for (const Expr &p : predicates) { + pred = pred && p; + } + return simplify(pred); + } + + Expr visit(const Load *op) override { + if (predicates.empty()) { + return op; + } + return Load::make(op->type, op->name, + mutate(op->index), + op->image, op->param, + apply_predicates(op->predicate), + op->alignment); + } + + Stmt visit(const Store *op) override { + if (predicates.empty()) { + return op; + } + return Store::make(op->name, + mutate(op->value), mutate(op->index), + op->param, + apply_predicates(op->predicate), + op->alignment); + } +}; + } // namespace + Stmt vectorize_loops(const Stmt &stmt, const map &env) { // Limit the scope of atomic nodes to just the necessary stuff. // TODO: Should this be an earlier pass? It's probably a good idea // for non-vectorizing stuff too. - Stmt s = LiftVectorizableExprsOutOfAllAtomicNodes(env).mutate(stmt); + Stmt s = stmt; + s = LiftVectorizableExprsOutOfAllAtomicNodes(env).mutate(s); + s = PredicateOps().mutate(s); s = vectorize_statement(s); s = RemoveUnnecessaryAtomics().mutate(s); return s; diff --git a/test/correctness/CMakeLists.txt b/test/correctness/CMakeLists.txt index adac3bae09e7..2f964e70672c 100644 --- a/test/correctness/CMakeLists.txt +++ b/test/correctness/CMakeLists.txt @@ -350,6 +350,7 @@ tests(GROUPS correctness vector_reductions.cpp vector_shuffle.cpp vector_tile.cpp + vectorize_atomic_predicate.cpp vectorize_guard_with_if.cpp vectorize_mixed_widths.cpp vectorize_nested.cpp diff --git a/test/correctness/predicated_store_load.cpp b/test/correctness/predicated_store_load.cpp index 3e21e0f9a3d7..0cf47c34c944 100644 --- a/test/correctness/predicated_store_load.cpp +++ b/test/correctness/predicated_store_load.cpp @@ -124,7 +124,7 @@ int predicated_tail_with_scalar_test(const Target &t) { if (t.has_feature(Target::HVX)) { f.hexagon(); } - f.add_custom_lowering_pass(new CheckPredicatedStoreLoad(1, 0)); + f.add_custom_lowering_pass(new CheckPredicatedStoreLoad(1, 1)); Buffer im = f.realize({size, size}); auto func = [](int x, int y) { @@ -158,7 +158,7 @@ int vectorized_predicated_store_scalarized_predicated_load_test(const Target &t) f.update(0).hexagon(); } - f.add_custom_lowering_pass(new CheckPredicatedStoreLoad(2, 6)); + f.add_custom_lowering_pass(new CheckPredicatedStoreLoad(1, 3)); Buffer im = f.realize({170, 170}); auto func = [im_ref](int x, int y, int z) { return im_ref(x, y, z); }; @@ -185,7 +185,7 @@ int vectorized_dense_load_with_stride_minus_one_test(const Target &t) { if (t.has_feature(Target::HVX)) { f.hexagon(); } - f.add_custom_lowering_pass(new CheckPredicatedStoreLoad(3, 6)); + f.add_custom_lowering_pass(new CheckPredicatedStoreLoad(1, 2)); Buffer im = f.realize({size, size}); auto func = [&im_ref, &im](int x, int y, int z) { @@ -253,7 +253,7 @@ int scalar_load_test(const Target &t) { f.update(0).hexagon(); } - f.add_custom_lowering_pass(new CheckPredicatedStoreLoad(1, 2)); + f.add_custom_lowering_pass(new CheckPredicatedStoreLoad(1, 3)); Buffer im = f.realize({160, 160}); auto func = [im_ref](int x, int y, int z) { return im_ref(x, y, z); }; @@ -287,7 +287,7 @@ int scalar_store_test(const Target &t) { f.update(0).hexagon(); } - f.add_custom_lowering_pass(new CheckPredicatedStoreLoad(1, 1)); + f.add_custom_lowering_pass(new CheckPredicatedStoreLoad(1, 2)); Buffer im = f.realize({160, 160}); auto func = [im_ref](int x, int y, int z) { return im_ref(x, y, z); }; @@ -320,7 +320,7 @@ int not_dependent_on_vectorized_var_test(const Target &t) { if (t.has_feature(Target::HVX)) { f.update(0).hexagon(); } - f.add_custom_lowering_pass(new CheckPredicatedStoreLoad(0, 0)); + f.add_custom_lowering_pass(new CheckPredicatedStoreLoad(1, 2)); Buffer im = f.realize({160, 160, 160}); auto func = [im_ref](int x, int y, int z) { return im_ref(x, y, z); }; @@ -382,7 +382,7 @@ int vectorized_predicated_predicate_with_pure_call_test(const Target &t) { if (t.has_feature(Target::HVX)) { f.update(0).hexagon(); } - f.add_custom_lowering_pass(new CheckPredicatedStoreLoad(2, 4)); + f.add_custom_lowering_pass(new CheckPredicatedStoreLoad(1, 2)); Buffer im = f.realize({160, 160}); auto func = [im_ref](int x, int y, int z) { return im_ref(x, y, z); }; diff --git a/test/correctness/vectorize_atomic_predicate.cpp b/test/correctness/vectorize_atomic_predicate.cpp new file mode 100644 index 000000000000..f2323a6c577d --- /dev/null +++ b/test/correctness/vectorize_atomic_predicate.cpp @@ -0,0 +1,59 @@ +#include "Halide.h" + +using namespace Halide; + +int main(int argc, char **argv) { + ImageParam mat(Float(32), 2, "mat"); + mat.dim(0).set_min(0).set_extent(mat.dim(0).extent() / 4 * 4); + mat.dim(1).set_min(0).set_stride(mat.dim(0).extent()); + + ImageParam vec(Float(32), 1, "vec"); + vec.dim(0).set_bounds(0, mat.dim(0).extent()); + + Func mv{"mv"}; + Var x{"x"}; + + // RDom r(0, vec.dim(0).extent() / 4 * 4); // <- works with this, because no tail + RDom r(0, vec.dim(0).extent()); + mv(x) += mat(r, x) * vec(r); + + Func out = mv.in(); + + RVar ro{"ro"}, ri{"ri"}; + Var u{"u"}; + + out.output_buffer().dim(0).set_bounds(0, mat.dim(1).extent() / 4 * 4); + out.vectorize(x, 4); + + auto intm = mv.update().split(r, ro, ri, 4, TailStrategy::Predicate).rfactor(ri, u); + intm.compute_at(out, x) + .reorder_storage(u, x) + .vectorize(u) + .unroll(x); + + intm.update().reorder(x, u, ro).vectorize(u).unroll(x); + + mv.update().atomic().vectorize(ri, 4); + mv.bound_extent(x, 4); + + struct TestMatcher : Internal::IRVisitor { + using IRVisitor::visit; + void visit(const Internal::VectorReduce *op) override { + found = true; + } + bool found = false; + } matcher; + + auto stmt = out.compile_to_module(out.infer_arguments()) + .get_conceptual_stmt(); + stmt.accept(&matcher); + + if (!matcher.found) { + std::cout << "Did not find a VectorReduce node.\n"; + std::cout << stmt << "\n"; + return 1; + } + + std::cout << "Success!\n"; + return 0; +} diff --git a/test/correctness/vectorize_guard_with_if.cpp b/test/correctness/vectorize_guard_with_if.cpp index dad6d0fa208c..0f140a039885 100644 --- a/test/correctness/vectorize_guard_with_if.cpp +++ b/test/correctness/vectorize_guard_with_if.cpp @@ -24,8 +24,8 @@ int main(int argc, char **argv) { const int w = 100, v = 8; f.vectorize(x, v, tail_strategy); - const int expected_vector_stores = w / v; - const int expected_scalar_stores = w % v; + const int expected_vector_stores = w / v + (w % v != 0 ? 1 : 0); + const int expected_scalar_stores = 0; f.jit_handlers().custom_trace = &my_trace; f.trace_stores(); diff --git a/test/correctness/vectorize_nested.cpp b/test/correctness/vectorize_nested.cpp index e9800ef8f034..a10b283573db 100644 --- a/test/correctness/vectorize_nested.cpp +++ b/test/correctness/vectorize_nested.cpp @@ -230,45 +230,37 @@ int vectorize_inner_of_scalarization() { .vectorize(y); // We are looking for a specific loop, which shouldn't have been scalarized. - class CheckForScalarizedLoop : public Internal::IRMutator { + struct CheckForScalarizedLoop : Internal::IRMutator { using IRMutator::visit; Internal::Stmt visit(const Internal::For *op) override { if (Internal::ends_with(op->name, ".x_inner")) { - *x_loop_found = true; + x_loop_found = true; } if (Internal::ends_with(op->name, ".y_inner")) { - *y_loop_found = true; + y_loop_found = true; } return IRMutator::visit(op); } - public: - explicit CheckForScalarizedLoop(bool *fx, bool *fy) - : x_loop_found(fx), y_loop_found(fy) { - } - - bool *x_loop_found = nullptr; - bool *y_loop_found = nullptr; - }; - - bool is_x_loop_found = false; - bool is_y_loop_found = false; + bool x_loop_found = false; + bool y_loop_found = false; + } checker; - out.add_custom_lowering_pass(new CheckForScalarizedLoop(&is_x_loop_found, &is_y_loop_found)); + out.add_custom_lowering_pass(&checker, nullptr); out.compile_jit(); - if (is_x_loop_found) { + if (checker.x_loop_found) { std::cerr << "Found scalarized loop for " << x << "\n"; return 1; } - if (!is_y_loop_found) { - std::cerr << "Expected to find scalarized loop for " << y << "\n"; + if (checker.y_loop_found) { + std::cerr << "Found scalarized loop for " << y << "\n"; return 1; }