Skip to content

Commit

Permalink
【Hackathon 4 No.21】Add i1 / i1e to paddle (#53210)
Browse files Browse the repository at this point in the history
* Add i1 and i1e op

* resolve merge conflicts
  • Loading branch information
LyndonKong authored May 17, 2023
1 parent 8965366 commit a63fb4c
Show file tree
Hide file tree
Showing 23 changed files with 1,128 additions and 0 deletions.
20 changes: 20 additions & 0 deletions paddle/phi/api/yaml/backward.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -821,6 +821,26 @@
kernel :
func : i0e_grad

- backward_op : i1_grad
forward : i1 (Tensor x) -> Tensor(out)
args : (Tensor x, Tensor out, Tensor out_grad)
output : Tensor(x_grad)
infer_meta :
func : UnchangedInferMeta
param : [x]
kernel :
func : i1_grad

- backward_op : i1e_grad
forward : i1e (Tensor x) -> Tensor(out)
args : (Tensor x, Tensor out, Tensor out_grad)
output : Tensor(x_grad)
infer_meta :
func : UnchangedInferMeta
param : [x]
kernel :
func : i1e_grad

- backward_op : imag_grad
forward : imag (Tensor x) -> Tensor(out)
args : (Tensor out_grad)
Expand Down
18 changes: 18 additions & 0 deletions paddle/phi/api/yaml/ops.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -948,6 +948,24 @@
func : i0e
backward : i0e_grad

- op : i1
args : (Tensor x)
output : Tensor(out)
infer_meta :
func : UnchangedInferMeta
kernel :
func : i1
backward : i1_grad

- op : i1e
args : (Tensor x)
output : Tensor(out)
infer_meta :
func : UnchangedInferMeta
kernel :
func : i1e
backward : i1e_grad

- op : imag
args : (Tensor x)
output : Tensor (out)
Expand Down
44 changes: 44 additions & 0 deletions paddle/phi/kernels/cpu/i1_grad_kernel.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/phi/kernels/i1_grad_kernel.h"

#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/for_range.h"
#include "paddle/phi/kernels/impl/bessel_grad_kernel_impl.h"

namespace phi {

template <typename T, typename Context>
void I1GradKernel(const Context& ctx,
const DenseTensor& x,
const DenseTensor& out,
const DenseTensor& out_grad,
DenseTensor* x_grad) {
const int64_t size = x.numel();
const T* x_data = x.data<T>();
const T* out_data = out.data<T>();
const T* out_grad_data = out_grad.data<T>();
T* x_grad_data = ctx.template Alloc<T>(x_grad);

phi::funcs::ForRange<Context> for_range(ctx, size);
I1GradFunctor<T> functor(x_data, out_data, out_grad_data, x_grad_data, size);
for_range(functor);
}

} // namespace phi

PD_REGISTER_KERNEL(i1_grad, CPU, ALL_LAYOUT, phi::I1GradKernel, float, double) {
}
37 changes: 37 additions & 0 deletions paddle/phi/kernels/cpu/i1_kernel.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/phi/kernels/i1_kernel.h"

#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/for_range.h"
#include "paddle/phi/kernels/impl/bessel_kernel_impl.h"

namespace phi {

template <typename T, typename Context>
void I1Kernel(const Context& ctx, const DenseTensor& x, DenseTensor* out) {
const int64_t size = x.numel();
const T* x_data = x.data<T>();
T* out_data = ctx.template Alloc<T>(out);

phi::funcs::ForRange<Context> for_range(ctx, size);
I1Functor<T> functor(x_data, out_data, size);
for_range(functor);
}

} // namespace phi

PD_REGISTER_KERNEL(i1, CPU, ALL_LAYOUT, phi::I1Kernel, float, double) {}
44 changes: 44 additions & 0 deletions paddle/phi/kernels/cpu/i1e_grad_kernel.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/phi/kernels/i1e_grad_kernel.h"

#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/for_range.h"
#include "paddle/phi/kernels/impl/bessel_grad_kernel_impl.h"

namespace phi {

template <typename T, typename Context>
void I1eGradKernel(const Context& ctx,
const DenseTensor& x,
const DenseTensor& out,
const DenseTensor& out_grad,
DenseTensor* x_grad) {
const int64_t size = x.numel();
const T* x_data = x.data<T>();
const T* out_data = out.data<T>();
const T* out_grad_data = out_grad.data<T>();
T* x_grad_data = ctx.template Alloc<T>(x_grad);

phi::funcs::ForRange<Context> for_range(ctx, size);
I1eGradFunctor<T> functor(x_data, out_data, out_grad_data, x_grad_data, size);
for_range(functor);
}

} // namespace phi

PD_REGISTER_KERNEL(
i1e_grad, CPU, ALL_LAYOUT, phi::I1eGradKernel, float, double) {}
37 changes: 37 additions & 0 deletions paddle/phi/kernels/cpu/i1e_kernel.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/phi/kernels/i1e_kernel.h"

#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/for_range.h"
#include "paddle/phi/kernels/impl/bessel_kernel_impl.h"

namespace phi {

template <typename T, typename Context>
void I1eKernel(const Context& ctx, const DenseTensor& x, DenseTensor* out) {
const int64_t size = x.numel();
const T* x_data = x.data<T>();
T* out_data = ctx.template Alloc<T>(out);

phi::funcs::ForRange<Context> for_range(ctx, size);
I1eFunctor<T> functor(x_data, out_data, size);
for_range(functor);
}

} // namespace phi

PD_REGISTER_KERNEL(i1e, CPU, ALL_LAYOUT, phi::I1eKernel, float, double) {}
37 changes: 37 additions & 0 deletions paddle/phi/kernels/gpu/i1_grad_kernel.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/phi/kernels/i1_grad_kernel.h"

#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/elementwise_base.h"
#include "paddle/phi/kernels/impl/bessel_grad_kernel_cuda_impl.h"

namespace phi {

template <typename T, typename Context>
void I1GradKernel(const Context& ctx,
const DenseTensor& x,
const DenseTensor& out,
const DenseTensor& out_grad,
DenseTensor* x_grad) {
ctx.template Alloc<T>(x_grad);
std::vector<const DenseTensor*> ins = {&x, &out, &out_grad};
std::vector<DenseTensor*> outs = {x_grad};
auto functor = CudaI1GradFunctor<T>();
phi::funcs::ElementwiseKernel<T>(ctx, ins, &outs, functor);
}

} // namespace phi

PD_REGISTER_KERNEL(i1_grad, GPU, ALL_LAYOUT, phi::I1GradKernel, float, double) {
}
32 changes: 32 additions & 0 deletions paddle/phi/kernels/gpu/i1_kernel.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/phi/kernels/i1_kernel.h"

#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/elementwise_base.h"
#include "paddle/phi/kernels/impl/bessel_kernel_cuda_impl.h"

namespace phi {

template <typename T, typename Context>
void I1Kernel(const Context& ctx, const DenseTensor& x, DenseTensor* out) {
ctx.template Alloc<T>(out);
std::vector<const DenseTensor*> ins = {&x};
std::vector<DenseTensor*> outs = {out};
auto functor = CudaI1Functor<T>();
phi::funcs::ElementwiseKernel<T>(ctx, ins, &outs, functor);
}

} // namespace phi

PD_REGISTER_KERNEL(i1, GPU, ALL_LAYOUT, phi::I1Kernel, float, double) {}
37 changes: 37 additions & 0 deletions paddle/phi/kernels/gpu/i1e_grad_kernel.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/phi/kernels/i1e_grad_kernel.h"

#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/elementwise_base.h"
#include "paddle/phi/kernels/impl/bessel_grad_kernel_cuda_impl.h"

namespace phi {

template <typename T, typename Context>
void I1eGradKernel(const Context& ctx,
const DenseTensor& x,
const DenseTensor& out,
const DenseTensor& out_grad,
DenseTensor* x_grad) {
ctx.template Alloc<T>(x_grad);
std::vector<const DenseTensor*> ins = {&x, &out, &out_grad};
std::vector<DenseTensor*> outs = {x_grad};
auto functor = CudaI1eGradFunctor<T>();
phi::funcs::ElementwiseKernel<T>(ctx, ins, &outs, functor);
}

} // namespace phi

PD_REGISTER_KERNEL(
i1e_grad, GPU, ALL_LAYOUT, phi::I1eGradKernel, float, double) {}
32 changes: 32 additions & 0 deletions paddle/phi/kernels/gpu/i1e_kernel.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/phi/kernels/i1e_kernel.h"

#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/elementwise_base.h"
#include "paddle/phi/kernels/impl/bessel_kernel_cuda_impl.h"

namespace phi {

template <typename T, typename Context>
void I1eKernel(const Context& ctx, const DenseTensor& x, DenseTensor* out) {
ctx.template Alloc<T>(out);
std::vector<const DenseTensor*> ins = {&x};
std::vector<DenseTensor*> outs = {out};
auto functor = CudaI1eFunctor<T>();
phi::funcs::ElementwiseKernel<T>(ctx, ins, &outs, functor);
}

} // namespace phi

PD_REGISTER_KERNEL(i1e, GPU, ALL_LAYOUT, phi::I1eKernel, float, double) {}
38 changes: 38 additions & 0 deletions paddle/phi/kernels/i1_grad_kernel.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/device_context.h"

namespace phi {

/**
* @brief This kernel calculate gradient of Modified Bessel function of order 1.
* @param ctx device context
* @param x
* @param out
* @param out_grad
* @param x_grad
*/

template <typename T, typename Context>
void I1GradKernel(const Context& ctx,
const DenseTensor& x,
const DenseTensor& out,
const DenseTensor& out_grad,
DenseTensor* x_grad);

} // namespace phi
Loading

0 comments on commit a63fb4c

Please sign in to comment.