Skip to content

Commit

Permalink
reply review comments
Browse files Browse the repository at this point in the history
  • Loading branch information
xiaoguoguo626807 committed Aug 25, 2023
1 parent 71d1c09 commit 0ded8f3
Show file tree
Hide file tree
Showing 2 changed files with 124 additions and 3 deletions.
122 changes: 119 additions & 3 deletions paddle/fluid/ir/dialect/paddle_dialect/ir/pd_manual_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,69 @@ OpInfoTuple SplitGradOp::GetOpInfo() {
inputs, attributes, outputs, run_time_info, "split_grad");
}

void SplitGradOp::Build(ir::Builder &builder,
ir::OperationArgument &argument,
ir::OpResult out_grad_,
float axis) {
// Generate scalar mutable attribute: axis
paddle::dialect::FullOp full_axis_op = builder.Build<paddle::dialect::FullOp>(
std::vector<int64_t>{1}, axis, phi::DataType::FLOAT32, phi::CPUPlace());
ir::OpResult axis_ = full_axis_op->result(0);

VLOG(4) << "Builder construction inputs";
std::vector<ir::OpResult> argument_inputs = {out_grad_, axis_};
argument.AddOperands(argument_inputs.begin(), argument_inputs.end());

VLOG(4) << "Builder construction attributes";

VLOG(4) << "Builder construction outputs";
ir::VectorType out_grad = out_grad_.type().dyn_cast<ir::VectorType>();
std::vector<phi::DenseTensor> vec_dense_out_grad;
for (size_t i = 0; i < static_cast<size_t>(out_grad.size()); i++) {
vec_dense_out_grad.push_back(phi::DenseTensor(
std::make_unique<paddle::experimental::DefaultAllocator>(
paddle::platform::CPUPlace())
.get(),
phi::DenseTensorMeta(
paddle::dialect::TransToPhiDataType(
out_grad[i]
.dyn_cast<paddle::dialect::DenseTensorType>()
.dtype()),
out_grad[i].dyn_cast<paddle::dialect::DenseTensorType>().dims(),
out_grad[i]
.dyn_cast<paddle::dialect::DenseTensorType>()
.data_layout(),
out_grad[i].dyn_cast<paddle::dialect::DenseTensorType>().lod(),
out_grad[i]
.dyn_cast<paddle::dialect::DenseTensorType>()
.offset())));
}
std::vector<phi::MetaTensor> vec_meta_out_grad;
for (size_t i = 0; i < vec_dense_out_grad.size(); i++) {
vec_meta_out_grad.push_back(phi::MetaTensor(&vec_dense_out_grad[i]));
}

std::vector<const phi::MetaTensor *> meta_out_grad;
for (size_t i = 0; i < static_cast<size_t>(vec_meta_out_grad.size()); i++) {
meta_out_grad.push_back(&vec_meta_out_grad[i]);
}
phi::DenseTensor dense_x_grad;
phi::MetaTensor meta_x_grad(&dense_x_grad);

phi::ConcatInferMeta(meta_out_grad, axis, &meta_x_grad);

std::vector<ir::Type> argument_outputs;
ir::Type x_grad_dense_tensor_type = paddle::dialect::DenseTensorType::get(
ir::IrContext::Instance(),
paddle::dialect::TransToIrDataType(dense_x_grad.dtype()),
dense_x_grad.dims(),
dense_x_grad.layout(),
dense_x_grad.lod(),
dense_x_grad.offset());
argument_outputs.push_back(x_grad_dense_tensor_type);
argument.AddOutputs(argument_outputs.begin(), argument_outputs.end());
}

