Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[IR] Attribute system #51636

Merged
Merged
20 changes: 20 additions & 0 deletions paddle/ir/attribute.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/ir/attribute.h"
#include "paddle/ir/dialect.h"

namespace ir {
IrContext *Attribute::ir_context() const { return dialect().ir_context(); }
} // namespace ir
92 changes: 92 additions & 0 deletions paddle/ir/attribute.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include "paddle/ir/attribute_base.h"
#include "paddle/ir/cast_utils.h"

namespace ir {
///
/// \brief Unified interface of the Attribute class. Derivation of all Attribute
/// classes only derives interfaces, not members.
///
class Attribute {
public:
using Storage = AttributeStorage;

constexpr Attribute() = default;

Attribute(const Storage *storage) // NOLINT
: storage_(storage) {}

Attribute(const Attribute &other) = default;

Attribute &operator=(const Attribute &other) = default;

bool operator==(Attribute other) const { return storage_ == other.storage_; }

bool operator!=(Attribute other) const { return storage_ != other.storage_; }

explicit operator bool() const { return storage_; }

bool operator!() const { return storage_ == nullptr; }

///
/// \brief Some Attribute attribute acquisition interfaces.
///
TypeId type_id() { return storage_->abstract_attribute().type_id(); }

const AbstractAttribute &abstract_attribute() {
return storage_->abstract_attribute();
}

const Storage *storage() const { return storage_; }

const Dialect &dialect() const {
return storage_->abstract_attribute().dialect();
}

IrContext *ir_context() const;

///
/// \brief Methods for type judgment and cast.
///
static bool classof(Attribute) { return true; }

template <typename T>
bool isa() const {
return ir::isa<T>(*this);
}

template <typename U>
U dyn_cast() const {
return ir::dyn_cast<U>(*this);
}

friend struct std::hash<Attribute>;

protected:
const Storage *storage_{nullptr};
};
} // namespace ir

namespace std {
template <>
struct hash<ir::Attribute> {
std::size_t operator()(const ir::Attribute &obj) const {
return std::hash<const ir::Attribute::Storage *>()(obj.storage_);
}
};
} // namespace std
282 changes: 282 additions & 0 deletions paddle/ir/attribute_base.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,282 @@
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include "paddle/ir/ir_context.h"
#include "paddle/ir/storage_manager.h"
#include "paddle/ir/type_id.h"

