From 7f7189dddfc071ae31bbffa82f958005d7f5c44f Mon Sep 17 00:00:00 2001 From: ZZK <359521840@qq.com> Date: Tue, 11 Jul 2023 15:39:01 +0800 Subject: [PATCH] Integrate rmsnorm kernel (#54998) * add rmsnorm kernel * add static graph test * fix round type * use alignas to avoid msvc compile error * remove redundant headerfile to avoid rocm compile error * fix rocm compile not found cub * Add document --- paddle/phi/api/yaml/ops.yaml | 10 + paddle/phi/infermeta/binary.cc | 32 + paddle/phi/infermeta/binary.h | 7 + paddle/phi/kernels/gpu/rms_norm_kernel.cu | 1252 +++++++++++++++++ paddle/phi/kernels/rms_norm_kernel.h | 85 ++ .../paddle/incubate/nn/functional/__init__.py | 2 + .../paddle/incubate/nn/functional/rms_norm.py | 59 + test/legacy_test/CMakeLists.txt | 2 + test/legacy_test/test_rms_norm_op.py | 176 +++ 9 files changed, 1625 insertions(+) create mode 100644 paddle/phi/kernels/gpu/rms_norm_kernel.cu create mode 100644 paddle/phi/kernels/rms_norm_kernel.h create mode 100644 python/paddle/incubate/nn/functional/rms_norm.py create mode 100644 test/legacy_test/test_rms_norm_op.py diff --git a/paddle/phi/api/yaml/ops.yaml b/paddle/phi/api/yaml/ops.yaml index 1962836e3ddb2..b87794af529eb 100644 --- a/paddle/phi/api/yaml/ops.yaml +++ b/paddle/phi/api/yaml/ops.yaml @@ -1994,6 +1994,16 @@ data_type : x backward : reverse_grad +- op : rms_norm + args : (Tensor x, Tensor weight, Tensor bias, float epsilon, int begin_norm_axis) + output : Tensor(out) + infer_meta : + func : RmsNormInferMeta + kernel : + func : rms_norm + data_type : x + optional : bias + - op : rmsprop_ args : (Tensor param, Tensor mean_square, Tensor grad, Tensor moment, Tensor learning_rate, Tensor mean_grad, Tensor master_param, float epsilon = 1.0e-10f, float decay = 0.9f, float momentum = 0.0f, bool centered = false, bool multi_precision = false) output : Tensor(param_out), Tensor(moment_out), Tensor(mean_square_out), Tensor(mean_grad_out), Tensor(master_param_outs) diff --git a/paddle/phi/infermeta/binary.cc b/paddle/phi/infermeta/binary.cc index 0a3c429f099d1..2698563eacc77 100644 --- a/paddle/phi/infermeta/binary.cc +++ b/paddle/phi/infermeta/binary.cc @@ -3137,6 +3137,38 @@ void Unpool3dInferMeta(const MetaTensor& x, } } +void RmsNormInferMeta(const MetaTensor& x, + const MetaTensor& weight, + const MetaTensor& bias, + const float epsilon, + const int begin_norm_axis, + MetaTensor* out) { + std::vector x_dims_vec = phi::vectorize(x.dims()); + auto x_dims_size = x_dims_vec.size(); + + size_t normalized_dims = 1; + for (size_t i = begin_norm_axis; i < x_dims_size; ++i) { + normalized_dims *= x_dims_vec[i]; + } + + PADDLE_ENFORCE_EQ(normalized_dims, + weight.dims()[0], + phi::errors::InvalidArgument( + "The normalized size of Input(X) must equal to be" + "the size of Weight, but received" + "normalized size of Input(X) is [%d], received size" + "of Weight is [%d]", + normalized_dims, + weight.dims()[0])); + + auto out_dims = phi::make_ddim(x_dims_vec); + + out->set_dims(out_dims); + out->set_dtype(x.dtype()); + out->set_layout(x.layout()); + out->share_lod(x); +} + } // namespace phi PD_REGISTER_INFER_META_FN(add_raw, phi::ElementwiseRawInferMeta); diff --git a/paddle/phi/infermeta/binary.h b/paddle/phi/infermeta/binary.h index 0af92a6accdc7..517b259f0149f 100644 --- a/paddle/phi/infermeta/binary.h +++ b/paddle/phi/infermeta/binary.h @@ -479,4 +479,11 @@ void Unpool3dInferMeta(const MetaTensor& x, MetaTensor* out, MetaConfig config = MetaConfig()); +void RmsNormInferMeta(const MetaTensor& x, + const MetaTensor& weight, + const MetaTensor& bias, + const float epsilon, + const int begin_norm_axis, + MetaTensor* out); + } // namespace phi diff --git a/paddle/phi/kernels/gpu/rms_norm_kernel.cu b/paddle/phi/kernels/gpu/rms_norm_kernel.cu new file mode 100644 index 0000000000000..ccbb1f2f4baa6 --- /dev/null +++ b/paddle/phi/kernels/gpu/rms_norm_kernel.cu @@ -0,0 +1,1252 @@ +/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +// Original OneFlow copyright notice: + +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +// https://github.com/Oneflow-Inc/oneflow/blob/master/oneflow/core/cuda/layer_norm.cuh +// The following code modified from OneFlow's implementation, and change to use +// single Pass algorithm. Support Int8 quant, dequant Load/Store implementation. + +#include "paddle/phi/kernels/rms_norm_kernel.h" +#include +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/common/amp_type_traits.h" +#include "paddle/phi/core/kernel_registry.h" +#ifndef PADDLE_WITH_HIP +#include +#endif + +namespace phi { + +namespace { + +#ifndef PADDLE_WITH_HIP + +constexpr int kWarpSize = 32; + +template +struct SumOp { + __device__ __forceinline__ T operator()(const T& a, const T& b) const { + return a + b; + } +}; + +template +struct MaxOp { + __device__ __forceinline__ T operator()(const T& a, const T& b) const { + return max(a, b); + } +}; + +template