From e4938e0a9cee15ffe2f8d205e0228c1842a5735c Mon Sep 17 00:00:00 2001 From: YiSheng5 Date: Wed, 24 Jul 2024 10:17:28 +0800 Subject: [PATCH] add the backward implementation for rms norm (#4517) (#4527) --- csrc/gpu/aten/operators/RMSNorm.cpp | 265 +++++++++++++++++- .../xpu/intrinsic/__init__.py | 5 + tests/gpu/examples/test_rms_norm.py | 45 +++ 3 files changed, 313 insertions(+), 2 deletions(-) diff --git a/csrc/gpu/aten/operators/RMSNorm.cpp b/csrc/gpu/aten/operators/RMSNorm.cpp index a9ededb16..4d2fbecf9 100644 --- a/csrc/gpu/aten/operators/RMSNorm.cpp +++ b/csrc/gpu/aten/operators/RMSNorm.cpp @@ -2,17 +2,37 @@ #include #include +#include #include +#include +#include +#include #include "Norm.h" +#include "comm/ATDispatch.h" #include "comm/RegistrationDeclarations.h" #include "utils/CustomOperatorRegistration.h" using namespace xpu::dpcpp; +using namespace torch::autograd; using namespace at::AtenIpexTypeXPU::normalization; namespace at { namespace AtenIpexTypeXPU { +std::tuple rms_norm_fw( + const Tensor& input, + at::IntArrayRef normalized_shape, + const Tensor& weight, + double epsilon); + +std::tuple rms_norm_bw( + const Tensor& grad_output, + const Tensor& input, + at::IntArrayRef normalized_shape, + const Tensor& rstd, + const Tensor& weight, + std::array grad_input_mask); + template class RMSNormForward : public NormForward { public: @@ -337,12 +357,13 @@ void RMSNormKernelImpl( X.scalar_type(), "RMSNormKernelImpl", [&]() { - rstd = at::empty({M}, X.options().dtype(kFloat)); if (gamma.scalar_type() == kFloat) { + rstd = at::empty({M}, X.options().dtype(kFloat)); RMSNormKernelImplInternal( X, gamma, M, N, static_cast>(eps), Y, rstd); } else { - RMSNormKernelImplInternal( + rstd = at::empty({M}, X.options()); + RMSNormKernelImplInternal( X, gamma, M, N, static_cast>(eps), Y, rstd); } }); @@ -374,11 +395,251 @@ std::tuple rms_norm_fw( return std::make_tuple(output.reshape(input.sizes()), rstd); } +template +void RmsNormBackwardKernelImplInternal( + const Tensor& dY, + const Tensor& X, + const Tensor& rstd, + const Tensor& gamma, + int64_t M, + int64_t N, + Tensor& dX, + Tensor& dgamma, + const Tensor& output, + std::array grad_input_mask) { + TORCH_CHECK(dY.numel() == M * N); + TORCH_CHECK(rstd.numel() == M); + + using accscalar_t = acc_type; + mean_t* var_data = rstd.data_ptr(); + weight_t* gamma_data = gamma.defined() ? gamma.data_ptr() : nullptr; + + if (grad_input_mask[0]) { + // backward data + scalar_t* X_data = X.data_ptr(); + scalar_t* dY_data = dY.data_ptr(); + scalar_t* dX_data = dX.data_ptr(); + + auto config = NormConfig(M, N, 1, sizeof(scalar_t)); + bool can_use_32bit_index = canUse32BitIndexMath(X) && + canUse32BitIndexMath(dY) && canUse32BitIndexMath(dX); + + // TODO: force it to use fused_norm_kernel + config.workgroup_num_foreach = 1; + config.WGPlane = config.Plane; + + if (config.workgroup_num_foreach == 1) { + RMSNormBackward rms_norm_backward( + X_data, dY_data, dX_data, var_data, gamma_data, M, N); + launch_vectorized_fused_norm_kernel< + scalar_t, + mean_t, + weight_t, + RMSNormBackward, + true>(rms_norm_backward, config, can_use_32bit_index); + } else { + const auto kAccType = + (X.scalar_type() == kHalf || X.scalar_type() == kBFloat16) + ? kFloat + : X.scalar_type(); + Tensor a = at::empty({M}, X.options().dtype(kAccType)); + accscalar_t* a_data = a.data_ptr(); + + RMSNormBackward rms_norm_backward( + X_data, dY_data, dX_data, var_data, gamma_data, a_data, M, N); + Tensor semaphores, scratchpad; + config.template init_global_reduce( + X, semaphores, scratchpad); + RowwiseMomentsDPCPPKernelImpl< + scalar_t, + mean_t, + weight_t, + RMSNormBackward, + true>(rms_norm_backward, config, can_use_32bit_index); + NormUpdateKernelImpl( + rms_norm_backward, config, can_use_32bit_index); + } + } + + if (grad_input_mask[1]) { + // backward weight + Tensor sum_tmp = at::mul(output, dY); + at::sum_out(dgamma, sum_tmp, at::IntArrayRef{0, 1}); + } +} + +void RmsNormBackwardKernelImpl( + const Tensor& dY, + const Tensor& X, + const Tensor& rstd, + const Tensor& gamma, + int64_t M, + int64_t N, + Tensor& dX, + Tensor& dgamma, + const Tensor& output, + std::array grad_input_mask) { + IPEX_DISPATCH_FLOATING_TYPES_AND2( + at::ScalarType::Half, + at::ScalarType::BFloat16, + X.scalar_type(), + "RmsNormBackwardKernelImpl", + [&]() { + using accscalar_t = acc_type; + if (gamma.scalar_type() == kFloat) { + RmsNormBackwardKernelImplInternal( + dY, X, rstd, gamma, M, N, dX, dgamma, output, grad_input_mask); + } else { + RmsNormBackwardKernelImplInternal( + dY, X, rstd, gamma, M, N, dX, dgamma, output, grad_input_mask); + } + }); +} + +std::tuple rms_norm_bw( + const Tensor& grad_output, + const Tensor& input, + at::IntArrayRef normalized_shape, + const Tensor& rstd, + const Tensor& weight, + const Tensor& output, + std::array grad_input_mask) { + RECORD_FUNCTION("ipex::rms_norm_bw", std::vector({grad_output})); + auto M_N = + _check_layer_norm_inputs(input, normalized_shape, weight, Tensor()); + auto M = M_N.first; + auto N = M_N.second; + + Tensor grad_input; + Tensor grad_weight; + + if (grad_input_mask[0]) { + grad_input = at::native::empty_like( + input, + c10::nullopt /* dtype */, + c10::nullopt /* layout */, + c10::nullopt /* device */, + c10::nullopt /* pin_memory */, + LEGACY_CONTIGUOUS_MEMORY_FORMAT); + } + + if (grad_input_mask[1]) { + grad_weight = M > 0 ? at::native::empty_like( + weight, + c10::nullopt /* dtype */, + c10::nullopt /* layout */, + c10::nullopt /* device */, + c10::nullopt /* pin_memory */, + LEGACY_CONTIGUOUS_MEMORY_FORMAT) + : at::native::zeros_like( + weight, + c10::nullopt /* dtype */, + c10::nullopt /* layout */, + c10::nullopt /* device */, + c10::nullopt /* pin_memory */, + LEGACY_CONTIGUOUS_MEMORY_FORMAT); + } + + if (input.numel() != 0 && grad_output.numel() != 0) { + Tensor input_ = (input.dim() == 1) ? input.reshape({M, N}) : input; + Tensor grad_output_ = + (grad_output.dim() == 1) ? grad_output.reshape({M, N}) : grad_output; + Tensor weight_ = + (weight.defined() && weight.dim() == 1) ? weight.reshape({N}) : weight; + Tensor output_ = (output.dim() == 1) ? output.reshape({M, N}) : output; + + input_ = input_.contiguous(); + grad_output_ = grad_output_.contiguous(); + output_ = output_.contiguous(); + weight_ = weight_.defined() ? weight_.contiguous() : weight_; + + RmsNormBackwardKernelImpl( + grad_output_, + input_, + rstd, + weight_, + M, + N, + grad_input, + grad_weight, + output_, + grad_input_mask); + } + return std::make_tuple( + grad_input_mask[0] ? grad_input.reshape(input.sizes()) : grad_input, + grad_input_mask[1] ? grad_weight.reshape(weight.sizes()) : grad_weight); +} + +class IPEXRmsNormOp : public Function { + public: + static variable_list forward( + AutogradContext* ctx, + const Tensor& input, + at::IntArrayRef normalized_shape, + const Tensor& weight, + double epsilon) { +#ifdef BUILD_SIMPLE_TRACE + SimpleTrace trace( + "IPEXRmsNormOp forward -> at::AtenIpexTypeXPU::IPEXRmsNormOp::forward"); +#endif + ctx->saved_data["input_requires_grad"] = input.requires_grad(); + ctx->saved_data["weight_requires_grad"] = weight.requires_grad(); + ctx->saved_data["normalized_shape"] = normalized_shape; + auto outputs = rms_norm_fw(input, normalized_shape, weight, epsilon); + + ctx->save_for_backward( + {input, weight, std::get<0>(outputs), std::get<1>(outputs)}); + variable_list result = {std::get<0>(outputs), std::get<1>(outputs)}; + return result; + } + + static variable_list backward( + AutogradContext* ctx, + variable_list grad_outputs) { +#ifdef BUILD_SIMPLE_TRACE + SimpleTrace trace( + "IPEXRmsNormOp backward -> at::AtenIpexTypeXPU::IPEXRmsNormOp::backward"); +#endif + auto weight_requires_grad = + ctx->saved_data["weight_requires_grad"].toBool(); + auto input_requires_grad = ctx->saved_data["input_requires_grad"].toBool(); + auto saved = ctx->get_saved_variables(); + Tensor input = saved[0]; + Tensor weight = saved[1]; + Tensor output = saved[2]; + Tensor rstd = saved[3]; + auto normalized_shape = weight.sizes(); + + auto grad_inputs = rms_norm_bw( + grad_outputs[0], + input, + normalized_shape, + rstd, + weight, + output, + {input_requires_grad, weight_requires_grad}); + return { + std::get<0>(grad_inputs), Tensor(), std::get<1>(grad_inputs), Tensor()}; + } +}; + +Tensor rms_norm_impl( + const Tensor& input, + at::IntArrayRef normalized_shape, + const Tensor& weight, + double epsilon) { + auto output = IPEXRmsNormOp::apply(input, normalized_shape, weight, epsilon); + return output[0]; +} } // namespace AtenIpexTypeXPU } // namespace at namespace { IPEX_LIBRARY_FRAGMENT() { + IPEX_OP_REGISTER_DISPATCH( + "rms_norm_impl", + at::AtenIpexTypeXPU::rms_norm_impl, + c10::DispatchKey::AutogradXPU); IPEX_OP_REGISTER("rms_norm.xpu", at::AtenIpexTypeXPU::rms_norm_fw); } } // namespace diff --git a/intel_extension_for_pytorch/xpu/intrinsic/__init__.py b/intel_extension_for_pytorch/xpu/intrinsic/__init__.py index 5a1418641..cf354c77e 100644 --- a/intel_extension_for_pytorch/xpu/intrinsic/__init__.py +++ b/intel_extension_for_pytorch/xpu/intrinsic/__init__.py @@ -25,6 +25,7 @@ "copy_blocks", "swap_blocks", "IpexPaged_attention", + "IpexRmsNorm", ] @@ -164,6 +165,10 @@ def IpexSDP_dropout( ) +def IpexRmsNorm(input, normalized_shape, weight, epsilon) -> Tensor: + return torch.ops.torch_ipex.rms_norm_impl(input, normalized_shape, weight, epsilon) + + def varlen_fwd( query, # [total_q, num_head, head_size] key, # [total_k, num_head_k, head_size] diff --git a/tests/gpu/examples/test_rms_norm.py b/tests/gpu/examples/test_rms_norm.py index 9d17728e5..bad521625 100644 --- a/tests/gpu/examples/test_rms_norm.py +++ b/tests/gpu/examples/test_rms_norm.py @@ -40,11 +40,56 @@ def test_rms_norm_fw_xpu(dtype): w = model.weight.xpu() output = torch.ops.torch_ipex.rms_norm(input_case, [hsz], w, 1e-5) output1 = ipex.llm.modules.RMSNorm.apply(input_case, w, 1e-5) + output2 = torch.xpu.IpexRmsNorm(input_case, [hsz], w, 1e-5) # diff = (output.cpu() - output_ref).abs().max().item() # print('diff', diff) # assert diff < 1e-2 self.assertEqual(output[0].cpu(), output_ref, atol=1e-2, rtol=1e-2) self.assertEqual(output1.cpu(), output_ref, atol=1e-2, rtol=1e-2) + self.assertEqual(output2.cpu(), output_ref, atol=1e-2, rtol=1e-2) test_rms_norm_fw_xpu(torch.float) test_rms_norm_fw_xpu(torch.bfloat16) + + def test_rms_norm_bw(self): + def test_rms_norm_fwd_bwd(dtype): + print("test_rms_norm_fw_bw", dtype) + torch.manual_seed(13) + modelb = RMSNormRef(64) + model0 = RMSNormRef(768) + model1 = RMSNormRef(2048) + model2 = RMSNormRef(4096) + model3 = RMSNormRef(16384) + model4 = RMSNormRef(16384 * 4 + 123) + hszs = [64, 768, 2048, 4096, 16384, 16384 * 4 + 123] + ls = [modelb, model0, model1, model2, model3, model4] + for i, model in enumerate(ls): + model = model.to(dtype) + hsz = hszs[i] + input_case = torch.rand(4, 1024, hsz).to(dtype) + input_case.requires_grad_(True) + grad = torch.rand(4, 1024, hsz).to(dtype) + output_ref = model(input_case) + output_ref.backward(grad) + grad_wei = model.weight.grad.clone() + input_grad_cpu = input_case.grad.clone() + w = model.weight.clone() + + input_case_xpu = input_case.clone().xpu() + input_case_xpu.retain_grad() + input_case_xpu.requires_grad_(True) + grad_xpu = grad.xpu() + w = w.xpu() + w.retain_grad() + w.requires_grad_(True) + output1 = torch.xpu.IpexRmsNorm(input_case_xpu, [hsz], w, 1e-5) + output1.backward(grad_xpu) + grad_wei_xpu = w.grad + + self.assertEqual(grad_wei_xpu.cpu(), grad_wei, atol=10e-2, rtol=10e-2) + self.assertEqual( + input_case_xpu.grad.cpu(), input_grad_cpu, atol=10e-2, rtol=10e-2 + ) + + test_rms_norm_fwd_bwd(torch.bfloat16) + test_rms_norm_fwd_bwd(torch.float)