From 53614a47e25efcb26367b41cb1987f6b6f10b277 Mon Sep 17 00:00:00 2001 From: RedContritio Date: Mon, 13 Mar 2023 15:20:50 +0000 Subject: [PATCH] support auto generate for nonzero --- paddle/fluid/operators/where_index_op.cc | 58 ------------------------ paddle/phi/api/yaml/legacy_ops.yaml | 8 ---- paddle/phi/api/yaml/ops.yaml | 9 ++++ paddle/phi/ops/compat/where_index_sig.cc | 27 ----------- 4 files changed, 9 insertions(+), 93 deletions(-) delete mode 100644 paddle/fluid/operators/where_index_op.cc delete mode 100644 paddle/phi/ops/compat/where_index_sig.cc diff --git a/paddle/fluid/operators/where_index_op.cc b/paddle/fluid/operators/where_index_op.cc deleted file mode 100644 index 2b19b62595eec..0000000000000 --- a/paddle/fluid/operators/where_index_op.cc +++ /dev/null @@ -1,58 +0,0 @@ -/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "paddle/fluid/framework/infershape_utils.h" -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/phi/core/infermeta_utils.h" -#include "paddle/phi/infermeta/unary.h" - -namespace paddle { -namespace operators { - -class WhereIndexOp : public framework::OperatorWithKernel { - public: - using framework::OperatorWithKernel::OperatorWithKernel; - - protected: - phi::KernelKey GetExpectedKernelType( - const framework::ExecutionContext& ctx) const override { - auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "Condition"); - return phi::KernelKey(data_type, ctx.GetPlace()); - } -}; - -class WhereIndexOpMaker : public framework::OpProtoAndCheckerMaker { - public: - void Make() override { - AddInput("Condition", "A bool tensor whose rank is at least 1"); - AddOutput("Out", "An int64 tensor of rank 2"); - AddComment(R"DOC( - Return a int64 tensor with rank 2, specifying the coordinate of true element in `Condition`. -)DOC"); - } -}; -} // namespace operators -} // namespace paddle - -namespace ops = paddle::operators; -DECLARE_INFER_SHAPE_FUNCTOR(where_index, - WhereIndexInferShapeFunctor, - PD_INFER_META(phi::NonZeroInferMeta)); -REGISTER_OPERATOR( - where_index, - ops::WhereIndexOp, - ops::WhereIndexOpMaker, - paddle::framework::EmptyGradOpMaker, - paddle::framework::EmptyGradOpMaker, - WhereIndexInferShapeFunctor); diff --git a/paddle/phi/api/yaml/legacy_ops.yaml b/paddle/phi/api/yaml/legacy_ops.yaml index c0e94957fb9e4..b452ed939b110 100755 --- a/paddle/phi/api/yaml/legacy_ops.yaml +++ b/paddle/phi/api/yaml/legacy_ops.yaml @@ -1207,14 +1207,6 @@ func : nms data_type : x -- op : nonzero - args : (Tensor condition) - output : Tensor(out) - infer_meta : - func : NonZeroInferMeta - kernel : - func : nonzero - - op : norm args : (Tensor x, int axis, float epsilon, bool is_test) output : Tensor(out), Tensor(norm) diff --git a/paddle/phi/api/yaml/ops.yaml b/paddle/phi/api/yaml/ops.yaml index 5aabfb1b13b74..137d38ddb5999 100644 --- a/paddle/phi/api/yaml/ops.yaml +++ b/paddle/phi/api/yaml/ops.yaml @@ -1001,6 +1001,15 @@ optional : weight backward : nll_loss_grad +- op : nonzero + args : (Tensor condition) + output : Tensor(out) + infer_meta : + func : NonZeroInferMeta + kernel : + func : nonzero + data_type: condition + - op : npu_identity args : (Tensor x, int format = -1) output : Tensor diff --git a/paddle/phi/ops/compat/where_index_sig.cc b/paddle/phi/ops/compat/where_index_sig.cc deleted file mode 100644 index cfe2a8110cc84..0000000000000 --- a/paddle/phi/ops/compat/where_index_sig.cc +++ /dev/null @@ -1,27 +0,0 @@ -// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "paddle/phi/core/compat/op_utils.h" - -namespace phi { - -KernelSignature WhereIndexOpArgumentMapping(const ArgumentMappingContext& ctx) { - return KernelSignature("nonzero", {"Condition"}, {}, {"Out"}); -} - -} // namespace phi - -PD_REGISTER_BASE_KERNEL_NAME(where_index, nonzero); - -PD_REGISTER_ARG_MAPPING_FN(where_index, phi::WhereIndexOpArgumentMapping);