From f9f03897bde45f716d09abceb0058c59e4734e72 Mon Sep 17 00:00:00 2001 From: Ruihang Lai Date: Mon, 10 May 2021 23:32:39 +0800 Subject: [PATCH] `IsAffineBinding(...)` --- src/tir/schedule/analysis.h | 4 +++- src/tir/schedule/analysis/analysis.cc | 22 ++++++++++------------ 2 files changed, 13 insertions(+), 13 deletions(-) diff --git a/src/tir/schedule/analysis.h b/src/tir/schedule/analysis.h index 4de72c54ed..6ceac5f9ef 100644 --- a/src/tir/schedule/analysis.h +++ b/src/tir/schedule/analysis.h @@ -22,6 +22,7 @@ #include #include #include +#include namespace tvm { namespace tir { @@ -60,7 +61,8 @@ bool IsReductionBlock(const ScheduleState& self, const StmtSRef& block_sref, /******** Binding ********/ // Todo -bool ValidateBlockBinding(const BlockRealize& realize, const Map& loop_var_ranges); +bool IsAffineBinding(const BlockRealize& realize, const Map& loop_var_ranges, + arith::Analyzer* analyzer); /******** Block-loop relation ********/ /*! diff --git a/src/tir/schedule/analysis/analysis.cc b/src/tir/schedule/analysis/analysis.cc index 376ccb4c8e..6d1c46a068 100644 --- a/src/tir/schedule/analysis/analysis.cc +++ b/src/tir/schedule/analysis/analysis.cc @@ -110,21 +110,17 @@ bool IsReductionBlock(const ScheduleState& self, const StmtSRef& block_sref, /******** Binding ********/ -bool ValidateBlockBinding(const BlockRealize& realize, const Map& loop_var_ranges) { +bool IsAffineBinding(const BlockRealize& realize, const Map& loop_var_ranges, + arith::Analyzer* analyzer) { if (loop_var_ranges.empty()) { return true; } - arith::Analyzer analyzer; - Array leaf_iters; - leaf_iters.reserve(static_cast(realize->block->iter_vars.size())); - for (const IterVar& iter_var : realize->block->iter_vars) { - leaf_iters.push_back(iter_var); - } Array results = arith::DetectIterMap( - /*leaf_iters=*/leaf_iters, - /*bindings=*/realize->iter_values, // Todo - /*root_iters=*/loop_var_ranges, - /*input_pred=*/realize->predicate, /*analyzer=*/&analyzer); + /*indices=*/realize->iter_values, + /*input_iters=*/loop_var_ranges, + /*predicate=*/realize->predicate, + /*require_bijective=*/false, + /*analyzer=*/analyzer); if (results.empty()) { return false; } @@ -400,7 +396,9 @@ void UpdateAffineFlag(ScheduleState self, const StmtSRef& block_sref) { } } ICHECK(self->block_info.count(block_sref)); - self->block_info[block_sref].affine_binding = ValidateBlockBinding(realize, loop_var_ranges); + arith::Analyzer analyzer; + self->block_info[block_sref].affine_binding = + IsAffineBinding(realize, loop_var_ranges, &analyzer); } /******** Pattern Matcher ********/