Skip to content

Commit

Permalink
move eye, lerp infershape to phi (#40105)
Browse files Browse the repository at this point in the history
  • Loading branch information
0x45f authored Mar 3, 2022
1 parent 167d511 commit 1c20588
Show file tree
Hide file tree
Showing 9 changed files with 75 additions and 65 deletions.
26 changes: 7 additions & 19 deletions paddle/fluid/operators/eye_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/nullary.h"

namespace paddle {
namespace operators {
Expand All @@ -21,24 +24,6 @@ class EyeOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;

void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"), true,
platform::errors::InvalidArgument(
"Output(Out) of EyeOP should not be null."));
auto num_rows = ctx->Attrs().Get<int64_t>("num_rows");
PADDLE_ENFORCE_EQ(
num_rows >= 0, true,
platform::errors::InvalidArgument(
"The value of Input(num_rows) should be non-negative int."));
auto num_columns = ctx->Attrs().Get<int64_t>("num_columns");
if (num_columns == -1) num_columns = num_rows;
PADDLE_ENFORCE_EQ(
num_columns >= 0, true,
platform::errors::InvalidArgument(
"The value of Input(num_columns) should be non-negative int."));
ctx->SetOutputDim("Out", {num_rows, num_columns});
}

protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
Expand Down Expand Up @@ -82,8 +67,11 @@ Return an identity tensor whose shape is [num_rows, num_columns].
} // namespace paddle

namespace ops = paddle::operators;
DELCARE_INFER_SHAPE_FUNCTOR(eye, EyeInferShapeFunctor,
PT_INFER_META(phi::EyeInferMeta));

REGISTER_OPERATOR(
eye, ops::EyeOp, ops::EyeOpMaker, ops::EyeOpVarTypeInference,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>,
EyeInferShapeFunctor);
50 changes: 6 additions & 44 deletions paddle/fluid/operators/lerp_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,57 +12,17 @@
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/ternary.h"

namespace paddle {
namespace operators {

class LerpOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;

void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "lerp");
OP_INOUT_CHECK(ctx->HasInput("Y"), "Input", "Y", "lerp");
OP_INOUT_CHECK(ctx->HasInput("Weight"), "Input", "Weight", "lerp");
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "lerp");

auto x_dims = ctx->GetInputDim("X");
auto y_dims = ctx->GetInputDim("Y");
auto w_dims = ctx->GetInputDim("Weight");
framework::DDim out_dims;
out_dims = GetOutputDims(x_dims, y_dims);
if (w_dims.size() > 1 || w_dims[0] != 1) {
out_dims = GetOutputDims(out_dims, w_dims);
}

ctx->SetOutputDim("Out", out_dims);
ctx->ShareLoD("X", /*->*/ "Out");
}

private:
framework::DDim GetOutputDims(const framework::DDim& s_dims,
const framework::DDim& l_dims) const {
if (s_dims.size() > l_dims.size()) {
return GetOutputDims(l_dims, s_dims);
}
std::vector<int64_t> shapes = phi::vectorize<int64_t>(l_dims);
for (int i = s_dims.size() - 1, j = l_dims.size() - 1; i >= 0; --i, --j) {
int64_t s = s_dims[i];
int64_t l = l_dims[j];
if (s != l) {
if (l == 1) {
shapes[j] = s;
} else if (s != 1) {
PADDLE_THROW(platform::errors::InvalidArgument(
"The shape of tensor a %s:%d must match shape of tensor b "
"%s:%d.",
s_dims.to_str(), i, l_dims.to_str(), j));
}
}
}
return phi::make_ddim(shapes);
}
};

class LerpOpMaker : public framework::OpProtoAndCheckerMaker {
Expand Down Expand Up @@ -125,10 +85,12 @@ DECLARE_INPLACE_OP_INFERER(LerpInplaceInferer, {"X", "Out"});
} // namespace operators
} // namespace paddle

DELCARE_INFER_SHAPE_FUNCTOR(lerp, LerpInferShapeFunctor,
PT_INFER_META(phi::LerpInferMeta));
REGISTER_OPERATOR(
lerp, paddle::operators::LerpOp, paddle::operators::LerpOpMaker,
paddle::operators::LerpOpGradMaker<paddle::framework::OpDesc>,
paddle::operators::LerpOpGradMaker<paddle::imperative::OpBase>,
paddle::operators::LerpInplaceInferer);
paddle::operators::LerpInplaceInferer, LerpInferShapeFunctor);

