Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【CINN】Add split with variable in factors and rewrite error handler of vectorize,unroll,bind schedule primitives #60449

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 22 additions & 4 deletions paddle/cinn/ir/schedule/impl/for_type.cc
Original file line number Diff line number Diff line change
Expand Up @@ -63,20 +63,37 @@ void DyScheduleImpl::Parallel(const Expr& loop) {
}

void DyScheduleImpl::Vectorize(const Expr& loop, int factor) {
CINN_IR_SCHEDULE_BEGIN();
std::string primitive = "Vectorize";
std::ostringstream os;
CHECK_GT(factor, 0) << "vectorize factor should be more than 0";
CHECK(loop.As<For>()->extent.is_constant())
<< "The loop to be vectorized should be constant!\n";
if (factor <= 0) {
os << "vectorize factor should be more than 0\n";
throw IRScheduleErrorHandler(primitive, os.str(), module_expr_);
}
if (!loop.As<For>()->extent.is_constant()) {
os << "The loop to be vectorized should be constant!\n";
throw IRScheduleErrorHandler(primitive, os.str(), module_expr_);
}
MutateForType(loop, ForType::Vectorized, factor);
CINN_IR_SCHEDULE_END(this->err_msg_level_);
}

void DyScheduleImpl::Unroll(const Expr& loop) {
CHECK(loop.As<For>()->extent.is_constant())
<< "The loop to be unrolled should be constant!\n";
CINN_IR_SCHEDULE_BEGIN();
std::string primitive = "Unroll";
std::ostringstream os;
if (!loop.As<For>()->extent.is_constant()) {
os << "The loop to be unrolled should be constant!\n";
throw IRScheduleErrorHandler(primitive, os.str(), module_expr_);
}
MutateForType(loop, ForType::Unrolled);
CINN_IR_SCHEDULE_END(this->err_msg_level_);
}

