Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PIR] Add Three OPs with ReifyReturnTypeShapes #58368

Merged
merged 7 commits into from
Dec 5, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion paddle/common/ddim.h
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,8 @@ using common::vectorize;

namespace pir {
using DDim = common::DDim;
}
using LoD = std::vector<std::vector<size_t>>;
} // namespace pir

namespace std {
template <>
Expand Down
4 changes: 2 additions & 2 deletions paddle/pir/core/infer_type_op_interface.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,10 @@ namespace pir {

bool InferShapedTypeOpInterface::ReifyReturnTypeShapes(
Builder& builder,
std::vector<OpOperand> operands,
const std::vector<OpOperand>& operands,
std::vector<Value>& reified_return_shapes) {
return impl_->reify_return_type_shapes(
builder, operands, reified_return_shapes);
operation(), builder, operands, reified_return_shapes);
}
} // namespace pir

Expand Down
19 changes: 11 additions & 8 deletions paddle/pir/core/infer_type_op_interface.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,23 +31,26 @@ class InferShapedTypeOpInterface
/// Defined these methods with the interface.
struct Concept {
explicit Concept(bool (*reify_return_type_shapes)(
Builder& builder, // NOLINT
std::vector<OpOperand> operands, // NOLINT
Operation* op,
Builder& builder, // NOLINT
const std::vector<OpOperand>& operands,
std::vector<Value>& reified_return_shapes)) // NOLINT
: reify_return_type_shapes(reify_return_type_shapes) {}
bool (*reify_return_type_shapes)(
Operation* op,
Builder& builder,
std::vector<OpOperand> operands,
const std::vector<OpOperand>& operands,
std::vector<Value>& reified_return_shapes); // NOLINT
};