void SplitGradOp::Build(ir::Builder &builder,
ir::OperationArgument &argument,
ir::OpResult out_grad_,
Expand All @@ -185,15 +248,13 @@ void SplitGradOp::Build(ir::Builder &builder,

VLOG(4) << "Builder construction outputs";
ir::VectorType out_grad = out_grad_.type().dyn_cast<ir::VectorType>();
(void)out_grad;
int axis = axis_.owner()
->dyn_cast<paddle::dialect::FullOp>()
.attributes()
.at("value")
.dyn_cast<paddle::dialect::ScalarAttribute>()
.data()
.to<int>();
(void)axis;

std::vector<phi::DenseTensor> vec_dense_out_grad;
for (size_t i = 0; i < static_cast<size_t>(out_grad.size()); i++) {
Expand Down Expand Up @@ -240,7 +301,62 @@ void SplitGradOp::Build(ir::Builder &builder,
argument.AddOutputs(argument_outputs.begin(), argument_outputs.end());
}

void SplitGradOp::Verify() {}
void SplitGradOp::Verify() {
VLOG(4) << "Start Verifying inputs, outputs and attributes for: SplitGradOp.";
VLOG(4) << "Verifying inputs:";
{
auto input_size = num_operands();
PADDLE_ENFORCE_EQ(
input_size,
2u,
phi::errors::PreconditionNotMet(
"The size %d of inputs must be equal to 2.", input_size));
if (auto vec_type =
(*this)->operand_source(0).type().dyn_cast<ir::VectorType>()) {
for (size_t i = 0; i < vec_type.size(); ++i) {
PADDLE_ENFORCE(vec_type[i].isa<paddle::dialect::DenseTensorType>(),
phi::errors::PreconditionNotMet(
"Type validation failed for the 0th input."));
}
} else {
PADDLE_ENFORCE((*this)
->operand_source(0)
.type()
.isa<paddle::dialect::DenseTensorType>(),
phi::errors::PreconditionNotMet(
"Type validation failed for the 0th input."));
}
PADDLE_ENFORCE((*this)
->operand_source(1)
.type()
.isa<paddle::dialect::DenseTensorType>(),
phi::errors::PreconditionNotMet(
"Type validation failed for the 1th input."));
}
VLOG(4) << "Verifying attributes:";
{
// Attributes num is 0, not need to check attributes type.
}
VLOG(4) << "Verifying outputs:";
{
auto output_size = num_results();
PADDLE_ENFORCE_EQ(
output_size,
1u,
phi::errors::PreconditionNotMet(
"The size %d of outputs must be equal to 1.", output_size));
PADDLE_ENFORCE(
(*this)->result(0).type().isa<paddle::dialect::DenseTensorType>(),
phi::errors::PreconditionNotMet(
"Type validation failed for the 0th output."));
}
VLOG(4) << "End Verifying for: SplitGradOp.";
}

void SplitGradOp::InferMeta(phi::InferMetaContext *infer_meta) {
auto fn = PD_INFER_META(phi::ConcatInferMeta);
fn(infer_meta);
}

} // namespace dialect
} // namespace paddle
Expand Down
5 changes: 5 additions & 0 deletions paddle/fluid/ir/dialect/paddle_dialect/ir/pd_manual_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,10 @@ class SplitGradOp : public ir::Op<SplitGradOp, OpYamlInfoInterface> {
static const char *attributes_name[1];
static constexpr uint32_t attributes_num = 1;
static OpInfoTuple GetOpInfo();
static void Build(ir::Builder &builder, // NOLINT
ir::OperationArgument &argument, // NOLINT
ir::OpResult x_,
float axis = 0);
static void Build(ir::Builder &builder, // NOLINT
ir::OperationArgument &argument, // NOLINT
ir::OpResult out_grad_,
Expand All @@ -67,6 +71,7 @@ class SplitGradOp : public ir::Op<SplitGradOp, OpYamlInfoInterface> {
ir::Value out_grad() { return operand_source(0); }
ir::Value axis() { return operand_source(1); }
ir::OpResult x_grad() { return result(0); }
static void InferMeta(phi::InferMetaContext *infer_meta);
};

} // namespace dialect
Expand Down

0 comments on commit 0ded8f3

Please sign in to comment.