Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[IR] Support op attribute and refactor for new op definition #54068

Merged
Merged
Show file tree
Hide file tree
Changes from 36 commits
Commits
Show all changes
41 commits
Select commit Hold shift + click to select a range
2b5cc89
add vector type support for program translator
kangguangli May 23, 2023
489b2cc
polish
kangguangli May 23, 2023
324d897
support basic attribute type
kangguangli May 23, 2023
f388e79
Merge branch 'develop' of github.com:PaddlePaddle/Paddle into add_vec…
kangguangli May 23, 2023
f9e17d6
resolve conflicts
kangguangli May 23, 2023
9cd463f
add verify for combine/slice and unittests
kangguangli May 24, 2023
6d17079
polish
kangguangli May 24, 2023
b654cc2
Merge branch 'add_vector_type_support_for_program_translator' into su…
kangguangli May 24, 2023
b45df9c
support more type in attribute translator
kangguangli May 24, 2023
5940ac4
modify by reviews
kangguangli May 24, 2023
b3b1856
Merge branch 'develop' of github.com:PaddlePaddle/Paddle into support…
kangguangli May 24, 2023
b328cbb
fix merge mistakes
kangguangli May 24, 2023
640c27f
Merge branch 'develop' of github.com:PaddlePaddle/Paddle into support…
kangguangli May 25, 2023
e35a288
refine code
zhangbo9674 May 26, 2023
1ba19dc
refine code
zhangbo9674 May 26, 2023
6c2ac5d
add interface
zhangbo9674 May 26, 2023
8473bce
Merge branch 'develop' of github.com:PaddlePaddle/Paddle into support…
kangguangli May 29, 2023
dd2060b
Merge pull/54130/head
kangguangli May 29, 2023
6967eba
fix: op name normalization
kangguangli May 29, 2023
df0b5cd
fix typo
kangguangli May 29, 2023
f1c02dd
refactor input translator
kangguangli May 30, 2023
a053a11
Merge branch 'develop' of github.com:PaddlePaddle/Paddle into support…
kangguangli May 30, 2023
0854418
fix merge conflicts
kangguangli May 30, 2023
38f7453
fix op normalizer bug
kangguangli May 30, 2023
e8c6234
refactor attribute translator
kangguangli May 30, 2023
261fe97
fix bug
kangguangli May 30, 2023
6d57a5f
refactor output translator
kangguangli May 31, 2023
0717f11
fix typo
kangguangli May 31, 2023
c9d7c63
fix
kangguangli May 31, 2023
edfad35
Merge branch 'develop' of github.com:PaddlePaddle/Paddle into support…
kangguangli Jun 1, 2023
2fb54d8
fix approval error
kangguangli Jun 1, 2023
e945a41
fix coverage
kangguangli Jun 1, 2023
c9358f4
fix op_compat parser
kangguangli Jun 1, 2023
30540e7
Merge branch 'develop' of github.com:PaddlePaddle/Paddle into support…
kangguangli Jun 2, 2023
b785d35
fix merge conflicts
kangguangli Jun 2, 2023
a0f21e6
fix merge conflicts
kangguangli Jun 2, 2023
c2ede6f
fix merge conflicts
kangguangli Jun 2, 2023
2605b6e
Merge branch 'develop' of github.com:PaddlePaddle/Paddle into support…
kangguangli Jun 2, 2023
7f97345
fix merge conflicts
kangguangli Jun 2, 2023
5f81713
fix merge conflicts
kangguangli Jun 2, 2023
31bdbf5
Merge branch 'develop' of github.com:PaddlePaddle/Paddle into support…
kangguangli Jun 2, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion paddle/fluid/dialect/op_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -1186,7 +1186,7 @@ def OpGenerator(
)

