Skip to content

Commit

Permalink
Support match pvar with dtype constraint (apache#9016)
Browse files Browse the repository at this point in the history
  • Loading branch information
wrongtest-intellif authored and ylc committed Sep 29, 2021
1 parent 408d92f commit a366b02
Show file tree
Hide file tree
Showing 2 changed files with 90 additions and 1 deletion.
69 changes: 68 additions & 1 deletion src/arith/pattern_match.h
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,73 @@ class PVar : public Pattern<PVar<T>> {
mutable bool filled_{false};
};

/*!
* \brief Wrapper for pattern variable container with extra match logic.
*
* \tparam Derived the type of derived class.
* \tparam T the type of the hole.
*/
template <typename Derived, typename T>
class PVarWithCheck : public arith::Pattern<PVarWithCheck<Derived, T>> {
public:
// Store by reference in the expression.
using Nested = const PVarWithCheck<Derived, T>&;

void InitMatch_() const { pvar_.InitMatch_(); }

bool Match_(const T& value) const {
if (!static_cast<const Derived*>(this)->Match_(value)) return false;
return pvar_.Match_(value);
}

template <typename NodeRefType,
typename = typename std::enable_if<std::is_base_of<NodeRefType, T>::value>::type>
bool Match_(const NodeRefType& value) const {
if (const auto* ptr = value.template as<typename T::ContainerType>()) {
return Match_(GetRef<T>(ptr));
} else {
return false;
}
}

T Eval() const { return pvar_.Eval(); }

protected:
arith::PVar<T> pvar_;
};

/*!
* \brief Pattern variable container with expr type check.
*
* \tparam T the type of the hole.
* \tparam DType the Pattern type of dtype.
*/
template <typename T, typename DType,
typename = std::enable_if<std::is_base_of<T, PrimExpr>::value>>
class PVarWithDataType : public PVarWithCheck<PVarWithDataType<T, DType>, T> {
public:
explicit PVarWithDataType(const DType& dtype) : dtype_(dtype) {}

bool Match_(const T& value) const { return dtype_.Match_(value->dtype); }

protected:
typename DType::Nested dtype_;
};

/*!
* \brief Pattern variable container for data type with lanes.
*/
class PVecDataType : public PVarWithCheck<PVecDataType, DataType> {
public:
/*! \brief construct vector dtype placeholder with element type check */
explicit PVecDataType(const DataType& elem_dtype) : elem_dtype_(elem_dtype) {}

bool Match_(const DataType& dtype) const { return dtype.code() == elem_dtype_.code(); }

protected:
DataType elem_dtype_;
};

/*!
* \brief Constant Pattern variable container.
*
Expand Down Expand Up @@ -467,7 +534,7 @@ class PCastExpr : public Pattern<PCastExpr<DType, TA>> {
/*!
* \brief Construct a cast pattern.
*
* \param dtype The target data type, can be PVar<Type> or PConst<Type>.
* \param dtype The target data type, can be PVar<DataType> or PConst<DataType>.
* \param value The input type.
*
* \return The result pattern.
Expand Down
22 changes: 22 additions & 0 deletions tests/cpp/pattern_match_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -138,3 +138,25 @@ TEST(Pattern, IntImm) {
// cannot match tx + 1 to v
ICHECK(!(v * c).Match((tx + 1) * 3));
}

TEST(Pattern, MatchWithType) {
using namespace tvm;
// match expr with specified dtype
arith::PVarWithDataType<PrimExpr, arith::PConst<DataType>> pat(DataType::Float(32));
tir::Var x("x", DataType::Float(32));
tir::Var y("y", DataType::Float(32));
tir::Var x_int("x", DataType::Int(32));
tir::Var y_int("y", DataType::Int(32));
ICHECK(pat.Match(x + y * 2.0f));
ICHECK(!pat.Match(x_int + y_int * 2));

// match vectorized expr with specified element dtype
arith::PVecDataType vec_ty(DataType::Float(32));
arith::PVarWithDataType<PrimExpr, arith::PVecDataType> vpat(vec_ty);
tir::Var vx = tir::Var("x", DataType::Float(32, 8));
tir::Var vy("y", DataType::Float(32, 8));
tir::Var vx_int("x", DataType::Int(32, 8));
tir::Var vy_int("y", DataType::Int(32, 8));
ICHECK(vpat.Match(vx + vy * tir::Broadcast(2.0f, 8)));
ICHECK(!vpat.Match(vx_int + vy_int * tir::Broadcast(2, 8)));
}

0 comments on commit a366b02

Please sign in to comment.