namespace ir {
class Dialect;

///
/// \brief Abstract the properties and behaviors common to all Attribute classes
/// into an AbstractAttribute class.
///
class AbstractAttribute {
public:
///
/// \brief Construct an AbstractAttribute by TypeId directly.
///
/// \param type_id The id of the AbstractAttribute.
/// \param dialect The Dialect which the attribute registered to.
///
static AbstractAttribute get(TypeId type_id, const Dialect &dialect) {
return AbstractAttribute(type_id, dialect);
}

///
/// \brief Construct an AbstractAttribute by TypeId directly.
///
/// \param dialect The Dialect which the attribute registered to.
///
template <typename T>
static AbstractAttribute get(const Dialect &dialect) {
return AbstractAttribute(TypeId::get<T>(), dialect);
}

///
/// \brief Returns the type id of the AbstractAttribute.
///
/// \return The id of the AbstractAttribute.
///
TypeId type_id() const { return type_id_; }

///
/// \brief Get the dialect this attribute was registered to.
///
/// \return The dialect this attribute was registered to.
///
const Dialect &dialect() const { return dialect_; }

///
/// \brief Find the AbstractAttribute instance whose TypeId is type_id from
/// IrContext.
///
/// \param type_id The type id of the AbstractAttribute.
/// \param ctx The IrContext.
/// \return The AbstractAttribute instance whose TypeId is type_id.
///
static const AbstractAttribute &lookup(TypeId type_id, IrContext *ctx);

private:
///
/// \brief The constructor is set to private and provides the user with the
/// get method to obtain and manage the AbstractAttribute.
///
/// \param type_id The type id of the AbstractAttribute.
/// \param dialect The Dialect which the attribute registered to.
///
explicit AbstractAttribute(TypeId type_id, const Dialect &dialect)
: type_id_(type_id), dialect_(dialect) {}

TypeId type_id_;
const Dialect &dialect_;
};

struct AttributeManager;

///
/// \brief AttributeStorage is used to store all information of a Attribute. A
/// Attribute object contains a AttributeStorage. For non-parameter attribute,
/// the information includes: TypeId, so AttributeStorage only needs to include
/// AbstractAttribute; For parameteric attribute, in addition to
/// AbstractAttribute/TypeId, parameteric information needs to be included. So
/// that, non-parameteric attribute can be constructed by AttributeStorage
/// directly but parameteric attribute should be constructed by Derived
/// AttributeStorage.
///
class AttributeStorage : public StorageManager::StorageBase {
friend StorageManager;
friend AttributeManager;

public:
///
/// \brief Construct a AttributeStorage and initialize abstract_attribute.
///
/// \param abstract_attribute The abstract_attribute of this AttributeStorage.
///
explicit AttributeStorage(AbstractAttribute *abstract_attribute)
: abstract_attribute_(abstract_attribute) {}

AttributeStorage() {}

///
/// \brief Returns the AbstractAttribute of the AttributeStorage.
///
/// \return The AbstractAttribute of the AttributeStorage.
///
const AbstractAttribute &abstract_attribute() const {
return *abstract_attribute_;
}

private:
///
/// \brief Initialize AttributeStorage based on the AbstractAttribute*
/// provided by the user
///
/// \param abstract_attribute AbstractAttribute* provided by the user, the
/// construction method of AbstractAttribute refers to AbstractAttribute::get.
///
void initialize(const AbstractAttribute &abstract_attribute) {
abstract_attribute_ = const_cast<AbstractAttribute *>(&abstract_attribute);
}

AbstractAttribute *abstract_attribute_{nullptr}; // not owned
};

///
/// \brief AttributeManager is a utility class that provides interfaces for get
/// or unique Attribute instances in IrContext.
///
struct AttributeManager {
///
/// \brief Get a unique instance of Attribute T from IrContext. Note: For a
/// parameteric attribute, if not found in IrContext, it will try to create a
/// new instance and register it to IrContext; for a parameterless attribute,
/// only search.
///
/// \param ctx The IrContext instance.
/// \param args Parameters of the wrapped function.
/// \return The unique instance of Attribute T from IrContext.
///
template <typename T, typename... Args>
static T get(IrContext *ctx, Args &&...args) {
return get<T, Args...>(
ctx, ir::TypeId::get<T>(), std::forward<Args>(args)...);
}

///
/// \brief Get a unique instance of parametric Attribute T from IrContext. If
/// not found in IrContext, it will try to create a new instance and register
/// it to IrContext;
///
/// \param ctx The IrContext instance.
/// \param type_id The type id of the AbstractAttribute.
/// \param args Parameters of the wrapped function.
/// \return The unique instance of Attribute T from IrContext.
///
template <typename T, typename... Args>
static std::enable_if_t<
!std::is_same<typename T::Storage, AttributeStorage>::value,
T>
get(IrContext *ctx, TypeId type_id, Args &&...args) {
return ctx->attribute_storage_manager()
.GetParametricStorage<typename T::Storage>(
[&, type_id](AttributeStorage *storage) {
storage->initialize(AbstractAttribute::lookup(type_id, ctx));
},
type_id,
std::forward<Args>(args)...);
}

///
/// \brief Get a unique instance of parameterless Attribute T from IrContext.
///
/// \param ctx The IrContext instance.
/// \param type_id The type id of the AbstractAttribute.
/// \return The unique instance of Attribute T from IrContext.
///
template <typename T>
static std::
enable_if_t<std::is_same<typename T::Storage, AttributeStorage>::value, T>
get(IrContext *ctx, TypeId type_id) {
return ctx->attribute_storage_manager()
.GetParameterlessStorage<typename T::Storage>(type_id);
}

///
/// \brief Register a unique instance of Attribute T to IrContext.
///
/// \param ctx The IrContext instance.
///
template <typename T>
static void RegisterAttribute(IrContext *ctx) {
RegisterAttribute<T>(ctx, ir::TypeId::get<T>());
}

///
/// \brief Register a unique parametric Attribute T to IrContext.
///
/// \param ctx The IrContext instance.
/// \param type_id The type id of the Attribute T.
///
template <typename T>
static std::enable_if_t<
!std::is_same<typename T::Storage, AttributeStorage>::value>
RegisterAttribute(IrContext *ctx, TypeId type_id) {
ctx->attribute_storage_manager()
.RegisterParametricStorage<typename T::Storage>(type_id);
}

///
/// \brief Register a unique parameterless Attribute T to IrContext.
///
/// \param ctx The IrContext instance.
/// \param type_id The type id of the Attribute T.
///
template <typename T>
static std::enable_if_t<
std::is_same<typename T::Storage, AttributeStorage>::value>
RegisterAttribute(IrContext *ctx, TypeId type_id) {
ctx->attribute_storage_manager()
.RegisterParameterlessStorage<AttributeStorage>(
type_id, [&ctx, type_id](AttributeStorage *storage) {
storage->initialize(AbstractAttribute::lookup(type_id, ctx));
});
}
};

///
/// \brief Add some necessary functions to the custom Attribute class.
///
#define DECLARE_ATTRIBUTE_UTILITY_FUNCTOR(concrete_attribute, storage_type) \
using Storage = storage_type; \
\
const Storage *storage() const { \
return static_cast<const Storage *>(this->storage_); \
} \
\
static ir::TypeId type_id() { \
return ir::TypeId::get<concrete_attribute>(); \
} \
\
template <typename T> \
static bool classof(T val) { \
return val.type_id() == type_id(); \
} \
\
template <typename... Args> \
static concrete_attribute get(ir::IrContext *ctx, Args... args) { \
return ir::AttributeManager::template get<concrete_attribute>(ctx, \
args...); \
}

///
/// \brief This macro definition is used to register custom Attribute class.
///
#define REGISTER_ATTRIBUTE_2_IRCONTEXT(concrete_attribute, dialect) \
ir::AbstractAttribute *abstract_attribute_##concrete_attribute = \
new ir::AbstractAttribute(std::move( \
ir::AbstractAttribute::get<concrete_attribute>(*dialect))); \
\
dialect->ir_context()->RegisterAbstractAttribute( \
ir::TypeId::get<concrete_attribute>(), \
abstract_attribute_##concrete_attribute); \
\
ir::AttributeManager::RegisterAttribute<concrete_attribute>( \
dialect->ir_context());

} // namespace ir
Loading