Skip to content

Commit

Permalink
Integrate rmsnorm kernel (PaddlePaddle#54998)
Browse files Browse the repository at this point in the history
* add rmsnorm kernel
* add static graph test
* fix round type
* use alignas to avoid msvc compile error
* remove redundant headerfile to avoid rocm compile error
* fix rocm compile not found cub
* Add document
  • Loading branch information
MARD1NO authored and wz1qqx committed Jul 31, 2023
1 parent d725347 commit 7f7189d
Show file tree
Hide file tree
Showing 9 changed files with 1,625 additions and 0 deletions.
10 changes: 10 additions & 0 deletions paddle/phi/api/yaml/ops.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1994,6 +1994,16 @@
data_type : x
backward : reverse_grad

- op : rms_norm
args : (Tensor x, Tensor weight, Tensor bias, float epsilon, int begin_norm_axis)
output : Tensor(out)
infer_meta :
func : RmsNormInferMeta
kernel :
func : rms_norm
data_type : x
optional : bias

- op : rmsprop_
args : (Tensor param, Tensor mean_square, Tensor grad, Tensor moment, Tensor learning_rate, Tensor mean_grad, Tensor master_param, float epsilon = 1.0e-10f, float decay = 0.9f, float momentum = 0.0f, bool centered = false, bool multi_precision = false)
output : Tensor(param_out), Tensor(moment_out), Tensor(mean_square_out), Tensor(mean_grad_out), Tensor(master_param_outs)
Expand Down
32 changes: 32 additions & 0 deletions paddle/phi/infermeta/binary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3137,6 +3137,38 @@ void Unpool3dInferMeta(const MetaTensor& x,
}
}

void RmsNormInferMeta(const MetaTensor& x,
const MetaTensor& weight,
const MetaTensor& bias,
const float epsilon,
const int begin_norm_axis,
MetaTensor* out) {
std::vector<int64_t> x_dims_vec = phi::vectorize(x.dims());
auto x_dims_size = x_dims_vec.size();

size_t normalized_dims = 1;
for (size_t i = begin_norm_axis; i < x_dims_size; ++i) {
normalized_dims *= x_dims_vec[i];
}

PADDLE_ENFORCE_EQ(normalized_dims,
weight.dims()[0],
phi::errors::InvalidArgument(
"The normalized size of Input(X) must equal to be"
"the size of Weight, but received"
"normalized size of Input(X) is [%d], received size"
"of Weight is [%d]",
normalized_dims,
weight.dims()[0]));

auto out_dims = phi::make_ddim(x_dims_vec);

out->set_dims(out_dims);
out->set_dtype(x.dtype());
out->set_layout(x.layout());
out->share_lod(x);
}

} // namespace phi

PD_REGISTER_INFER_META_FN(add_raw, phi::ElementwiseRawInferMeta);
7 changes: 7 additions & 0 deletions paddle/phi/infermeta/binary.h
Original file line number Diff line number Diff line change
Expand Up @@ -479,4 +479,11 @@ void Unpool3dInferMeta(const MetaTensor& x,
MetaTensor* out,
MetaConfig config = MetaConfig());

void RmsNormInferMeta(const MetaTensor& x,
const MetaTensor& weight,
const MetaTensor& bias,
const float epsilon,
const int begin_norm_axis,
MetaTensor* out);

} // namespace phi
Loading

0 comments on commit 7f7189d

Please sign in to comment.