Skip to content

Commit b8b5d14

Browse files
committed
FuseGPUThreadLoops.cpp: convert FindInnermostGPUBlock & AddConditionToALoop
1 parent 69a57ae commit b8b5d14

File tree

1 file changed

+17
-42
lines changed

1 file changed

+17
-42
lines changed

src/FuseGPUThreadLoops.cpp

Lines changed: 17 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -1521,44 +1521,6 @@ Stmt zero_gpu_loop_mins(const Stmt &s) {
15211521

15221522
namespace {
15231523

1524-
// Find the inner most GPU block of a statement.
1525-
class FindInnermostGPUBlock : public IRVisitor {
1526-
using IRVisitor::visit;
1527-
1528-
void visit(const For *op) override {
1529-
if (op->for_type == ForType::GPUBlock) {
1530-
// Set the last found GPU block to found_gpu_block.
1531-
found_gpu_block = op;
1532-
}
1533-
IRVisitor::visit(op);
1534-
}
1535-
1536-
public:
1537-
const For *found_gpu_block = nullptr;
1538-
};
1539-
1540-
// Given a condition and a loop, add the condition
1541-
// to the loop body.
1542-
class AddConditionToALoop : public IRMutator {
1543-
using IRMutator::visit;
1544-
1545-
Stmt visit(const For *op) override {
1546-
if (op != loop) {
1547-
return IRMutator::visit(op);
1548-
}
1549-
1550-
return For::make(op->name, op->min, op->extent, op->for_type, op->partition_policy, op->device_api,
1551-
IfThenElse::make(condition, op->body, Stmt()));
1552-
}
1553-
1554-
public:
1555-
AddConditionToALoop(const Expr &condition, const For *loop)
1556-
: condition(condition), loop(loop) {
1557-
}
1558-
const Expr &condition;
1559-
const For *loop;
1560-
};
1561-
15621524
// Push if statements between GPU blocks through all GPU blocks.
15631525
// Throw error if the if statement has an else clause.
15641526
class NormalizeIfStatements : public IRMutator {
@@ -1578,11 +1540,24 @@ class NormalizeIfStatements : public IRMutator {
15781540
if (!inside_gpu_blocks) {
15791541
return IRMutator::visit(op);
15801542
}
1581-
FindInnermostGPUBlock find;
1582-
op->accept(&find);
1583-
if (find.found_gpu_block != nullptr) {
1543+
const For *innermost_gpu_block = nullptr;
1544+
visit_with(*op, [&](auto *self, const For *loop) {
1545+
if (loop->for_type == ForType::GPUBlock) {
1546+
innermost_gpu_block = loop;
1547+
}
1548+
self->visit_base(loop);
1549+
});
1550+
if (innermost_gpu_block != nullptr) {
15841551
internal_assert(!op->else_case.defined()) << "Found an if statement with else case between two GPU blocks.\n";
1585-
return AddConditionToALoop(op->condition, find.found_gpu_block).mutate(op->then_case);
1552+
// Add the condition to the loop body
1553+
return mutate_with(op->then_case, [&](auto *self, const For *loop) {
1554+
if (loop != innermost_gpu_block) {
1555+
return self->visit_base(loop);
1556+
}
1557+
return For::make(
1558+
loop->name, loop->min, loop->extent, loop->for_type, loop->partition_policy, loop->device_api,
1559+
IfThenElse::make(op->condition, loop->body, Stmt()));
1560+
});
15861561
}
15871562
return IRMutator::visit(op);
15881563
}

0 commit comments

Comments
 (0)