@@ -1521,44 +1521,6 @@ Stmt zero_gpu_loop_mins(const Stmt &s) {
15211521
15221522namespace {
15231523
1524- // Find the inner most GPU block of a statement.
1525- class FindInnermostGPUBlock : public IRVisitor {
1526- using IRVisitor::visit;
1527-
1528- void visit (const For *op) override {
1529- if (op->for_type == ForType::GPUBlock) {
1530- // Set the last found GPU block to found_gpu_block.
1531- found_gpu_block = op;
1532- }
1533- IRVisitor::visit (op);
1534- }
1535-
1536- public:
1537- const For *found_gpu_block = nullptr ;
1538- };
1539-
1540- // Given a condition and a loop, add the condition
1541- // to the loop body.
1542- class AddConditionToALoop : public IRMutator {
1543- using IRMutator::visit;
1544-
1545- Stmt visit (const For *op) override {
1546- if (op != loop) {
1547- return IRMutator::visit (op);
1548- }
1549-
1550- return For::make (op->name , op->min , op->extent , op->for_type , op->partition_policy , op->device_api ,
1551- IfThenElse::make (condition, op->body , Stmt ()));
1552- }
1553-
1554- public:
1555- AddConditionToALoop (const Expr &condition, const For *loop)
1556- : condition(condition), loop(loop) {
1557- }
1558- const Expr &condition;
1559- const For *loop;
1560- };
1561-
15621524// Push if statements between GPU blocks through all GPU blocks.
15631525// Throw error if the if statement has an else clause.
15641526class NormalizeIfStatements : public IRMutator {
@@ -1578,11 +1540,24 @@ class NormalizeIfStatements : public IRMutator {
15781540 if (!inside_gpu_blocks) {
15791541 return IRMutator::visit (op);
15801542 }
1581- FindInnermostGPUBlock find;
1582- op->accept (&find);
1583- if (find.found_gpu_block != nullptr ) {
1543+ const For *innermost_gpu_block = nullptr ;
1544+ visit_with (*op, [&](auto *self, const For *loop) {
1545+ if (loop->for_type == ForType::GPUBlock) {
1546+ innermost_gpu_block = loop;
1547+ }
1548+ self->visit_base (loop);
1549+ });
1550+ if (innermost_gpu_block != nullptr ) {
15841551 internal_assert (!op->else_case .defined ()) << " Found an if statement with else case between two GPU blocks.\n " ;
1585- return AddConditionToALoop (op->condition , find.found_gpu_block ).mutate (op->then_case );
1552+ // Add the condition to the loop body
1553+ return mutate_with (op->then_case , [&](auto *self, const For *loop) {
1554+ if (loop != innermost_gpu_block) {
1555+ return self->visit_base (loop);
1556+ }
1557+ return For::make (
1558+ loop->name , loop->min , loop->extent , loop->for_type , loop->partition_policy , loop->device_api ,
1559+ IfThenElse::make (op->condition , loop->body , Stmt ()));
1560+ });
15861561 }
15871562 return IRMutator::visit (op);
15881563 }
0 commit comments