Skip to content

Commit

Permalink
[Phi] Migrate InferShape of multiplex, qr, tril_triu (#40102)
Browse files Browse the repository at this point in the history
* migrate infershape

* fix tril_triu infershape error

* fix qr_op infershape

* add parse qr mode func

* move order
  • Loading branch information
Caozhou1995 authored Mar 24, 2022
1 parent f51a579 commit 2e73653
Show file tree
Hide file tree
Showing 9 changed files with 179 additions and 117 deletions.
48 changes: 9 additions & 39 deletions paddle/fluid/operators/multiplex_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,13 @@ limitations under the License. */

#include <memory>
#include <vector>

#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"

#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/multiary.h"

namespace paddle {
namespace operators {

Expand All @@ -25,44 +30,6 @@ class MultiplexOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;

void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("Ids"), "Input", "Ids", "Multiplex");
PADDLE_ENFORCE_NE(
ctx->Inputs("X").empty(), true,
platform::errors::InvalidArgument("MultiInput(X) shouldn't be empty."));
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "Multiplex");
auto ids_dim = ctx->GetInputDim("Ids");
PADDLE_ENFORCE_EQ(
ids_dim.size(), 2,
platform::errors::PreconditionNotMet(
"The index tensor must be a vector with 2 dimensions"));
PADDLE_ENFORCE_EQ(
ids_dim[1], 1,
platform::errors::PreconditionNotMet(
"The index tensor must be a vector with batchSize x 1."));

auto ins_dims = ctx->GetInputsDim("X");
auto num_ins = ins_dims.size();
PADDLE_ENFORCE_GT(num_ins, 1,
platform::errors::InvalidArgument(
"multiplex operator should have more than "
"one candidate input tensors."));

auto in_dim = ins_dims[0];
PADDLE_ENFORCE_GE(
in_dim.size(), 2,
platform::errors::InvalidArgument(
"The rank of candidate tensors must be not less than 2."));
for (size_t i = 1; i < num_ins; i++) {
auto dim = ins_dims[i];
PADDLE_ENFORCE_EQ(
in_dim, dim,
platform::errors::PreconditionNotMet(
"All the candidate tensors must have the same size."));
}
ctx->SetOutputDim("Out", in_dim);
}

protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
Expand Down Expand Up @@ -164,8 +131,11 @@ class MultiplexGradMaker : public framework::SingleGradOpMaker<T> {
} // namespace paddle

namespace ops = paddle::operators;
DECLARE_INFER_SHAPE_FUNCTOR(multiplex, MultiplexInferShapeFunctor,
PD_INFER_META(phi::MultiplexInferMeta));

REGISTER_OPERATOR(multiplex, ops::MultiplexOp, ops::MultiplexOpMaker,
ops::MultiplexGradMaker<paddle::framework::OpDesc>,
ops::MultiplexGradMaker<paddle::imperative::OpBase>);
ops::MultiplexGradMaker<paddle::imperative::OpBase>,
MultiplexInferShapeFunctor);
REGISTER_OPERATOR(multiplex_grad, ops::MultiplexGradOp);
48 changes: 8 additions & 40 deletions paddle/fluid/operators/qr_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,10 @@
#ifdef PADDLE_WITH_MKLDNN
#include "paddle/fluid/platform/mkldnn_helper.h"
#endif
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/unary.h"

