Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Teach unrolling to exploit conditions in enclosing ifs #7969

Merged
merged 4 commits into from
Dec 8, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -470,6 +470,7 @@ SOURCE_FILES = \
BoundaryConditions.cpp \
Bounds.cpp \
BoundsInference.cpp \
BoundConstantExtentLoops.cpp \
BoundSmallAllocations.cpp \
Buffer.cpp \
Callable.cpp \
Expand Down Expand Up @@ -665,6 +666,7 @@ HEADER_FILES = \
BoundaryConditions.h \
Bounds.h \
BoundsInference.h \
BoundConstantExtentLoops.h \
BoundSmallAllocations.h \
Buffer.h \
Callable.h \
Expand Down
136 changes: 136 additions & 0 deletions src/BoundConstantExtentLoops.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
#include "BoundConstantExtentLoops.h"
#include "Bounds.h"
#include "CSE.h"
#include "IRMutator.h"
#include "IROperator.h"
#include "Simplify.h"
#include "SimplifyCorrelatedDifferences.h"
#include "Substitute.h"

namespace Halide {
namespace Internal {

namespace {
class BoundLoops : public IRMutator {
using IRMutator::visit;

std::vector<std::pair<std::string, Expr>> lets;

Stmt visit(const LetStmt *op) override {
if (is_pure(op->value)) {
lets.emplace_back(op->name, op->value);
Stmt s = IRMutator::visit(op);
lets.pop_back();
return s;
} else {
return IRMutator::visit(op);
}
}

std::vector<Expr> facts;
Stmt visit(const IfThenElse *op) override {
facts.push_back(op->condition);
Stmt then_case = mutate(op->then_case);
Stmt else_case;
if (op->else_case.defined()) {
facts.back() = simplify(!op->condition);
else_case = mutate(op->else_case);
}
facts.pop_back();
if (then_case.same_as(op->then_case) &&
else_case.same_as(op->else_case)) {
return op;
} else {
return IfThenElse::make(op->condition, then_case, else_case);
}
}

Stmt visit(const For *op) override {
if (is_const(op->extent)) {
// Nothing needs to be done
return IRMutator::visit(op);
}

if (op->for_type == ForType::Unrolled ||
op->for_type == ForType::Vectorized) {
// Give it one last chance to simplify to an int
Expr extent = simplify(op->extent);
Stmt body = op->body;
const IntImm *e = extent.as<IntImm>();

if (e == nullptr) {
// We're about to hard fail. Get really aggressive
// with the simplifier.
for (auto it = lets.rbegin(); it != lets.rend(); it++) {
extent = Let::make(it->first, it->second, extent);
}
extent = remove_likelies(extent);
extent = substitute_in_all_lets(extent);
extent = simplify(extent,
true,
Scope<Interval>::empty_scope(),
Scope<ModulusRemainder>::empty_scope(),
facts);
e = extent.as<IntImm>();
}

Expr extent_upper;
if (e == nullptr) {
// Still no luck. Try taking an upper bound and
// injecting an if statement around the body.
extent_upper = find_constant_bound(extent, Direction::Upper, Scope<Interval>());
if (extent_upper.defined()) {
e = extent_upper.as<IntImm>();
body =
IfThenElse::make(likely_if_innermost(Variable::make(Int(32), op->name) <
op->min + op->extent),
body);
}
}

if (e == nullptr && permit_failed_unroll && op->for_type == ForType::Unrolled) {
// Still no luck, but we're allowed to fail. Rewrite
// to a serial loop.
user_warning << "HL_PERMIT_FAILED_UNROLL is allowing us to unroll a non-constant loop into a serial loop. Did you mean to do this?\n";
body = mutate(body);
return For::make(op->name, op->min, op->extent,
ForType::Serial, op->partition_policy, op->device_api, std::move(body));
}

user_assert(e)
<< "Can only " << (op->for_type == ForType::Unrolled ? "unroll" : "vectorize")
<< " for loops over a constant extent.\n"
<< "Loop over " << op->name << " has extent " << extent << ".\n";
body = mutate(body);

return For::make(op->name, op->min, e,
op->for_type, op->partition_policy, op->device_api, std::move(body));
} else {
return IRMutator::visit(op);
}
}
bool permit_failed_unroll = false;

public:
BoundLoops() {
// Experimental autoschedulers may want to unroll without
// being totally confident the loop will indeed turn out
// to be constant-sized. If this feature continues to be
// important, we need to expose it in the scheduling
// language somewhere, but how? For now we do something
// ugly and expedient.

// For the tracking issue to fix this, see
// https://github.com/halide/Halide/issues/3479
permit_failed_unroll = get_env_variable("HL_PERMIT_FAILED_UNROLL") == "1";
}
};

} // namespace

Stmt bound_constant_extent_loops(const Stmt &s) {
return BoundLoops().mutate(s);
}

} // namespace Internal
} // namespace Halide
24 changes: 24 additions & 0 deletions src/BoundConstantExtentLoops.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
#ifndef HALIDE_BOUND_CONSTANT_EXTENT_LOOPS_H
#define HALIDE_BOUND_CONSTANT_EXTENT_LOOPS_H

/** \file
* Defines the lowering pass that enforces a constant extent on all
* vectorized or unrolled loops.
*/

#include "Expr.h"

