Skip to content

Commit

Permalink
[phi] move shape op (#40248)
Browse files Browse the repository at this point in the history
* add selected row op and fix bug in ctest

* modify the date

* fix bug in npu and xpu

* modfiy the include file
  • Loading branch information
Liu-xiandong authored Mar 10, 2022
1 parent df60166 commit 575dea8
Show file tree
Hide file tree
Showing 12 changed files with 279 additions and 90 deletions.
21 changes: 17 additions & 4 deletions paddle/fluid/operators/mkldnn/shape_mkldnn_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,19 +13,32 @@ See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/shape_op.h"
#include "paddle/fluid/platform/mkldnn_helper.h"

namespace paddle {
namespace operators {

using paddle::framework::Tensor;
using Tensor = framework::Tensor;
using LoDTensor = framework::LoDTensor;
using SelectedRows = phi::SelectedRows;

template <typename T>
class ShapeMKLDNNKernel : public ShapeKernel<T> {
class ShapeMKLDNNKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
ShapeKernel<T>::Compute(ctx);
auto* in_var = ctx.InputVar("Input");
framework::DDim in_dims;
if (in_var->IsType<phi::SelectedRows>()) {
in_dims = in_var->Get<phi::SelectedRows>().value().dims();
} else {
in_dims = in_var->Get<LoDTensor>().dims();
}
auto* out_t = ctx.Output<Tensor>("Out");
out_t->Resize({in_dims.size()});
auto out_data = out_t->mutable_data<int32_t>(platform::CPUPlace());
for (int i = 0; i < in_dims.size(); ++i) {
out_data[i] = in_dims[i];
}

auto* out = ctx.Output<Tensor>("Out");
out->set_layout(framework::DataLayout::kMKLDNN);
Expand Down
8 changes: 0 additions & 8 deletions paddle/fluid/operators/shape_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/operators/shape_op.h"
#include <string>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/complex.h"

namespace paddle {
namespace operators {
Expand Down Expand Up @@ -95,9 +93,3 @@ REGISTER_OPERATOR(
shape, ops::ShapeOp, ops::ShapeOpMaker,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OP_CPU_KERNEL(shape, ops::ShapeKernel<bool>, ops::ShapeKernel<int>,
ops::ShapeKernel<int8_t>, ops::ShapeKernel<uint8_t>,
ops::ShapeKernel<int64_t>, ops::ShapeKernel<float>,
ops::ShapeKernel<double>,
ops::ShapeKernel<plat::complex<float>>,
ops::ShapeKernel<plat::complex<double>>);
27 changes: 0 additions & 27 deletions paddle/fluid/operators/shape_op.cu

This file was deleted.

46 changes: 0 additions & 46 deletions paddle/fluid/operators/shape_op.h

This file was deleted.

2 changes: 1 addition & 1 deletion paddle/fluid/operators/shape_op_npu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/operators/shape_op.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/device/npu/npu_op_runner.h"

namespace paddle {
Expand Down
37 changes: 33 additions & 4 deletions paddle/fluid/operators/shape_op_xpu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,41 @@
* limitations under the License. */

#ifdef PADDLE_WITH_XPU
#include <algorithm>
#include "paddle/fluid/framework/op_registry.h"

#include "paddle/fluid/operators/shape_op.h"
namespace paddle {
namespace operators {

using Tensor = framework::Tensor;
using LoDTensor = framework::LoDTensor;
using SelectedRows = phi::SelectedRows;

template <typename T>
class ShapeXPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* in_var = ctx.InputVar("Input");
framework::DDim in_dims;
if (in_var->IsType<phi::SelectedRows>()) {
in_dims = in_var->Get<phi::SelectedRows>().value().dims();
} else {
in_dims = in_var->Get<LoDTensor>().dims();
}
auto* out_t = ctx.Output<Tensor>("Out");
out_t->Resize({in_dims.size()});
auto out_data = out_t->mutable_data<int32_t>(platform::CPUPlace());
for (int i = 0; i < in_dims.size(); ++i) {
out_data[i] = in_dims[i];
}
}
};
} // namespace operators
} // namespace paddle

namespace ops = paddle::operators;
REGISTER_OP_XPU_KERNEL(shape, ops::ShapeKernel<bool>, ops::ShapeKernel<int>,
ops::ShapeKernel<int64_t>, ops::ShapeKernel<float>,
ops::ShapeKernel<double>);
REGISTER_OP_XPU_KERNEL(shape, ops::ShapeXPUKernel<bool>,
ops::ShapeXPUKernel<int>, ops::ShapeXPUKernel<int64_t>,
ops::ShapeXPUKernel<float>, ops::ShapeXPUKernel<double>);

#endif
33 changes: 33 additions & 0 deletions paddle/phi/kernels/cpu/shape_kernel.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/phi/kernels/shape_kernel.h"
#include "paddle/phi/kernels/impl/shape_kernel_impl.h"

#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"

PD_REGISTER_KERNEL(shape,
CPU,
ALL_LAYOUT,
phi::ShapeKernel,
bool,
int,
int8_t,
uint8_t,
int64_t,
float,
double,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
35 changes: 35 additions & 0 deletions paddle/phi/kernels/gpu/shape_kernel.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/phi/kernels/shape_kernel.h"

#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/float16.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/shape_kernel_impl.h"

PD_REGISTER_KERNEL(shape,
GPU,
ALL_LAYOUT,
phi::ShapeKernel,
bool,
int,
int8_t,
uint8_t,
int64_t,
float,
double,
phi::dtype::complex<float>,
phi::dtype::complex<double>,
phi::dtype::float16) {}
36 changes: 36 additions & 0 deletions paddle/phi/kernels/impl/shape_kernel_impl.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

#include "paddle/phi/core/dense_tensor.h"

namespace phi {

template <typename T, typename Context>
void ShapeKernel(const Context& ctx,
const DenseTensor& input,
DenseTensor* out) {
auto in_var = &input;
phi::DDim in_dims;
in_dims = in_var->dims();
auto out_t = out;
out_t->Resize({in_dims.size()});
auto out_data = ctx.template HostAlloc<int32_t>(out_t);
for (int i = 0; i < in_dims.size(); ++i) {
out_data[i] = in_dims[i];
}
}

} // namespace phi
70 changes: 70 additions & 0 deletions paddle/phi/kernels/selected_rows/shape_kernel.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/phi/kernels/selected_rows/shape_kernel.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/float16.h"
#include "paddle/phi/core/kernel_registry.h"

namespace phi {
namespace sr {

template <typename T, typename Context>
void ShapeKernel(const Context& ctx,
const SelectedRows& input,
DenseTensor* out) {
auto in_var = input;
phi::DDim in_dims;
in_dims = in_var.value().dims();
auto out_t = out;
out_t->Resize({in_dims.size()});
auto out_data = ctx.template HostAlloc<int32_t>(out_t);
for (int i = 0; i < in_dims.size(); ++i) {
out_data[i] = in_dims[i];
}
}

} // namespace sr
} // namespace phi

PD_REGISTER_KERNEL(shape_sr,
CPU,
ALL_LAYOUT,
phi::sr::ShapeKernel,
bool,
int,
int8_t,
uint8_t,
int64_t,
float,
double,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}

#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
PD_REGISTER_KERNEL(shape_sr,
GPU,
ALL_LAYOUT,
phi::sr::ShapeKernel,
bool,
int,
int8_t,
uint8_t,
int64_t,
float,
double,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
#endif
28 changes: 28 additions & 0 deletions paddle/phi/kernels/selected_rows/shape_kernel.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

#include "paddle/phi/core/selected_rows.h"

namespace phi {
namespace sr {

template <typename T, typename Context>
void ShapeKernel(const Context& ctx,
const SelectedRows& input,
DenseTensor* out);

} // namespace sr
} // namespace phi
Loading

0 comments on commit 575dea8

Please sign in to comment.