void DyScheduleImpl::Bind(const Expr& loop, const std::string& thread_axis) {
#ifdef CINN_WITH_CUDA
CINN_IR_SCHEDULE_BEGIN();
std::string primitive = "Bind";
std::ostringstream os;

Expand Down Expand Up @@ -117,6 +134,7 @@ void DyScheduleImpl::Bind(const Expr& loop, const std::string& thread_axis) {
}
MutateForType(loop, ForType::GPUThread, offset);
}
CINN_IR_SCHEDULE_END(this->err_msg_level_);
#endif
}
} // namespace ir
Expand Down
2 changes: 2 additions & 0 deletions paddle/cinn/ir/schedule/impl/ir_schedule.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ class DyScheduleImpl : public ScheduleBase {
std::vector<Expr> GetChildBlocks(const Expr& expr) const;
Expr GetBlock(const std::string& block_name) const;
std::vector<Expr> Split(const Expr& loop, const std::vector<int>& factors);
std::vector<Expr> Split(const Expr& loop, const std::vector<Expr>& factors);
std::vector<Expr> SamplePerfectTile(
utils::LinearRandomEngine::StateType* rand_seed,
const Expr& loop,
Expand Down Expand Up @@ -122,6 +123,7 @@ class StScheduleImpl : public ScheduleBase {
std::vector<Expr> GetChildBlocks(const Expr& expr) const;
Expr GetBlock(const std::string& block_name) const;
std::vector<Expr> Split(const Expr& loop, const std::vector<int>& factors);
std::vector<Expr> Split(const Expr& loop, const std::vector<Expr>& factors);
std::vector<Expr> SamplePerfectTile(
utils::LinearRandomEngine::StateType* rand_seed,
const Expr& loop,
Expand Down
67 changes: 66 additions & 1 deletion paddle/cinn/ir/schedule/impl/loop_transformation.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,11 @@
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/cinn/common/macros.h"
#include "paddle/cinn/ir/schedule/impl/ir_schedule.h"

#include "paddle/cinn/common/integer_set.h"
#include "paddle/cinn/common/macros.h"

/** \brief A macro that guards the beginning of each implementation of schedule
*/
#define CINN_IR_SCHEDULE_BEGIN() try {
Expand Down Expand Up @@ -156,6 +158,63 @@ std::vector<Expr> DyScheduleImpl::Split(const Expr& loop,
return splited_loops;
}

// TODO(@LiuYang): now -1 can't exsit in factors,
std::vector<Expr> DyScheduleImpl::Split(const Expr& loop,
const std::vector<Expr>& factors) {
CHECK(loop.As<ir::For>())
<< "Expr param of Split must be For node! Please check.";
auto* for_node = loop.As<ir::For>();
CHECK(common::is_zero(for_node->min))
<< "The For node must start with 0! Please check.";
CHECK(!factors.empty())
<< "The factors param of Split should not be empty! Please check.";
CHECK(!loop.As<ir::For>()->extent.is_constant())
<< "Can't Split a loop with constant extent but with variable in "
"factors!";
Expr tot_extent = for_node->extent;

VLOG(3) << "Try Split loop from (" << for_node->loop_var->name << ", 0, "
<< tot_extent << ") to (" << cinn::utils::Join(factors, ", ")
<< ") at loop:\n"
<< loop;

std::vector<Expr> process_factors(factors);
Expr prod_size(1);
for (auto factor : factors) prod_size = prod_size * Expr(factor);
cinn::common::SymbolicExprAnalyzer analyzer({});
CHECK(analyzer.ProveEQ(tot_extent, prod_size).value_or(false))
<< "Product of factors can't be proved to be equal to the extent of "
"current for loop!";
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

目前只支持可以证明相等的情况,未来如果有需求的话,可以考虑:无法证明相等时,增加predicate条件分支。


std::vector<Var> new_loop_vars;
Expr substitute_value(0);
for (int i = 0; i < process_factors.size(); ++i) {
Var temp_var(common::UniqName(for_node->loop_var->name));
substitute_value = Expr(temp_var) + substitute_value * process_factors[i];
new_loop_vars.push_back(temp_var);
}
substitute_value = cinn::common::AutoSimplify(substitute_value);
Expr new_node = ir::ir_utils::IRCopy(for_node->body);
ReplaceExpr(&new_node, {for_node->loop_var}, {substitute_value});
std::vector<Expr> splited_loops;
splited_loops.resize(process_factors.size());

for (int i = process_factors.size() - 1; i >= 0; i--) {
if (!new_node.As<ir::Block>()) new_node = Block::Make({new_node});
new_node = For::Make(new_loop_vars[i],
Expr(0),
process_factors[i],
for_node->for_type(),
for_node->device_api,
new_node);
splited_loops[i] = new_node;
}

this->Replace(loop, new_node);
VLOG(3) << "After Split, ir is:\n" << splited_loops.at(0);
return splited_loops;
}

Expr DyScheduleImpl::Fuse(const std::vector<Expr>& loops) {
VLOG(3) << "Tring to fuse:\n" << cinn::utils::Join(loops, "\n");
std::vector<const ir::For*> for_nodes;
Expand Down Expand Up @@ -369,6 +428,12 @@ std::vector<Expr> StScheduleImpl::Split(const Expr& loop,
return splited_loops;
}

std::vector<Expr> StScheduleImpl::Split(const Expr& loop,
const std::vector<Expr>& factors) {
CHECK(false) << "Static shape schedule don't support Split with some "
"variables in factors";
}

Expr StScheduleImpl::Fuse(const std::vector<Expr>& loops) {
VLOG(3) << "Tring to fuse:\n" << cinn::utils::Join(loops, "\n");
std::vector<const ir::For*> for_nodes;
Expand Down
15 changes: 10 additions & 5 deletions paddle/cinn/ir/schedule/ir_schedule.cc
Original file line number Diff line number Diff line change
Expand Up @@ -405,11 +405,16 @@ std::vector<Expr> IRSchedule::Split(const std::string& block_name,
std::vector<Expr> IRSchedule::Split(const Expr& loop,
const std::vector<Expr>& factors) {
std::vector<int> int_factors;
std::transform(factors.begin(),
factors.end(),
std::back_inserter(int_factors),
[](Expr x) { return x.as_int32(); });
auto results = impl_->Split(loop, int_factors);
std::vector<Expr> results;
std::for_each(factors.begin(), factors.end(), [&int_factors](const Expr& e) {
if (e.is_constant()) int_factors.push_back(e.as_int32());
});
if (int_factors.size() == factors.size()) {
results = impl_->Split(loop, int_factors);
} else {
results = impl_->Split(loop, factors);
}

trace_.Append(ScheduleDesc::Step(
"Split",
{{"loop", std::vector<Expr>({loop})}, {"factors", factors}},
Expand Down
2 changes: 2 additions & 0 deletions paddle/cinn/ir/schedule/schedule_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,8 @@ class ScheduleBase {
virtual Expr GetBlock(const std::string& block_name) const = 0;
virtual std::vector<Expr> Split(const Expr& loop,
const std::vector<int>& factors) = 0;
virtual std::vector<Expr> Split(const Expr& loop,
const std::vector<Expr>& factors) = 0;
virtual std::vector<Expr> SamplePerfectTile(
utils::LinearRandomEngine::StateType* rand_seed,
const Expr& loop,
Expand Down