Skip to content

Commit

Permalink
merge DetectIterMap and DetectIterMapPadded
Browse files Browse the repository at this point in the history
  • Loading branch information
wrongtest-intellif committed May 23, 2022
1 parent 5a2d333 commit 700b702
Show file tree
Hide file tree
Showing 14 changed files with 419 additions and 299 deletions.
103 changes: 45 additions & 58 deletions include/tvm/arith/iter_affine_map.h
Original file line number Diff line number Diff line change
Expand Up @@ -259,53 +259,29 @@ class IterSumExpr : public IterMapExpr {
TVM_DEFINE_OBJECT_REF_COW_METHOD(IterSumExprNode);
};

/*! \brief Mapping level for iterators. */
enum IterMapLevel {
// Require the mapping to be bijective.
Bijective = 0,
// Require the mapping to be subjective.
Surjective = 1,
// Require the mapping to be injective.
Injective = 2
};

/*!
* \brief Detect if indices can be written as
* [y_0 + c_0, y_1 + c_1, ..., y_n + c_n]
*
* Here y = some-quasi-affine-iter-map(input_iters)
* and c are symbolic constants.
*
* We also requires that y_i and y_j to be independent for i != j.
*
* For returned value rv, the following is always true:
* - rv[i]->args.size() <=1: only one iterator per element.
*
* \param indices The indices to detect pattern for.
* \param input_iters Map from variable to iterator's range.
* \param predicate The predicate constraints on the input iterators
* \param require_bijective A boolean flag that indicates whether the mapping should be bijective.
* \param analyzer Analyzer used to get context information.
* \param simplify_trivial_iterators If true, iterators with extent of
* 1 will be replaced with a constant value.
*
* \return The detected pattern if a match exists,
* otherwise return an empty array.
* \brief Result of DetectIterMap.
*/
Array<IterSumExpr> DetectIterMap(const Array<PrimExpr>& indices, const Map<Var, Range>& input_iters,
const PrimExpr& predicate, bool require_bijective,
arith::Analyzer* analyzer, bool simplify_trivial_iterators = true);
class IterMapResultNode : public Object {
public:
// The detected pattern if a match exists.
Array<IterSumExpr> indices;

/*! \brief A utility struct for return values from DetectPaddedIterMap
*/
struct PaddedIterMapResult {
// Any errors that occurred while converting the input indices. If
// the array is empty, the conversion was successful.
Array<String> errors;

// The detected pattern if a match exists.
Array<IterSumExpr> indices;

/* \brief Boolean expression indicating if padding was required
*
* `requires_padding` evaluates to true if the returned indices
* contain padding relative to the provided expressions, and false
* otherwise. If `input_iters` contains a variable extent, this
* expression may be in terms of those variables.
*/
PrimExpr requires_padding;

/* \brief Boolean expression indicating if a specific value w
/*! \brief Boolean expression indicating if a specific value w
*
* `padding_predicate` evaluates to true for a set of indices that
* are outside the bounds of the provided index iterators, but
Expand All @@ -314,43 +290,54 @@ struct PaddedIterMapResult {
* `input_iters`.
*/
PrimExpr padding_predicate;

// overrides
void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("errors", &errors);
v->Visit("indices", &indices);
v->Visit("padding_predicate", &padding_predicate);
}

static constexpr const char* _type_key = "arith.IterMapResult";
TVM_DECLARE_FINAL_OBJECT_INFO(IterMapResultNode, Object);
};

/*!
* \brief Managed reference to IterMapResultNode.
* \sa IterMapResultNode
*/
class IterMapResult : public ObjectRef {
public:
TVM_DEFINE_OBJECT_REF_METHODS(IterMapResult, ObjectRef, IterMapResultNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(IterMapResultNode);
};

/*!
* \brief Detect if indices can be written as
* [y_0 + c_0, y_1 + c_1, ..., y_n + c_n]
*
* Here y = some-quasi-affine-iter-map(input_iters) and c are
* symbolic constants. The y_i iterators may be padded to fit this
* representation.
* Here y = some-quasi-affine-iter-map(input_iters)
* and c are symbolic constants.
*
* We also requires that y_i and y_j to be independent for i != j.
*
* For returned value rv, the following is always true:
* - rv.indices[i]->args.size() <=1: only one iterator per element.
* - rv[i]->args.size() <=1: only one iterator per element.
*
* \param indices The indices to detect pattern for.
*
* \param input_iters Map from variable to iterator's range.
*
* \param predicate The predicate constraints on the input iterators
*
* \param require_bijective A boolean flag that indicates whether the
* mapping should be bijective. If true, no padding may be
* introduced.
*
* \param check_level The iter mapping check level.
* \param analyzer Analyzer used to get context information.
*
* \param simplify_trivial_iterators If true, iterators with extent of
* 1 will be replaced with a constant value.
*
* \return An instance of PaddedIterMapResult.
* \return The detected iteration result.
* The return object's .indices is empty on failure.
*/
PaddedIterMapResult DetectPaddedIterMap(const Array<PrimExpr>& indices,
const Map<Var, Range>& input_iters,
const PrimExpr& predicate, bool require_bijective,
arith::Analyzer* analyzer,
bool simplify_trivial_iterators = true);
IterMapResult DetectIterMap(const Array<PrimExpr>& indices, const Map<Var, Range>& input_iters,
const PrimExpr& predicate, IterMapLevel check_level,
arith::Analyzer* analyzer, bool simplify_trivial_iterators = true);

/*!
* \brief Use IterVarMap detector to rewrite and simplify the indices
Expand Down
6 changes: 3 additions & 3 deletions python/tvm/arith/iter_affine_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,14 +117,14 @@ def detect_iter_map(
Returns
-------
results : List[IterSumExpr]
results : IterMapResult
The iter map matching result.
Empty array if no match can be found.
The result's .indices is empty array if no match can be found.
"""
return _ffi_api.DetectIterMap(
indices, input_iters, predicate, require_bijective, simplify_trivial_iterators
)
).indices


def normalize_iter_map_to_expr(expr):
Expand Down
5 changes: 3 additions & 2 deletions src/arith/int_set.cc
Original file line number Diff line number Diff line change
Expand Up @@ -867,9 +867,10 @@ Optional<Array<IntSet>> EstimateRegionLowerBound(const Array<Range>& region,
for (const Range& range : region) {
affine_indices.push_back(range->min);
}
iter_sum_exprs = DetectIterMap(
auto res = DetectIterMap(
/*indices=*/affine_indices, /*input_iters=*/var_dom,
/*predicate=*/predicate, /*require_bijective=*/false, analyzer);
/*predicate=*/predicate, /*check_level=*/IterMapLevel::Surjective, analyzer);
iter_sum_exprs = res->indices;
}
if (iter_sum_exprs.empty()) {
return NullOpt;
Expand Down
Loading

0 comments on commit 700b702

Please sign in to comment.