Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ARITH] Analyzer CanonicalSimplifier #2891

Merged
merged 9 commits into from
Mar 31, 2019
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 36 additions & 0 deletions include/tvm/arithmetic.h
Original file line number Diff line number Diff line change
Expand Up @@ -218,13 +218,47 @@ class RewriteSimplifier {
private:
friend class Analyzer;
friend class ConstraintContext;
friend class CanonicalSimplifier;
explicit RewriteSimplifier(Analyzer* parent);
~RewriteSimplifier();
class Impl;
/*! \brief Internal impl */
Impl* impl_;
};

/*!
* \brief Canonical-form based simplifier.
*/
class CanonicalSimplifier {
public:
/*!
* \brief analyze the expr
* \param expr The expression of interest.
* \return the result of the analysis.
*/
Expr operator()(const Expr& expr);

/*!
* \brief Update binding of var to a new expression.
*
* \param var The variable of interest.
* \param new_expr
* \param override Whether do we allow override of existing information.
*/
void Update(const Var& var,
const Expr& new_expr,
bool override = false);

private:
friend class Analyzer;
friend class ConstraintContext;
explicit CanonicalSimplifier(Analyzer* parent);
~CanonicalSimplifier();
class Impl;
/*! \brief Internal impl */
Impl* impl_;
};

/*!
* \brief A RAII constraint context.
*
Expand Down Expand Up @@ -277,6 +311,8 @@ class Analyzer {
ModularSetAnalyzer modular_set;
/*! \brief sub-analyzer rewrite simplfy */
RewriteSimplifier rewrite_simplify;
/*! \brief sub-analyzer rewrite simplfy */
CanonicalSimplifier canonical_simplify;
/*! \brief constructor */
Analyzer();
/*!
Expand Down
39 changes: 39 additions & 0 deletions include/tvm/base.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include <string>
#include <memory>
#include <functional>
#include <utility>
#include "runtime/registry.h"

namespace tvm {
Expand All @@ -32,6 +33,44 @@ using ::tvm::AttrVisitor;
using ContainerType = NodeName; \
}; \

/*!
* \brief Macro to make it easy to define node ref type that
* has a CopyOnWrite member function.
*
* CopyOnWrite will generate a unique copy of the internal node.
* The node will be copied if it is referenced by multiple places.
* The function returns the raw pointer to the node to allow modification
* of the content.
*
* \code
*
* MyCOWNodeRef ref, ref2;
* ref2 = ref;
* ref.CopyOnWrite()->value = new_value;
* assert(ref2->value == old_value);
* assert(ref->value == new_value);
*
* \endcode
*/
#define TVM_DEFINE_COW_NODE_REF(TypeName, BaseType, NodeName) \
class TypeName : public BaseType { \
public: \
TypeName() {} \
explicit TypeName(::tvm::NodePtr<::tvm::Node> n) : BaseType(n) {} \
const NodeName* operator->() const { \
return static_cast<const NodeName*>(node_.get()); \
} \
inline NodeName* CopyOnWrite() { \
CHECK(node_ != nullptr); \
if (!node_.unique()) { \
NodePtr<NodeName> n = make_node<NodeName>(*(operator->())); \
NodePtr<Node>(std::move(n)).swap(node_); \
} \
return static_cast<NodeName*>(node_.get()); \
} \
using ContainerType = NodeName; \
};


/*!
* \brief save the node as well as all the node it depends on as json.
Expand Down
16 changes: 16 additions & 0 deletions python/tvm/arith.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ def __init__(self):
self._bind = _mod("bind")
self._modular_set = _mod("modular_set")
self._rewrite_simplify = _mod("rewrite_simplify")
self._canonical_simplify = _mod("canonical_simplify")
self._enter_constraint_context = _mod("enter_constraint_context")

def const_int_bound(self, expr):
Expand Down Expand Up @@ -144,6 +145,21 @@ def rewrite_simplify(self, expr):
"""
return self._rewrite_simplify(expr)

def canonical_simplify(self, expr):
"""Simplify expression via canonicalization.

Parameters
----------
expr : tvm.Expr
The expression.

Returns
-------
result : Expr
The result.
"""
return self._canonical_simplify(expr)

def bind(self, var, expr):
"""Bind a variable to the expression.

Expand Down
4 changes: 4 additions & 0 deletions src/api/api_arith.cc
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,10 @@ TVM_REGISTER_API("arith._CreateAnalyzer")
return PackedFunc([self](TVMArgs args, TVMRetValue *ret) {
*ret = self->rewrite_simplify(args[0]);
});
} else if (name == "canonical_simplify") {
return PackedFunc([self](TVMArgs args, TVMRetValue *ret) {
*ret = self->canonical_simplify(args[0]);
});
} else if (name == "bind") {
return PackedFunc([self](TVMArgs args, TVMRetValue *ret) {
auto& sptr = args[1].node_sptr();
Expand Down
16 changes: 12 additions & 4 deletions src/arithmetic/analyzer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,21 @@ namespace arith {
Analyzer::Analyzer()
: const_int_bound(this),
modular_set(this),
rewrite_simplify(this) {
rewrite_simplify(this),
canonical_simplify(this) {
}

void Analyzer::Bind(const VarExpr& v, const Expr& expr) {
Var var(v.node_);
this->const_int_bound.Update(var, this->const_int_bound(expr));
this->modular_set.Update(var, this->modular_set(expr));
this->rewrite_simplify.Update(var, this->rewrite_simplify(expr));

Expr new_expr = expr;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't understand why do you copy expr to new_expr here.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To make new expr mutable

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But none of the subsequent lines changes it, or am I wrong?.

new_expr = this->canonical_simplify(new_expr);
new_expr = this->rewrite_simplify(new_expr);

this->const_int_bound.Update(var, this->const_int_bound(new_expr));
this->modular_set.Update(var, this->modular_set(new_expr));
this->rewrite_simplify.Update(var, new_expr);
this->canonical_simplify.Update(var, new_expr);
}

void Analyzer::Bind(const VarExpr& v, const Range& range) {
Expand Down Expand Up @@ -47,5 +54,6 @@ bool Analyzer::CanProveGreaterEqual(const Expr& expr, int64_t lower_bound) {
if (bd->min_value >= lower_bound) return true;
return false;
}

} // namespace arith
} // namespace tvm
Loading