# generate op verify function
if "GradOp" in op_class_name or "Grad_Op" in op_class_name:
if "Grad" in op_class_name:
op_verify_str = GRAD_OP_VERIFY_TEMPLATE.format(
op_name=op_class_name,
)
Expand Down
4 changes: 3 additions & 1 deletion paddle/fluid/dialect/pd_attribute.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@ namespace paddle {
namespace dialect {
phi::IntArray IntArrayAttribute::data() const { return storage()->GetAsKey(); }

phi::Scalar ScalarAttribute::data() const { return storage()->GetAsKey(); }
paddle::experimental::Scalar ScalarAttribute::data() const {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个从phi::改为 paddle::experimental::是有什么必要的原因吗?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

它们是两个不同的类型,在OpDesc中的Attribute类型是paddle::experimental::Scalar,所以这里做了修改。后面根据新讨论的ScalarAttribute重新适配下。

return storage()->GetAsKey();
}

phi::DataType DataTypeAttribute::data() const { return storage()->GetAsKey(); }

Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/dialect/pd_attribute.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ class ScalarAttribute : public ir::Attribute {
return storage() < right.storage();
}

phi::Scalar data() const;
paddle::experimental::Scalar data() const;
};

class DataTypeAttribute : public ir::Attribute {
Expand Down
4 changes: 2 additions & 2 deletions paddle/fluid/dialect/pd_attribute_storage.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ struct IntArrayAttributeStorage : public ir::AttributeStorage {
};

struct ScalarAttributeStorage : public ir::AttributeStorage {
using ParamKey = phi::Scalar;
using ParamKey = paddle::experimental::Scalar;

explicit ScalarAttributeStorage(const ParamKey &key) { data_ = key; }

Expand All @@ -73,7 +73,7 @@ struct ScalarAttributeStorage : public ir::AttributeStorage {
ParamKey GetAsKey() const { return ParamKey(data_); }

private:
phi::Scalar data_;
paddle::experimental::Scalar data_;
};

struct DataTypeAttributeStorage : public ir::AttributeStorage {
Expand Down
5 changes: 4 additions & 1 deletion paddle/fluid/dialect/pd_interface.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,10 @@ class GetOpInfoInterface : public ir::OpInterfaceBase<GetOpInfoInterface> {
struct Model : public Concept {
static OpInfoTuple GetOpInfo() { return ConcreteOp::GetOpInfo(); }

Model() : Concept(GetOpInfo) {}
Model() : Concept(GetOpInfo) {
static_assert(sizeof(Model) == sizeof(Concept),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个static_assert(sizeof(Model) == sizeof(Concept)已经不需要了,可以在下个pr中删除。

"sizeof(Model) != sizeof(Concept)");
}
};

GetOpInfoInterface(ir::Operation *op, Concept *impl)
Expand Down
14 changes: 4 additions & 10 deletions paddle/fluid/dialect/pd_op.yaml
Original file line number Diff line number Diff line change
@@ -1,12 +1,7 @@
- name: feed
inputs:
- typename: Tensor[]
name: x
optional: false
no_need_buffer: false
data_transform: {}
inputs: []
attrs:
- {typename: int, name: col}
- {typename: str, name: name}
outputs:
- {typename: Tensor, name: out, optional: false, intermediate: false}
no_need_buffer: null
Expand All @@ -21,9 +16,8 @@
no_need_buffer: false
data_transform: {}
attrs:
- {typename: int, name: col}
outputs:
- {typename: 'Tensor[]', name: out, optional: false, intermediate: false}
- {typename: str, name: name}
outputs: []
no_need_buffer: null
data_transform: null
inplace: null
Expand Down
231 changes: 231 additions & 0 deletions paddle/fluid/translator/attribute_translator.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,231 @@
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/translator/attribute_translator.h"

#include <string>
#include <vector>

#include "paddle/fluid/dialect/pd_attribute.h"
#include "paddle/phi/common/data_type.h"
#include "paddle/phi/common/int_array.h"
#include "paddle/phi/common/layout.h"
#include "paddle/phi/common/place.h"
#include "paddle/phi/common/scalar.h"
#include "paddle/utils/variant.h"

namespace paddle {
namespace translator {

class AttributeVisitor {
public:
ir::IrContext* ctx;
AttributeVisitor() { ctx = ir::IrContext::Instance(); }
~AttributeVisitor() {}

public:
virtual ir::Attribute operator()(int i) {
VLOG(10) << "translating int";
return ir::Int32_tAttribute::get(ctx, i);
}

virtual ir::Attribute operator()(float f) {
VLOG(10) << "translating float";
return ir::FloatAttribute::get(ctx, f);
}

virtual ir::Attribute operator()(bool b) {
VLOG(10) << "translating bool";
return ir::BoolAttribute::get(ctx, b);
}

virtual ir::Attribute operator()(double d) {
VLOG(10) << "translating double";
return ir::DoubleAttribute::get(ctx, d);
}

virtual ir::Attribute operator()(std::string str) {
VLOG(10) << "translating string";
return ir::StrAttribute::get(ctx, str);
}

virtual ir::Attribute operator()(const paddle::experimental::Scalar& scalar) {
VLOG(10) << "translating scalar";
return paddle::dialect::ScalarAttribute::get(ctx, scalar);
}

virtual ir::Attribute operator()(const std::vector<std::string>& strs) {
VLOG(10) << "translating vector<string>";
std::vector<ir::Attribute> attrs;
attrs.reserve(strs.size());
for (const auto& v : strs) {
attrs.push_back(ir::StrAttribute::get(ctx, v));
}
return ir::ArrayAttribute::get(ctx, attrs);
}

virtual ir::Attribute operator()(const std::vector<float>& fs) {
VLOG(10) << "translating vector<float>";
std::vector<ir::Attribute> attrs;
attrs.reserve(fs.size());
for (const auto& v : fs) {
attrs.push_back(ir::FloatAttribute::get(ctx, v));
}
return ir::ArrayAttribute::get(ctx, attrs);
}

virtual ir::Attribute operator()(const std::vector<int>& is) {
VLOG(10) << "translating vector<int>";
std::vector<ir::Attribute> attrs;
attrs.reserve(is.size());
for (const auto& v : is) {
attrs.push_back(ir::Int32_tAttribute::get(ctx, v));
}
return ir::ArrayAttribute::get(ctx, attrs);
}

virtual ir::Attribute operator()(const std::vector<bool>& bs) {
VLOG(10) << "translating vector<bool>";
std::vector<ir::Attribute> attrs;
attrs.reserve(bs.size());
for (const auto& v : bs) {
attrs.push_back(ir::BoolAttribute::get(ctx, v));
}
return ir::ArrayAttribute::get(ctx, attrs);
}

virtual ir::Attribute operator()(const std::vector<int64_t>& i64s) {
VLOG(10) << "translating vector<int64>";
std::vector<ir::Attribute> attrs;
attrs.reserve(i64s.size());
for (const auto& v : i64s) {
attrs.push_back(ir::Int64_tAttribute::get(ctx, v));
}
return ir::ArrayAttribute::get(ctx, attrs);
}

virtual ir::Attribute operator()(const std::vector<double>& ds) {
VLOG(10) << "translating vector<double>";
std::vector<ir::Attribute> attrs;
attrs.reserve(ds.size());
for (const auto& v : ds) {
attrs.push_back(ir::DoubleAttribute::get(ctx, v));
}
return ir::ArrayAttribute::get(ctx, attrs);
}

virtual ir::Attribute operator()(
const std::vector<paddle::experimental::Scalar>& ss) {
VLOG(10) << "translating vector<scalar>";
std::vector<ir::Attribute> attrs;
attrs.reserve(ss.size());
for (const auto& v : ss) {
attrs.push_back(paddle::dialect::ScalarAttribute::get(ctx, v));
}
return ir::ArrayAttribute::get(ctx, attrs);
}

virtual ir::Attribute operator()(const paddle::blank& blank) {
VLOG(10) << "translating paddle::blank";
return ir::Attribute(nullptr);
}

template <typename T>
ir::Attribute operator()(T attr) {
VLOG(10) << "translating null type";
return ir::Attribute(nullptr);
}
};

class IntArrayAttributeVisitor : public AttributeVisitor {
public:
using AttributeVisitor::AttributeVisitor;
ir::Attribute operator()(const std::vector<int>& is) override {
VLOG(10) << "translating vector<int> to IntArray";
phi::IntArray data(is);
return paddle::dialect::IntArrayAttribute::get(ctx, data);
}

ir::Attribute operator()(const std::vector<int64_t>& is) override {
VLOG(10) << "translating vector<int> to IntArray";
phi::IntArray data(is);
return paddle::dialect::IntArrayAttribute::get(ctx, data);
}
};

class ScalarAttributeVisitor : public AttributeVisitor {
public:
using AttributeVisitor::AttributeVisitor;
ir::Attribute operator()(int i) override {
VLOG(10) << "translating int to Scalar";
phi::Scalar data(i);
return paddle::dialect::ScalarAttribute::get(ctx, data);
}

ir::Attribute operator()(float f) override {
VLOG(10) << "translating float to Scalar";
phi::Scalar data(f);
return paddle::dialect::ScalarAttribute::get(ctx, data);
}
};

class DataTypeAttributeVisitor : public AttributeVisitor {
public:
using AttributeVisitor::AttributeVisitor;
ir::Attribute operator()(int i) override {
VLOG(10) << "translating int to DataType: " << i;
phi::DataType data = static_cast<phi::DataType>(i);
return paddle::dialect::DataTypeAttribute::get(ctx, data);
}
};

class PlaceAttributeVisitor : public AttributeVisitor {
public:
using AttributeVisitor::AttributeVisitor;

ir::Attribute operator()(const paddle::blank& blank) override {
VLOG(10) << "translating paddle::blank";
phi::Place data(phi::AllocationType::CPU);
return paddle::dialect::PlaceAttribute::get(ctx, data);
}
};

AttributeTranslator::AttributeTranslator() {
general_visitor = new AttributeVisitor();
special_visitors["paddle::dialect::IntArrayAttribute"] =
new IntArrayAttributeVisitor();
special_visitors["paddle::dialect::ScalarAttribute"] =
new ScalarAttributeVisitor();
special_visitors["paddle::dialect::DataTypeAttribute"] =
new DataTypeAttributeVisitor();
special_visitors["paddle::dialect::PlaceAttribute"] =
new PlaceAttributeVisitor();
}

ir::Attribute AttributeTranslator::operator()(
const framework::Attribute& attr) {
return paddle::visit(*general_visitor, attr);
}

ir::Attribute AttributeTranslator::operator()(
const std::string& target_type, const framework::Attribute& attr) {
if (special_visitors.find(target_type) == special_visitors.end()) {
VLOG(10) << "[" << target_type << "] not found";
return paddle::visit(*general_visitor, attr);
}
return paddle::visit(*(special_visitors.at(target_type)), attr);
}

} // namespace translator
} // namespace paddle
54 changes: 54 additions & 0 deletions paddle/fluid/translator/attribute_translator.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include <string>
#include <unordered_map>

#include "paddle/fluid/framework/attribute.h"
#include "paddle/fluid/framework/type_defs.h"
#include "paddle/ir/core/attribute.h"
#include "paddle/ir/core/builtin_attribute.h"
#include "paddle/ir/core/ir_context.h"

#pragma once

namespace paddle {
namespace translator {

class AttributeVisitor;

class AttributeTranslator {
private:
AttributeTranslator();
AttributeVisitor* general_visitor;
std::unordered_map<std::string, AttributeVisitor*> special_visitors;

public:
AttributeTranslator(const AttributeTranslator&) = delete;
AttributeTranslator& operator=(const AttributeTranslator&) = delete;
AttributeTranslator(AttributeTranslator&&) = delete;
AttributeTranslator& operator=(AttributeTranslator&&) = delete;

static auto& instance() {
static AttributeTranslator attribute_translator;
return attribute_translator;
}

ir::Attribute operator()(const framework::Attribute& attr);
ir::Attribute operator()(const std::string& target_type,
const framework::Attribute& attr);
};

} // namespace translator
} // namespace paddle
Loading