namespace paddle {
namespace operators {
Expand All @@ -29,43 +33,6 @@ using DDim = framework::DDim;
class QrOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;

void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "qr");
OP_INOUT_CHECK(ctx->HasOutput("Q"), "Output", "Q", "qr");
OP_INOUT_CHECK(ctx->HasOutput("R"), "Output", "R", "qr");

auto x_dims = ctx->GetInputDim("X");
int x_rank = x_dims.size();
PADDLE_ENFORCE_GE(x_dims.size(), 2,
platform::errors::InvalidArgument(
"the rank of input must greater than 2"));
bool compute_q;
bool reduced_mode;
int m = x_dims[x_rank - 2];
int n = x_dims[x_rank - 1];
int min_mn = std::min(m, n);
std::string mode = ctx->Attrs().Get<std::string>("mode");
std::tie(compute_q, reduced_mode) = _parse_qr_mode(mode);

if (compute_q) {
int k = reduced_mode ? min_mn : m;
auto q_dims_vec = phi::vectorize(x_dims);
q_dims_vec[q_dims_vec.size() - 1] = k;
ctx->SetOutputDim("Q", phi::make_ddim(q_dims_vec));
} else {
ctx->SetOutputDim("Q", phi::make_ddim({0}));
}

int k = reduced_mode ? min_mn : m;
auto r_dims_vec = phi::vectorize(x_dims);
r_dims_vec[r_dims_vec.size() - 2] = k;
r_dims_vec[r_dims_vec.size() - 1] = n;
ctx->SetOutputDim("R", phi::make_ddim(r_dims_vec));

ctx->ShareLoD("X", /*->*/ "Q");
ctx->ShareLoD("X", /*->*/ "R");
}
};

class QrOpMaker : public framework::OpProtoAndCheckerMaker {
Expand All @@ -83,10 +50,8 @@ class QrOpMaker : public framework::OpProtoAndCheckerMaker {
.SetDefault("reduced");
AddComment(R"DOC(
Qr Operator.
This operator is used to perform QR operation for batched matrics $X$.
$$Q, R = qr(X)$$
)DOC");
}
};
Expand Down Expand Up @@ -138,10 +103,13 @@ class QrGradMaker : public framework::SingleGradOpMaker<T> {
} // namespace paddle

namespace ops = paddle::operators;
DECLARE_INFER_SHAPE_FUNCTOR(qr, QrInferShapeFunctor,
PD_INFER_META(phi::QrInferMeta));

REGISTER_OPERATOR(qr, ops::QrOp, ops::QrOpMaker,
ops::QrGradMaker<paddle::framework::OpDesc>,
ops::QrGradMaker<paddle::imperative::OpBase>);
ops::QrGradMaker<paddle::imperative::OpBase>,
QrInferShapeFunctor);

REGISTER_OPERATOR(qr_grad, ops::QrGradOp);

Expand Down
24 changes: 8 additions & 16 deletions paddle/fluid/operators/tril_triu_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,29 +13,18 @@ See the License for the specific language governing permissions and
limitations under the License. */

#include <memory>
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"

#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/unary.h"

namespace paddle {
namespace operators {

class TrilTriuOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;

void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE_EQ(
ctx->HasInput("X"), true,
platform::errors::NotFound("Input(X) of TrilTriuOp is not found."));
PADDLE_ENFORCE_EQ(
ctx->HasOutput("Out"), true,
platform::errors::NotFound("Output(Out) of TrilTriuOp is not found."));
const auto& x_dims = ctx->GetInputDim("X");
PADDLE_ENFORCE_GE(x_dims.size(), 2,
platform::errors::InvalidArgument(
"Input(X)'s rank must be at least 2 in TrilTriuOp."));
ctx->SetOutputDim("Out", x_dims);
ctx->ShareLoD("X", /*->*/ "Out");
}
};