template <class ConcreteOp>
struct Model : public Concept {
static inline bool ReifyReturnTypeShapes(
Builder& builder, // NOLINT
std::vector<OpOperand> operands, // NOLINT
Operation* op,
Builder& builder, // NOLINT
const std::vector<OpOperand>& operands,
std::vector<Value>& reified_return_shapes) { // NOLINT
return ConcreteOp::ReifyReturnTypeShapes(
return op->dyn_cast<ConcreteOp>().ReifyReturnTypeShapes(
builder, operands, reified_return_shapes);
}

Expand All @@ -59,8 +62,8 @@ class InferShapedTypeOpInterface
: pir::OpInterfaceBase<InferShapedTypeOpInterface>(op), impl_(impl) {}

bool ReifyReturnTypeShapes(
Builder& builder, // NOLINT
std::vector<OpOperand> operands, // NOLINT
Builder& builder, // NOLINT
const std::vector<OpOperand>& operands,
std::vector<Value>& reified_return_shapes); // NOLINT

private:
Expand Down
6 changes: 5 additions & 1 deletion paddle/pir/dialect/shape/ir/shape_dialect.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

#include "paddle/pir/dialect/shape/ir/shape_dialect.h"
#include "paddle/pir/dialect/shape/ir/shape_op.h"
#include "paddle/pir/dialect/shape/ir/shape_reify_infer_shape_op.h"

namespace pir::shape {
ShapeDialect::ShapeDialect(IrContext *context)
Expand All @@ -32,7 +33,10 @@ void ShapeDialect::initialize() {
FromElementsOp,
ExtractOp,
ConstantOp,
IndexCastOp>();
IndexCastOp,
AbsOp,
TransposeOp,
ConcatOp>();
}

void ShapeDialect::PrintOperation(Operation *op, IrPrinter &printer) const {
Expand Down
29 changes: 28 additions & 1 deletion paddle/pir/dialect/shape/ir/shape_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
// limitations under the License.

#include "paddle/pir/dialect/shape/ir/shape_op.h"
#include "paddle/common/enforce.h"
#include "paddle/pir/core/builtin_attribute.h"
#include "paddle/pir/core/builtin_op.h"
#include "paddle/pir/core/builtin_type.h"
Expand Down Expand Up @@ -290,12 +289,37 @@ void ShapeOfOp::Build(Builder &builder, // NOLINT
OperationArgument &argument, // NOLINT
Value input) {
argument.AddInput(input);

IrContext *ctx = IrContext::Instance();
Type dtype = IndexType::get(ctx);
int64_t input_rank = input.type()
.dyn_cast<DenseTensorType>()
.dyn_cast<ShapedTypeInterface>()
.GetRank();
pir::DDim dims = {input_rank};
pir::DataLayout data_layout = pir::DataLayout::NCHW;
pir::LoD lod = {{0, 1, 2}};
size_t offset = 0;

argument.output_types.emplace_back(
DenseTensorType::get(ctx, dtype, dims, data_layout, lod, offset));
}

void FromElementsOp::Build(Builder &builder, // NOLINT
OperationArgument &argument, // NOLINT
const std::vector<Value> &elements) {
argument.AddInputs(elements);

IrContext *ctx = IrContext::Instance();
Type dtype = IndexType::get(ctx);
int64_t num_elements = elements.size();
pir::DDim dims = {num_elements};
pir::DataLayout data_layout = pir::DataLayout::NCHW;
pir::LoD lod = {{0, 1, 2}};
size_t offset = 0;

argument.output_types.emplace_back(
DenseTensorType::get(ctx, dtype, dims, data_layout, lod, offset));
}

std::vector<Value> FromElementsOp::elements() {
Expand All @@ -312,6 +336,8 @@ void ExtractOp::Build(Builder &builder, // NOLINT
std::vector<Value> indices) {
argument.AddInput(tensor);
argument.AddInputs(indices);
auto type = tensor.type().dyn_cast<ShapedTypeInterface>().GetElementType();
argument.output_types.emplace_back(type);
}

std::vector<Value> ExtractOp::indices() {
Expand All @@ -334,6 +360,7 @@ void IndexCastOp::Build(Builder &builder, // NOLINT
Type out,
Value in) {
argument.AddInput(in);
argument.output_types.emplace_back(out);
}

} // namespace pir::shape
Expand Down
204 changes: 204 additions & 0 deletions paddle/pir/dialect/shape/ir/shape_reify_infer_shape_op.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,204 @@
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/pir/dialect/shape/ir/shape_reify_infer_shape_op.h"
#include "paddle/common/ddim.h"
#include "paddle/pir/core/builtin_type.h"
#include "paddle/pir/dialect/shape/ir/shape_op.h"

namespace pir::shape {

namespace {

bool DeriveShapeFromOperand(Builder *builder,
Value operand,
std::vector<Value> *reified_return_shapes) {
auto shaped_type = operand.type().dyn_cast<ShapedTypeInterface>();
if (!shaped_type) return false;
reified_return_shapes->assign(
{builder->Build<shape::ShapeOfOp>(operand).result(0)});
return true;
}

// Returns a new scalar integer value having type `type`.
// Here `type` must be an integer or index type.
Value MaybeCastTo(Builder &builder, Value value, Type type) { // NOLINT
if (type == value.type()) return value;
if (!type.IsIndex() && !value.type().IsIndex()) {
Value casted =
builder.Build<shape::IndexCastOp>(builder.index_type(), value)
.result(0);
return builder.Build<shape::IndexCastOp>(type, casted).result(0);
}
return builder.Build<shape::IndexCastOp>(type, value).result(0);
}
} // namespace

void AbsOp::Build(Builder &builder, OperationArgument &argument, Value x) {
argument.AddInput(x);

IrContext *ctx = IrContext::Instance();
Type dtype = x.type().dyn_cast<ShapedTypeInterface>().GetElementType();
pir::DDim dims = x.type().dyn_cast<DenseTensorType>().dims();
pir::DataLayout data_layout = pir::DataLayout::NCHW;
pir::LoD lod = {{0, 1, 2}};
size_t offset = 0;

argument.output_types.emplace_back(
DenseTensorType::get(ctx, dtype, dims, data_layout, lod, offset));
}

bool AbsOp::ReifyReturnTypeShapes(Builder &builder,
const std::vector<OpOperand> &operands,
std::vector<Value> &reified_return_shapes) {
return DeriveShapeFromOperand(
&builder, operands.front().source(), &reified_return_shapes);
}

const char *TransposeOp::attributes_name[attributes_num] = {"perm"};

void TransposeOp::Build(Builder &builder,
OperationArgument &argument,
Value x,
std::vector<int> &perm) {
std::vector<pir::Value> argument_inputs = {x};
argument.AddInputs(argument_inputs);
std::vector<pir::Attribute> vec_perm;
for (size_t i = 0; i < static_cast<size_t>(perm.size()); i++) {
pir::Attribute attr_perm =
pir::Int32Attribute::get(pir::IrContext::Instance(), perm[i]);
vec_perm.push_back(attr_perm);
}
pir::Attribute attr_perm =
pir::ArrayAttribute::get(pir::IrContext::Instance(), vec_perm);
argument.AddAttribute("perm", attr_perm);

IrContext *ctx = IrContext::Instance();
Type dtype = IndexType::get(ctx);
pir::DDim in_dims = x.type().dyn_cast<DenseTensorType>().dims();
pir::DDim out_dims = in_dims.transpose(perm);
pir::DataLayout data_layout = pir::DataLayout::NCHW;
pir::LoD lod = {{0, 1, 2}};
size_t offset = 0;

argument.output_types.emplace_back(
DenseTensorType::get(ctx, dtype, out_dims, data_layout, lod, offset));
}

std::vector<int64_t> TransposeOp::permutation() {
// TODO(zhangbopd): should not return just {1, 0}.
return {1, 0};
}

bool TransposeOp::ReifyReturnTypeShapes(
Builder &builder,
const std::vector<OpOperand> &operands,
std::vector<Value> &reified_return_shapes) {
auto operand_type = operands[0].type().dyn_cast<DenseTensorType>();
// Currently not support unranked type.
if (!operand_type) return false;

std::vector<int64_t> permutation = this->permutation();
std::vector<Value> shape_values(permutation.size());

Type shape_scalar_type = builder.index_type();

auto to_shape_scalar_type = [&](Value v) {
return MaybeCastTo(builder, v, shape_scalar_type);
};

auto shaped_type = operand_type.dyn_cast<ShapedTypeInterface>();
auto shape_vector = shaped_type.GetDyShape();
for (auto [idx, element] = std::tuple{0, shape_vector.begin()};
element != shape_vector.end();
++idx, ++element) {
auto it = std::find(permutation.begin(), permutation.end(), idx);
// TODO(zhangbopd): Need BuildOrFold
Value value_dim = to_shape_scalar_type(
builder.Build<shape::TensorDimOp>(operands[0].source(), idx).result(0));
shape_values[std::distance(permutation.begin(), it)] = value_dim;
}

Value output_shape =
builder.Build<shape::FromElementsOp>(shape_values).result(0);
reified_return_shapes.push_back(output_shape);

return true;
}

void ConcatOp::Build(Builder &builder,
OperationArgument &argument,
Value x,
Value axis) {
std::vector<pir::Value> argument_inputs = {x, axis};
argument.AddInputs(argument_inputs);
}

bool ConcatOp::ReifyReturnTypeShapes(
Builder &builder,
const std::vector<OpOperand> &operands,
std::vector<Value> &reified_return_shapes) {
std::vector<Value> inputs = {x()};

auto operand_type = inputs[0].type().dyn_cast<DenseTensorType>();
// Currently not support unranked type.
if (!operand_type) return false;

Type shapeScalarType = builder.index_type();
auto to_shape_scalar_type = [&](Value v) {
return MaybeCastTo(builder, v, shapeScalarType);
};

std::vector<std::vector<Value>> all_shape_values;
for (size_t inputId = 0; inputId < inputs.size(); ++inputId) {
Value operand = inputs[inputId];
auto operand_type = operand.type().dyn_cast<DenseTensorType>();
if (!operand_type) return false;

std::vector<Value> shape_values;

auto shaped_type = operand_type.dyn_cast<ShapedTypeInterface>();
auto shape_vector = shaped_type.GetDyShape();
for (auto [idx, element] = std::tuple{0, shape_vector.begin()};
element != shape_vector.end();
++idx, ++element) {
Value value_dim = to_shape_scalar_type(
builder.Build<shape::TensorDimOp>(operand, idx).result(0));
shape_values.push_back(value_dim);
}
all_shape_values.emplace_back(std::move(shape_values));
}

[[maybe_unused]] int axis = this->dimension();
auto &shape_values = all_shape_values[0];
for (size_t vecId = 1; vecId < all_shape_values.size(); ++vecId) {
auto &otherShapeValues = all_shape_values[vecId];
if (otherShapeValues.size() != shape_values.size()) return false;
// TODO(zhangbopd): AddIOp
// shape_values[axis] =
// builder.Build<arith::AddIOp>(shape_values[axis],
// otherShapeValues[axis]);
}

Value output_shape =
builder.Build<shape::FromElementsOp>(shape_values).result(0);
reified_return_shapes.push_back(output_shape);
return true;
}

} // namespace pir::shape

IR_DEFINE_EXPLICIT_TYPE_ID(pir::shape::AbsOp)
IR_DEFINE_EXPLICIT_TYPE_ID(pir::shape::TransposeOp)
IR_DEFINE_EXPLICIT_TYPE_ID(pir::shape::ConcatOp)
Loading