Skip to content

Commit

Permalink
[PIR] Unify dyn_cast interface with pir::cast (PaddlePaddle#57463)
Browse files Browse the repository at this point in the history
* PR comment

* unify dyn_cast_interface

* rm detail
  • Loading branch information
zhangbopd authored and Frida-a committed Oct 14, 2023
1 parent 9e3a56e commit ecd5426
Show file tree
Hide file tree
Showing 15 changed files with 124 additions and 134 deletions.
2 changes: 1 addition & 1 deletion paddle/fluid/pir/dialect/kernel/ir/kernel_type.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ const phi::Place& AllocatedDenseTensorType::place() const {
return storage()->place_;
}

const pir::Type& AllocatedDenseTensorType::dtype() const {
pir::Type AllocatedDenseTensorType::dtype() const {
return storage()->dense_tensor_type_.dtype();
}

Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/pir/dialect/kernel/ir/kernel_type.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ class AllocatedDenseTensorType

const phi::Place &place() const;

const pir::Type &dtype() const;
pir::Type dtype() const;

const phi::DDim &dims() const;

Expand Down
1 change: 1 addition & 0 deletions paddle/fluid/pir/dialect/operator/interface/op_yaml_info.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ class OpYamlInfoInterface : public pir::OpInterfaceBase<OpYamlInfoInterface> {
Model() : Concept(GetOpInfo) {}
};

/// Constructor
OpYamlInfoInterface(pir::Operation *op, Concept *impl)
: pir::OpInterfaceBase<OpYamlInfoInterface>(op), impl_(impl) {}

Expand Down
1 change: 1 addition & 0 deletions paddle/fluid/pir/dialect/operator/interface/vjp.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ class VjpInterface : public pir::OpInterfaceBase<VjpInterface> {
Model() : Concept(Vjp) {}
};

/// Constructor
VjpInterface(pir::Operation* op, Concept* impl)
: pir::OpInterfaceBase<VjpInterface>(op), impl_(impl) {}

Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/pir/dialect/operator/ir/op_dialect.cc
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ void OperatorDialect::PrintType(pir::Type type, std::ostream &os) const {
if (auto tensor_type = type.dyn_cast<DenseTensorType>()) {
os << "tensor<";
for (auto d : phi::vectorize(tensor_type.dims())) {
pir::ShapedTypeInterface::isDynamic(d) ? os << "?" : os << d;
pir::ShapedTypeInterface::IsDynamic(d) ? os << "?" : os << d;
os << "x";
}
tensor_type.dtype().Print(os);
Expand Down
2 changes: 1 addition & 1 deletion paddle/pir/core/builtin_type.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
namespace pir {
std::vector<Type> VectorType::data() const { return storage()->GetAsKey(); }

const pir::Type& DenseTensorType::dtype() const { return storage()->dtype_; }
pir::Type DenseTensorType::dtype() const { return storage()->dtype_; }

const DenseTensorTypeStorage::Dim& DenseTensorType::dims() const {
return storage()->dims_;
Expand Down
2 changes: 1 addition & 1 deletion paddle/pir/core/builtin_type.h
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ class DenseTensorType : public Type::TypeBase<DenseTensorType,
public:
using Base::Base;

const Type &dtype() const;
Type dtype() const;

const DenseTensorTypeStorage::Dim &dims() const;

Expand Down
8 changes: 4 additions & 4 deletions paddle/pir/core/builtin_type_interfaces.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,12 @@

namespace pir {

Type ShapedTypeInterface::getElementType() const {
return impl_->get_element_type_(*this);
Type ShapedTypeInterface::GetElementType() const {
return impl_->get_element_type(*this);
}

phi::DDim ShapedTypeInterface::getShape() const {
return impl_->get_shape_(*this);
phi::DDim ShapedTypeInterface::GetShape() const {
return impl_->get_shape(*this);
}

} // namespace pir
Expand Down
100 changes: 38 additions & 62 deletions paddle/pir/core/builtin_type_interfaces.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,52 +14,16 @@

#pragma once

#include <algorithm>
#include <vector>

#include "paddle/phi/core/tensor_base.h"
#include "paddle/pir/core/cast_utils.h"
#include "paddle/pir/core/enforce.h"
#include "paddle/pir/core/type.h"

namespace pir {

namespace detail {

template <typename RangeT>
constexpr auto begin_impl(RangeT &&range)
-> decltype(std::begin(std::forward<RangeT>(range))) {
return std::begin(std::forward<RangeT>(range));
}

template <typename RangeT>
constexpr auto end_impl(RangeT &&range)
-> decltype(std::end(std::forward<RangeT>(range))) {
return std::end(std::forward<RangeT>(range));
}

template <typename RangeT>
constexpr auto adl_begin(RangeT &&range)
-> decltype(begin_impl(std::forward<RangeT>(range))) {
return begin_impl(std::forward<RangeT>(range));
}

template <typename RangeT>
constexpr auto adl_end(RangeT &&range)
-> decltype(end_impl(std::forward<RangeT>(range))) {
return end_impl(std::forward<RangeT>(range));
}

template <typename R, typename UnaryPredicate>
bool any_of(R &&Range, UnaryPredicate P) {
return std::any_of(adl_begin(Range), adl_end(Range), P);
}

template <typename R, typename UnaryPredicate>
auto count_if(R &&Range, UnaryPredicate P) {
return std::count_if(adl_begin(Range), adl_end(Range), P);
}

} // namespace detail

class ShapedTypeInterface : public TypeInterfaceBase<ShapedTypeInterface> {
public:
using DDim = phi::DDim;
Expand All @@ -68,10 +32,10 @@ class ShapedTypeInterface : public TypeInterfaceBase<ShapedTypeInterface> {
/// Defined these methods with the interface.
explicit Concept(DataType (*get_element_type)(Type),
DDim (*get_shape)(Type))
: get_element_type_(get_element_type), get_shape_(get_shape) {}
: get_element_type(get_element_type), get_shape(get_shape) {}

DataType (*get_element_type_)(Type);
DDim (*get_shape_)(Type);
DataType (*get_element_type)(Type);
DDim (*get_shape)(Type);
};

template <class ConcreteType>
Expand All @@ -88,18 +52,27 @@ class ShapedTypeInterface : public TypeInterfaceBase<ShapedTypeInterface> {
};

/// Constructor
ShapedTypeInterface(std::nullptr_t) // NOLINT
: TypeInterfaceBase<ShapedTypeInterface>(Type()), impl_(nullptr) {}

explicit ShapedTypeInterface(Type type = Type())
: TypeInterfaceBase<ShapedTypeInterface>(type),
impl_(type
? type.abstract_type().GetInterfaceImpl<ShapedTypeInterface>()
: nullptr) {}

ShapedTypeInterface(Type type, Concept *impl)
: TypeInterfaceBase<ShapedTypeInterface>(type), impl_(impl) {}

///
/// \brief Get the element type.
///
DataType getElementType() const;
DataType GetElementType() const;

///
/// \brief Get the shape of this type.
///
DDim getShape() const;
DDim GetShape() const;

///
/// \brief kDynamic
Expand All @@ -109,62 +82,65 @@ class ShapedTypeInterface : public TypeInterfaceBase<ShapedTypeInterface> {
///
/// \brief Check whether this type is ranked, currently return true.
///
bool hasRank() const { return true; }
bool HasRank() const { return true; }

///
/// If this is a ranked type, return the rank. Otherwise, abort.
///
int64_t getRank() const {
IR_ENFORCE((*this).hasRank(), "Cannot query rank of unranked shaped type.");
return (*this).getShape().size();
int64_t GetRank() const {
IR_ENFORCE((*this).HasRank(), "Cannot query rank of unranked shaped type.");
return (*this).GetShape().size();
}

///
/// \brief Check whether the given dimension size is a dynamic dimension.
///
static constexpr bool isDynamic(int64_t dValue) { return dValue == kDynamic; }
static constexpr bool IsDynamic(int64_t dValue) { return dValue == kDynamic; }

///
/// \brief Check whether the given shape has any size indicating a dynamic
/// dimension.
///
static bool isDynamicShape(DDim dSizes) {
return detail::any_of(vectorize(dSizes),
[](int64_t dSize) { return isDynamic(dSize); });
static bool IsDynamicShape(DDim sizes) {
auto size_vec = vectorize(sizes);
return std::any_of(size_vec.begin(), size_vec.end(), [](int64_t size_vec) {
return IsDynamic(size_vec);
});
}

///
/// \brief Check whether shape has any size indicating a dynamic dimension.
///
bool hasStaticShape() const {
return (*this).hasRank() && !isDynamicShape((*this).getShape());
bool HasStaticShape() const {
return (*this).HasRank() && !IsDynamicShape((*this).GetShape());
}

///
/// \brief Check whether the given dimension has a dynamic size.Aborts for
/// unranked types.
///
bool isDynamicDim(unsigned idx) const {
IR_ENFORCE(idx < getRank(), "Invalid index for shaped type.");
return ShapedTypeInterface::isDynamic((*this).getShape()[idx]);
bool IsDynamicDim(unsigned idx) const {
IR_ENFORCE(idx < GetRank(), "Invalid index for shaped type.");
return ShapedTypeInterface::IsDynamic((*this).GetShape()[idx]);
}

///
/// \brief Get the number of dimensions with dynamic size for a ranked type.
/// Aborts for unranked types.
///
int64_t getNumDynamicDims() const {
return detail::count_if(vectorize((*this).getShape()),
ShapedTypeInterface::isDynamic);
int64_t GetNumDynamicDims() const {
auto shape_vec = vectorize((*this).GetShape());
return std::count_if(
shape_vec.begin(), shape_vec.end(), ShapedTypeInterface::IsDynamic);
}

///
/// \brief Get the size of the specified dimension for a ranked type. Aborts
/// for unranked types.
///
int64_t getDimSize(unsigned idx) const {
IR_ENFORCE(idx < getRank(), "Invalid index for shaped type.");
return (*this).getShape()[idx];
int64_t GetDimSize(unsigned idx) const {
IR_ENFORCE(idx < GetRank(), "Invalid index for shaped type.");
return (*this).GetShape()[idx];
}

private:
Expand Down
11 changes: 10 additions & 1 deletion paddle/pir/core/op_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -91,9 +91,18 @@ class OpInterfaceBase : public OpBase {
public:
explicit OpInterfaceBase(Operation *op) : OpBase(op) {}

// Accessor for the ID of this interface.
///
/// \brief Accessor for the ID of this interface.
///
static TypeId GetInterfaceId() { return TypeId::get<ConcreteInterface>(); }

///
/// \brief Checking if the given object defines the concrete interface.
///
static bool classof(Operation *op) {
return op->HasInterface<ConcreteInterface>();
}

static ConcreteInterface dyn_cast(Operation *op) {
if (op && op->HasInterface<ConcreteInterface>()) {
return ConcreteInterface(
Expand Down
14 changes: 7 additions & 7 deletions paddle/pir/core/operation.h
Original file line number Diff line number Diff line change
Expand Up @@ -172,22 +172,22 @@ class IR_API alignas(8) Operation final {
detail::OpOperandImpl *op_operand_impl(uint32_t index);
const detail::OpOperandImpl *op_operand_impl(uint32_t index) const;

template <typename T, typename Enabler = void>
template <typename To, typename Enabler = void>
struct CastUtil {
static T call(Operation *op) {
throw("Can't dyn_cast to T, T should be a Op or Trait or Interface");
static To call(Operation *op) {
throw("Can't dyn_cast to To, To should be a Op or Trait or Interface");
}
};

// Allow access to 'SetParent'.
friend class Block;
void SetParent(Block *parent, const Block::Iterator &position);

template <typename T>
template <typename To>
struct CastUtil<
T,
typename std::enable_if<std::is_base_of<OpBase, T>::value>::type> {
static T call(Operation *op) { return T::dyn_cast(op); }
To,
typename std::enable_if<std::is_base_of<OpBase, To>::value>::type> {
static To call(Operation *op) { return To::dyn_cast(op); }
};

AttributeMap attributes_;
Expand Down
7 changes: 7 additions & 0 deletions paddle/pir/core/storage_manager_support.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,13 @@ class StorageHelperBase : public BaseT {
using InterfaceList =
typename Filter<TypeInterfaceBase, std::tuple<TraitOrInterface...>>::Type;

static ConcreteT dyn_cast_impl(BaseT type) {
if (type && type.abstract_type().type_id() == TypeId::get<ConcreteT>()) {
return ConcreteT(type.storage());
}
return ConcreteT(nullptr);
}

///
/// \brief Access to the storage instance.
///
Expand Down
Loading

0 comments on commit ecd5426

Please sign in to comment.