class TrilTriuOpMaker : public framework::OpProtoAndCheckerMaker {
Expand Down Expand Up @@ -100,7 +89,10 @@ class TrilTriuGradOpMaker : public framework::SingleGradOpMaker<T> {

namespace ops = paddle::operators;
namespace plat = paddle::platform;
DECLARE_INFER_SHAPE_FUNCTOR(tril_triu, TrilTriuInferShapeFunctor,
PD_INFER_META(phi::TrilTriuInferMeta));
REGISTER_OPERATOR(tril_triu, ops::TrilTriuOp, ops::TrilTriuOpMaker,
ops::TrilTriuGradOpMaker<paddle::framework::OpDesc>,
ops::TrilTriuGradOpMaker<paddle::imperative::OpBase>);
ops::TrilTriuGradOpMaker<paddle::imperative::OpBase>,
TrilTriuInferShapeFunctor);
REGISTER_OPERATOR(tril_triu_grad, ops::TrilTriuGradOp);
44 changes: 44 additions & 0 deletions paddle/phi/infermeta/multiary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -832,6 +832,50 @@ void MultiDotInferMeta(const std::vector<MetaTensor*>& x, MetaTensor* out) {
out->share_lod(*x.at(0));
}

void MultiplexInferMeta(const std::vector<MetaTensor*>& ins,
const MetaTensor& ids,
MetaTensor* out) {
PADDLE_ENFORCE_NE(
ins.empty(),
true,
phi::errors::InvalidArgument("MultiInput(X) shouldn't be empty."));
auto ids_dim = ids.dims();
PADDLE_ENFORCE_EQ(ids_dim.size(),
2,
phi::errors::PreconditionNotMet(
"The index tensor must be a vector with 2 dimensions"));
PADDLE_ENFORCE_EQ(
ids_dim[1],
1,
phi::errors::PreconditionNotMet(
"The index tensor must be a vector with batchSize x 1."));

auto ins_dims = GetMetaTensorsDim(ins);
auto num_ins = ins_dims.size();
PADDLE_ENFORCE_GT(
num_ins,
1,
phi::errors::InvalidArgument("multiplex operator should have more than "
"one candidate input tensors."));

auto in_dim = ins_dims[0];
PADDLE_ENFORCE_GE(
in_dim.size(),
2,
phi::errors::InvalidArgument(
"The rank of candidate tensors must be not less than 2."));
for (size_t i = 1; i < num_ins; i++) {
auto dim = ins_dims[i];
PADDLE_ENFORCE_EQ(
in_dim,
dim,
phi::errors::PreconditionNotMet(
"All the candidate tensors must have the same size."));
}
out->set_dims(in_dim);
out->set_dtype(ins[0]->dtype());
}

void PsroiPoolInferMeta(const MetaTensor& x,
const MetaTensor& rois,
paddle::optional<const MetaTensor&> rois_num,
Expand Down
4 changes: 4 additions & 0 deletions paddle/phi/infermeta/multiary.h
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,10 @@ void HierarchicalSigmoidInferMeta(const MetaTensor& x,

void MultiDotInferMeta(const std::vector<MetaTensor*>& x, MetaTensor* out);

void MultiplexInferMeta(const std::vector<MetaTensor*>& ins,
const MetaTensor& ids,
MetaTensor* out);

void PsroiPoolInferMeta(const MetaTensor& x,
const MetaTensor& rois,
paddle::optional<const MetaTensor&> rois_num,
Expand Down
53 changes: 53 additions & 0 deletions paddle/phi/infermeta/unary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ limitations under the License. */
#include "paddle/phi/common/type_traits.h"
#include "paddle/phi/core/enforce.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/kernels/funcs/parse_qr_mode.h"
#include "paddle/phi/kernels/funcs/pooling.h"
#include "paddle/phi/kernels/funcs/unfold_functor.h"
#include "paddle/phi/kernels/funcs/unsqueeze.h"
Expand Down Expand Up @@ -1129,6 +1130,44 @@ void RealAndImagInferMeta(const MetaTensor& x, MetaTensor* out) {
out->set_layout(x.layout());
}

void QrInferMeta(const MetaTensor& x,
const std::string& mode,
MetaTensor* q,
MetaTensor* r) {
auto x_dims = x.dims();
int x_rank = x_dims.size();
PADDLE_ENFORCE_GE(
x_dims.size(),
2,
phi::errors::InvalidArgument("the rank of input must greater than 2"));
bool compute_q;
bool reduced_mode;
int m = x_dims[x_rank - 2];
int n = x_dims[x_rank - 1];
int min_mn = std::min(m, n);
std::tie(compute_q, reduced_mode) = phi::funcs::ParseQrMode(mode);

if (compute_q) {
int k = reduced_mode ? min_mn : m;
auto q_dims_vec = phi::vectorize(x_dims);
q_dims_vec[q_dims_vec.size() - 1] = k;
q->set_dims(phi::make_ddim(q_dims_vec));
} else {
q->set_dims(phi::make_ddim({0}));
}

int k = reduced_mode ? min_mn : m;
auto r_dims_vec = phi::vectorize(x_dims);
r_dims_vec[r_dims_vec.size() - 2] = k;
r_dims_vec[r_dims_vec.size() - 1] = n;
r->set_dims(phi::make_ddim(r_dims_vec));

q->share_lod(x);
r->share_lod(x);
q->set_dtype(x.dtype());
r->set_dtype(x.dtype());
}

DDim ReduceInferDim(const MetaTensor& x,
const std::vector<int64_t>& axis,
bool keep_dim,
Expand Down Expand Up @@ -1847,6 +1886,20 @@ void UnbindInferMeta(const MetaTensor& x,
}
}

void TrilTriuInferMeta(const MetaTensor& x,
int diagonal,
bool lower,
MetaTensor* out) {
const auto& x_dims = x.dims();
PADDLE_ENFORCE_GE(x_dims.size(),
2,
phi::errors::InvalidArgument(
"Input(X)'s rank must be at least 2 in TrilTriuOp."));
out->set_dims(x.dims());
out->share_lod(x);
out->set_dtype(x.dtype());
}

void UnchangedInferMeta(const MetaTensor& x, MetaTensor* out) {
out->share_meta(x);
}
Expand Down
10 changes: 10 additions & 0 deletions paddle/phi/infermeta/unary.h
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,11 @@ void PoolInferMeta(const MetaTensor& x,
MetaTensor* out,
MetaConfig config = MetaConfig());

void QrInferMeta(const MetaTensor& x,
const std::string& mode,
MetaTensor* q,
MetaTensor* r);

void RealAndImagInferMeta(const MetaTensor& x, MetaTensor* out);

void ReduceInferMeta(const MetaTensor& x,
Expand Down Expand Up @@ -282,6 +287,11 @@ void TransposeGradInferMeta(const MetaTensor& x,
const std::vector<int>& axis,
MetaTensor* out);

void TrilTriuInferMeta(const MetaTensor& x,
int diagonal,
bool lower,
MetaTensor* out);

void UnbindInferMeta(const MetaTensor& x,
int axis,
std::vector<MetaTensor>* outs);
Expand Down
24 changes: 2 additions & 22 deletions paddle/phi/kernels/cpu/qr_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,30 +19,10 @@
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/complex_functors.h"
#include "paddle/phi/kernels/funcs/parse_qr_mode.h"

namespace phi {

static inline std::tuple<bool, bool> ParseQrMode(const std::string& mode) {
bool compute_q;
bool reduced;
if (mode == "reduced") {
compute_q = true;
reduced = true;
} else if (mode == "complete") {
compute_q = true;
reduced = false;
} else if (mode == "r") {
compute_q = false;
reduced = true;
} else {
PADDLE_THROW(errors::InvalidArgument(
"QR received unrecognized mode '%s'"
" but expected one of 'reduced' (default), 'r', or 'complete'",
mode));
}
return std::make_tuple(compute_q, reduced);
}

template <typename T, typename Context>
void QrKernel(const Context& ctx,
const DenseTensor& x,
Expand All @@ -51,7 +31,7 @@ void QrKernel(const Context& ctx,
DenseTensor* r) {
bool compute_q;
bool reduced_mode;
std::tie(compute_q, reduced_mode) = ParseQrMode(mode);
std::tie(compute_q, reduced_mode) = phi::funcs::ParseQrMode(mode);
auto numel = x.numel();
PADDLE_ENFORCE_GT(
numel, 0, errors::PreconditionNotMet("The input of QR is empty."));
Expand Down
Loading

0 comments on commit 2e73653

Please sign in to comment.