Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PHI] Remove infershape of set_value op #40636

Merged
merged 2 commits into from
Mar 17, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 45 additions & 0 deletions paddle/fluid/framework/infershape_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -442,6 +442,51 @@ phi::InferMetaContext BuildInferMetaContext(InferShapeContext* ctx,
attr_name, infershape_input.size()));
}
}
} else if (attr_defs[i].type_index ==
std::type_index(typeid(std::vector<phi::Scalar>))) {
auto& attr = attr_reader.GetAttr(attr_name);
if (std::type_index(attr.type()) ==
std::type_index(typeid(std::vector<int32_t>))) {
const auto& vec = BOOST_GET_CONST(std::vector<int32_t>, attr);
std::vector<phi::Scalar> scalar_list;
scalar_list.reserve(vec.size());
for (const auto& val : vec) {
scalar_list.emplace_back(val);
}
infer_meta_context.EmplaceBackAttr(std::move(scalar_list));
} else if (std::type_index(attr.type()) ==
std::type_index(typeid(std::vector<int64_t>))) {
const auto& vec = BOOST_GET_CONST(std::vector<int64_t>, attr);
std::vector<phi::Scalar> scalar_list;
scalar_list.reserve(vec.size());
for (const auto& val : vec) {
scalar_list.emplace_back(val);
}
infer_meta_context.EmplaceBackAttr(std::move(scalar_list));
} else if (std::type_index(attr.type()) ==
std::type_index(typeid(std::vector<float>))) {
const auto& vec = BOOST_GET_CONST(std::vector<float>, attr);
std::vector<phi::Scalar> scalar_list;
scalar_list.reserve(vec.size());
for (const auto& val : vec) {
scalar_list.emplace_back(val);
}
infer_meta_context.EmplaceBackAttr(std::move(scalar_list));
} else if (std::type_index(attr.type()) ==
std::type_index(typeid(std::vector<double>))) {
const auto& vec = BOOST_GET_CONST(std::vector<double>, attr);
std::vector<phi::Scalar> scalar_list;
scalar_list.reserve(vec.size());
for (const auto& val : vec) {
scalar_list.emplace_back(val);
}
infer_meta_context.EmplaceBackAttr(std::move(scalar_list));
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"Unsupported cast op attribute `%s` to vector<Scalar> when "
"construct InferMetaContext.",
attr_names[i]));
}
} else if (ctx->HasAttr(attr_name)) {
// Emplace Back Attr according to the type of attr.
auto& attr = attr_reader.GetAttr(attr_name);
Expand Down
24 changes: 12 additions & 12 deletions paddle/fluid/operators/set_value_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,15 @@
// limitations under the License.

#include "paddle/fluid/operators/set_value_op.h"

#include <string>

#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_version_registry.h"

#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/unary.h"

namespace paddle {
namespace framework {
class InferShapeContext;
Expand All @@ -34,24 +40,15 @@ class CPUDeviceContext;
namespace paddle {
namespace operators {

using Tensor = framework::Tensor;

class SetValue : public framework::OperatorWithKernel {
public:
SetValue(const std::string &type, const framework::VariableNameMap &inputs,
const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs)
: OperatorWithKernel(type, inputs, outputs, attrs) {}

void InferShape(framework::InferShapeContext *ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("Input"), "Input", "Input", "SetValue");
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "SetValue");
auto in_dims = ctx->GetInputDim("Input");
PADDLE_ENFORCE_LT(
in_dims.size(), 7,
platform::errors::InvalidArgument(
"The rank of input should be less than 7, but received %d.",
in_dims.size()));
}

protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
Expand Down Expand Up @@ -236,10 +233,13 @@ DECLARE_INPLACE_OP_INFERER(SetValueOpInplaceInferer, {"Input", "Out"});
namespace ops = paddle::operators;
namespace plat = paddle::platform;

DECLARE_INFER_SHAPE_FUNCTOR(set_value, SetValueInferShapeFunctor,
PD_INFER_META(phi::SetValueInferMeta));

REGISTER_OPERATOR(set_value, ops::SetValue, ops::SetValueMaker,
ops::SetValueGradMaker<paddle::framework::OpDesc>,
ops::SetValueGradMaker<paddle::imperative::OpBase>,
ops::SetValueOpInplaceInferer);
ops::SetValueOpInplaceInferer, SetValueInferShapeFunctor);

REGISTER_OPERATOR(set_value_grad, ops::SetValueGrad);

Expand Down
10 changes: 10 additions & 0 deletions paddle/phi/infermeta/unary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1090,6 +1090,16 @@ void RollInferMeta(const MetaTensor& x,
out->set_dtype(x.dtype());
}

void SetValueInferMeta(const MetaTensor& x, MetaTensor* out) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个infermeta的参数和kernel参数也不一致吧,后续自动生成到generator.h/cc中的infermeta会和这个函数重名吗

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里需要补充下对out dims和dtype的推断吗?虽然原来的没写,但这里应该有?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里是先绕过了给算子注册两个InferMeta函数的问题,参数不一致的问题可以在自动生成的时候处理,函数名重名的问题到时候一并解决。
set_value算子是inplace的,正常的话InferMeta里确实也不需要设置,为了稳妥起见迁过来也没有设置dim和dtype,后面有时间再把这里补上测试下

auto in_dims = x.dims();
PADDLE_ENFORCE_LT(
in_dims.size(),
7,
phi::errors::InvalidArgument(
"The rank of input should be less than 7, but received %d.",
in_dims.size()));
}

void ShapeInferMeta(const MetaTensor& input, MetaTensor* out) {
auto in_dim = input.dims();
out->set_dims(phi::make_ddim({in_dim.size()}));
Expand Down
2 changes: 2 additions & 0 deletions paddle/phi/infermeta/unary.h
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,8 @@ void RollInferMeta(const MetaTensor& x,
const std::vector<int64_t>& axis,
MetaTensor* out);

void SetValueInferMeta(const MetaTensor& x, MetaTensor* out);

void ShapeInferMeta(const MetaTensor& input, MetaTensor* out);

void ShardIndexInferMeta(const MetaTensor& in,
Expand Down
28 changes: 14 additions & 14 deletions paddle/phi/ops/compat/set_value_sig.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@ namespace phi {

KernelSignature SetValueOpArgumentMapping(const ArgumentMappingContext& ctx) {
if (ctx.IsDenseTensorInput("Input")) {
if (ctx.HasInput("StartsTensorList")) {
if (ctx.HasInput("EndsTensorList")) {
if (ctx.HasInput("StepsTensorList")) {
if (ctx.InputSize("StartsTensorList") > 0) {
if (ctx.InputSize("EndsTensorList") > 0) {
if (ctx.InputSize("StepsTensorList") > 0) {
if (ctx.HasInput("ValueTensor")) {
return KernelSignature("set_value_with_tensor",
{"Input", "ValueTensor"},
Expand Down Expand Up @@ -197,7 +197,7 @@ KernelSignature SetValueOpArgumentMapping(const ArgumentMappingContext& ctx) {
}
}
} else {
if (ctx.HasInput("StepsTensorList")) {
if (ctx.InputSize("StepsTensorList") > 0) {
if (ctx.HasInput("ValueTensor")) {
return KernelSignature("set_value_with_tensor",
{"Input", "ValueTensor"},
Expand Down Expand Up @@ -374,8 +374,8 @@ KernelSignature SetValueOpArgumentMapping(const ArgumentMappingContext& ctx) {
}
}
} else {
if (ctx.HasInput("EndsTensorList")) {
if (ctx.HasInput("StepsTensorList")) {
if (ctx.InputSize("EndsTensorList") > 0) {
if (ctx.InputSize("StepsTensorList") > 0) {
if (ctx.HasInput("ValueTensor")) {
return KernelSignature("set_value_with_tensor",
{"Input", "ValueTensor"},
Expand Down Expand Up @@ -551,7 +551,7 @@ KernelSignature SetValueOpArgumentMapping(const ArgumentMappingContext& ctx) {
}
}
} else {
if (ctx.HasInput("StepsTensorList")) {
if (ctx.InputSize("StepsTensorList") > 0) {
if (ctx.HasInput("ValueTensor")) {
return KernelSignature("set_value_with_tensor",
{"Input", "ValueTensor"},
Expand Down Expand Up @@ -734,9 +734,9 @@ KernelSignature SetValueOpArgumentMapping(const ArgumentMappingContext& ctx) {

KernelSignature SetValueGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
if (ctx.HasInput("StartsTensorList")) {
if (ctx.HasInput("EndsTensorList")) {
if (ctx.HasInput("StepsTensorList")) {
if (ctx.InputSize("StartsTensorList") > 0) {
if (ctx.InputSize("EndsTensorList") > 0) {
if (ctx.InputSize("StepsTensorList") > 0) {
return KernelSignature(
"set_value_grad",
{GradVarName("Out")},
Expand All @@ -760,7 +760,7 @@ KernelSignature SetValueGradOpArgumentMapping(
{GradVarName("Input"), GradVarName("ValueTensor")});
}
} else {
if (ctx.HasInput("StepsTensorList")) {
if (ctx.InputSize("StepsTensorList") > 0) {
return KernelSignature(
"set_value_grad",
{GradVarName("Out")},
Expand All @@ -785,8 +785,8 @@ KernelSignature SetValueGradOpArgumentMapping(
}
}
} else {
if (ctx.HasInput("EndsTensorList")) {
if (ctx.HasInput("StepsTensorList")) {
if (ctx.InputSize("EndsTensorList") > 0) {
if (ctx.InputSize("StepsTensorList") > 0) {
return KernelSignature(
"set_value_grad",
{GradVarName("Out")},
Expand All @@ -810,7 +810,7 @@ KernelSignature SetValueGradOpArgumentMapping(
{GradVarName("Input"), GradVarName("ValueTensor")});
}
} else {
if (ctx.HasInput("StepsTensorList")) {
if (ctx.InputSize("StepsTensorList") > 0) {
return KernelSignature(
"set_value_grad",
{GradVarName("Out")},
Expand Down