REGISTER_OPERATOR(lerp_grad, paddle::operators::LerpGradOp);
8 changes: 8 additions & 0 deletions paddle/phi/infermeta/nullary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,4 +32,12 @@ void CreateInferMeta(const ScalarArray& shape,
CreateInferMetaBase(shape.GetData(), dtype, DataLayout::NCHW, out);
}

void EyeInferMeta(int64_t num_rows,
int64_t num_columns,
DataType dtype,
MetaTensor* out) {
if (num_columns == -1) num_columns = num_rows;
out->set_dims({num_rows, num_columns});
out->set_dtype(dtype);
}
} // namespace phi
5 changes: 5 additions & 0 deletions paddle/phi/infermeta/nullary.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,4 +35,9 @@ void CreateInferMetaBase(const std::vector<int64_t>& shape,

void CreateInferMeta(const ScalarArray& shape, DataType dtype, MetaTensor* out);

void EyeInferMeta(int64_t num_rows,
int64_t num_columns,
DataType dtype,
MetaTensor* out);

} // namespace phi
17 changes: 17 additions & 0 deletions paddle/phi/infermeta/ternary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -89,4 +89,21 @@ void AddmmInferMeta(const MetaTensor& input,
out->set_dtype(input.dtype());
}

void LerpInferMeta(const MetaTensor& x,
const MetaTensor& y,
const MetaTensor& weight,
MetaTensor* out) {
auto x_dims = x.dims();
auto y_dims = y.dims();
auto w_dims = weight.dims();
DDim out_dims;
out_dims = funcs::GetOutputDims(x_dims, y_dims);
if (w_dims.size() > 1 || w_dims[0] != 1) {
out_dims = funcs::GetOutputDims(out_dims, w_dims);
}
out->set_dims(out_dims);
out->set_dtype(x.dtype());
out->share_lod(x);
}

} // namespace phi
5 changes: 5 additions & 0 deletions paddle/phi/infermeta/ternary.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,4 +37,9 @@ void AddmmInferMeta(const MetaTensor& input,
float beta,
MetaTensor* out);

void LerpInferMeta(const MetaTensor& x,
const MetaTensor& y,
const MetaTensor& weight,
MetaTensor* out);

} // namespace phi
2 changes: 1 addition & 1 deletion paddle/phi/kernels/eye_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ template <typename T, typename Context>
void EyeKernel(const Context& ctx,
int64_t num_rows,
int64_t num_columns,
int dtype,
DataType dtype,
DenseTensor* out);

} // namespace phi
25 changes: 25 additions & 0 deletions paddle/phi/kernels/funcs/common_shape.h
Original file line number Diff line number Diff line change
Expand Up @@ -140,5 +140,30 @@ inline bool CheckDims(const DDim &dims_x, const DDim &dims_y) {
return true;
}

inline DDim GetOutputDims(const DDim &s_dims, const DDim &l_dims) {
if (s_dims.size() > l_dims.size()) {
return GetOutputDims(l_dims, s_dims);
}
std::vector<int64_t> shapes = phi::vectorize<int64_t>(l_dims);
for (int i = s_dims.size() - 1, j = l_dims.size() - 1; i >= 0; --i, --j) {
int64_t s = s_dims[i];
int64_t l = l_dims[j];
if (s != l) {
if (l == 1) {
shapes[j] = s;
} else if (s != 1) {
PADDLE_THROW(errors::InvalidArgument(
"The shape of tensor a %s:%d must match shape of tensor b "
"%s:%d.",
s_dims.to_str(),
i,
l_dims.to_str(),
j));
}
}
}
return phi::make_ddim(shapes);
}

} // namespace funcs
} // namespace phi
2 changes: 1 addition & 1 deletion paddle/phi/kernels/impl/eye_kernel_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ template <typename T, typename Context>
void EyeKernel(const Context& ctx,
int64_t num_rows,
int64_t num_columns,
int dtype,
DataType dtype,
DenseTensor* out) {
auto num = num_columns;
if (num == -1) {
Expand Down

0 comments on commit 1c20588

Please sign in to comment.