namespace Halide {
namespace Internal {

/** Replace all loop extents of unrolled or vectorized loops with constants, by
* substituting and simplifying as needed. If we can't determine a constant
* extent, but can determine a constant upper bound, inject an if statement into
* the body. If we can't even determine a constant upper bound, throw a user
* error. */
Stmt bound_constant_extent_loops(const Stmt &s);

} // namespace Internal
} // namespace Halide

#endif
4 changes: 2 additions & 2 deletions src/BoundsInference.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1013,11 +1013,11 @@ class BoundsInference : public IRMutator {
}

// Dump out the region required of each stage for debugging.

/*
debug(0) << "Box required of " << producer.name
<< " by " << consumer.name
<< " stage " << consumer.stage << ":\n";
<< " stage " << consumer.stage << ":\n"
<< " used: " << b.used << "\n";
for (size_t k = 0; k < b.size(); k++) {
debug(0) << " " << b[k].min << " ... " << b[k].max << "\n";
}
Expand Down
4 changes: 3 additions & 1 deletion src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@ set(HEADER_FILES
BoundaryConditions.h
Bounds.h
BoundsInference.h
BoundSmallAllocations.h
BoundConstantExtentLoops.h
BoundSmallAllocations.h
Buffer.h
Callable.h
CanonicalizeGPUVars.h
Expand Down Expand Up @@ -189,6 +190,7 @@ set(SOURCE_FILES
BoundaryConditions.cpp
Bounds.cpp
BoundsInference.cpp
BoundConstantExtentLoops.cpp
BoundSmallAllocations.cpp
Buffer.cpp
Callable.cpp
Expand Down
5 changes: 5 additions & 0 deletions src/Lower.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include "AddParameterChecks.h"
#include "AllocationBoundsInference.h"
#include "AsyncProducers.h"
#include "BoundConstantExtentLoops.h"
#include "BoundSmallAllocations.h"
#include "Bounds.h"
#include "BoundsInference.h"
Expand Down Expand Up @@ -312,6 +313,10 @@ void lower_impl(const vector<Function> &output_funcs,
s = simplify_correlated_differences(s);
log("Lowering after simplifying correlated differences:", s);

debug(1) << "Bounding constant extent loops...\n";
s = bound_constant_extent_loops(s);
log("Lowering after bounding constant extent loops:", s);

debug(1) << "Unrolling...\n";
s = unroll_loops(s);
log("Lowering after unrolling:", s);
Expand Down
14 changes: 12 additions & 2 deletions src/Simplify.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -355,8 +355,13 @@ Simplify::ScopedFact::~ScopedFact() {

Expr simplify(const Expr &e, bool remove_dead_let_stmts,
const Scope<Interval> &bounds,
const Scope<ModulusRemainder> &alignment) {
const Scope<ModulusRemainder> &alignment,
const std::vector<Expr> &assumptions) {
Simplify m(remove_dead_let_stmts, &bounds, &alignment);
std::vector<Simplify::ScopedFact> facts;
for (const Expr &a : assumptions) {
facts.push_back(m.scoped_truth(a));
}
Expr result = m.mutate(e, nullptr);
if (m.in_unreachable) {
return unreachable(e.type());
Expand All @@ -366,8 +371,13 @@ Expr simplify(const Expr &e, bool remove_dead_let_stmts,

Stmt simplify(const Stmt &s, bool remove_dead_let_stmts,
const Scope<Interval> &bounds,
const Scope<ModulusRemainder> &alignment) {
const Scope<ModulusRemainder> &alignment,
const std::vector<Expr> &assumptions) {
Simplify m(remove_dead_let_stmts, &bounds, &alignment);
std::vector<Simplify::ScopedFact> facts;
for (const Expr &a : assumptions) {
facts.push_back(m.scoped_truth(a));
}
Stmt result = m.mutate(s);
if (m.in_unreachable) {
return Evaluate::make(unreachable());
Expand Down
17 changes: 10 additions & 7 deletions src/Simplify.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,19 +13,22 @@
namespace Halide {
namespace Internal {

/** Perform a a wide range of simplifications to expressions and
* statements, including constant folding, substituting in trivial
* values, arithmetic rearranging, etc. Simplifies across let
* statements, so must not be called on stmts with dangling or
* repeated variable names.
/** Perform a wide range of simplifications to expressions and statements,
* including constant folding, substituting in trivial values, arithmetic
* rearranging, etc. Simplifies across let statements, so must not be called on
* stmts with dangling or repeated variable names. Can optionally be passed
* known bounds of any variables, known alignment properties, and any other
* Exprs that should be assumed to be true.
*/
// @{
Stmt simplify(const Stmt &, bool remove_dead_code = true,
const Scope<Interval> &bounds = Scope<Interval>::empty_scope(),
const Scope<ModulusRemainder> &alignment = Scope<ModulusRemainder>::empty_scope());
const Scope<ModulusRemainder> &alignment = Scope<ModulusRemainder>::empty_scope(),
const std::vector<Expr> &assumptions = std::vector<Expr>());
Expr simplify(const Expr &, bool remove_dead_code = true,
const Scope<Interval> &bounds = Scope<Interval>::empty_scope(),
const Scope<ModulusRemainder> &alignment = Scope<ModulusRemainder>::empty_scope());
const Scope<ModulusRemainder> &alignment = Scope<ModulusRemainder>::empty_scope(),
const std::vector<Expr> &assumptions = std::vector<Expr>());
// @}

/** Attempt to statically prove an expression is true using the simplifier. */
Expand Down
Loading