Skip to content

Commit

Permalink
add the backward implementation for rms norm (#4517) (#4527)
Browse files Browse the repository at this point in the history
  • Loading branch information
ys950902 authored Jul 24, 2024
1 parent 020075f commit e4938e0
Show file tree
Hide file tree
Showing 3 changed files with 313 additions and 2 deletions.
265 changes: 263 additions & 2 deletions csrc/gpu/aten/operators/RMSNorm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,37 @@
#include <ATen/Config.h>
#include <ATen/NativeFunctions.h>

#include <ATen/record_function.h>
#include <oneDNN/oneDNN.h>
#include <torch/autograd.h>
#include <torch/custom_class.h>
#include <utils/SimpleTrace.h>
#include "Norm.h"
#include "comm/ATDispatch.h"
#include "comm/RegistrationDeclarations.h"
#include "utils/CustomOperatorRegistration.h"

using namespace xpu::dpcpp;
using namespace torch::autograd;
using namespace at::AtenIpexTypeXPU::normalization;

namespace at {
namespace AtenIpexTypeXPU {

std::tuple<Tensor, Tensor> rms_norm_fw(
const Tensor& input,
at::IntArrayRef normalized_shape,
const Tensor& weight,
double epsilon);

std::tuple<Tensor, Tensor> rms_norm_bw(
const Tensor& grad_output,
const Tensor& input,
at::IntArrayRef normalized_shape,
const Tensor& rstd,
const Tensor& weight,
std::array<bool, 2> grad_input_mask);

template <typename scalar_t, typename mean_t, typename weight_t>
class RMSNormForward : public NormForward<scalar_t, mean_t, weight_t, true> {
public:
Expand Down Expand Up @@ -337,12 +357,13 @@ void RMSNormKernelImpl(
X.scalar_type(),
"RMSNormKernelImpl",
[&]() {
rstd = at::empty({M}, X.options().dtype(kFloat));
if (gamma.scalar_type() == kFloat) {
rstd = at::empty({M}, X.options().dtype(kFloat));
RMSNormKernelImplInternal<scalar_t, float, float>(
X, gamma, M, N, static_cast<acc_type<scalar_t>>(eps), Y, rstd);
} else {
RMSNormKernelImplInternal<scalar_t, float, scalar_t>(
rstd = at::empty({M}, X.options());
RMSNormKernelImplInternal<scalar_t, scalar_t, scalar_t>(
X, gamma, M, N, static_cast<acc_type<scalar_t>>(eps), Y, rstd);
}
});
Expand Down Expand Up @@ -374,11 +395,251 @@ std::tuple<Tensor, Tensor> rms_norm_fw(
return std::make_tuple(output.reshape(input.sizes()), rstd);
}

template <typename scalar_t, typename mean_t, typename weight_t>
void RmsNormBackwardKernelImplInternal(
const Tensor& dY,
const Tensor& X,
const Tensor& rstd,
const Tensor& gamma,
int64_t M,
int64_t N,
Tensor& dX,
Tensor& dgamma,
const Tensor& output,
std::array<bool, 2> grad_input_mask) {
TORCH_CHECK(dY.numel() == M * N);
TORCH_CHECK(rstd.numel() == M);

using accscalar_t = acc_type<scalar_t>;
mean_t* var_data = rstd.data_ptr<mean_t>();
weight_t* gamma_data = gamma.defined() ? gamma.data_ptr<weight_t>() : nullptr;

if (grad_input_mask[0]) {
// backward data
scalar_t* X_data = X.data_ptr<scalar_t>();
scalar_t* dY_data = dY.data_ptr<scalar_t>();
scalar_t* dX_data = dX.data_ptr<scalar_t>();

auto config = NormConfig(M, N, 1, sizeof(scalar_t));
bool can_use_32bit_index = canUse32BitIndexMath(X) &&
canUse32BitIndexMath(dY) && canUse32BitIndexMath(dX);

// TODO: force it to use fused_norm_kernel
config.workgroup_num_foreach = 1;
config.WGPlane = config.Plane;

if (config.workgroup_num_foreach == 1) {
RMSNormBackward<scalar_t, mean_t, weight_t> rms_norm_backward(
X_data, dY_data, dX_data, var_data, gamma_data, M, N);
launch_vectorized_fused_norm_kernel<
scalar_t,
mean_t,
weight_t,
RMSNormBackward,
true>(rms_norm_backward, config, can_use_32bit_index);
} else {
const auto kAccType =
(X.scalar_type() == kHalf || X.scalar_type() == kBFloat16)
? kFloat
: X.scalar_type();
Tensor a = at::empty({M}, X.options().dtype(kAccType));
accscalar_t* a_data = a.data_ptr<accscalar_t>();

RMSNormBackward<scalar_t, mean_t, weight_t> rms_norm_backward(
X_data, dY_data, dX_data, var_data, gamma_data, a_data, M, N);
Tensor semaphores, scratchpad;
config.template init_global_reduce<accscalar_t>(
X, semaphores, scratchpad);
RowwiseMomentsDPCPPKernelImpl<
scalar_t,
mean_t,
weight_t,
RMSNormBackward,
true>(rms_norm_backward, config, can_use_32bit_index);
NormUpdateKernelImpl<scalar_t, mean_t, weight_t, RMSNormBackward, true>(
rms_norm_backward, config, can_use_32bit_index);
}
}

if (grad_input_mask[1]) {
// backward weight
Tensor sum_tmp = at::mul(output, dY);
at::sum_out(dgamma, sum_tmp, at::IntArrayRef{0, 1});
}
}

void RmsNormBackwardKernelImpl(
const Tensor& dY,
const Tensor& X,
const Tensor& rstd,
const Tensor& gamma,
int64_t M,
int64_t N,
Tensor& dX,
Tensor& dgamma,
const Tensor& output,
std::array<bool, 2> grad_input_mask) {
IPEX_DISPATCH_FLOATING_TYPES_AND2(
at::ScalarType::Half,
at::ScalarType::BFloat16,
X.scalar_type(),
"RmsNormBackwardKernelImpl",
[&]() {
using accscalar_t = acc_type<scalar_t>;
if (gamma.scalar_type() == kFloat) {
RmsNormBackwardKernelImplInternal<scalar_t, float, float>(
dY, X, rstd, gamma, M, N, dX, dgamma, output, grad_input_mask);
} else {
RmsNormBackwardKernelImplInternal<scalar_t, scalar_t, scalar_t>(
dY, X, rstd, gamma, M, N, dX, dgamma, output, grad_input_mask);
}
});
}

std::tuple<Tensor, Tensor> rms_norm_bw(
const Tensor& grad_output,
const Tensor& input,
at::IntArrayRef normalized_shape,
const Tensor& rstd,
const Tensor& weight,
const Tensor& output,
std::array<bool, 2> grad_input_mask) {
RECORD_FUNCTION("ipex::rms_norm_bw", std::vector<c10::IValue>({grad_output}));
auto M_N =
_check_layer_norm_inputs(input, normalized_shape, weight, Tensor());
auto M = M_N.first;
auto N = M_N.second;

Tensor grad_input;
Tensor grad_weight;

if (grad_input_mask[0]) {
grad_input = at::native::empty_like(
input,
c10::nullopt /* dtype */,
c10::nullopt /* layout */,
c10::nullopt /* device */,
c10::nullopt /* pin_memory */,
LEGACY_CONTIGUOUS_MEMORY_FORMAT);
}

if (grad_input_mask[1]) {
grad_weight = M > 0 ? at::native::empty_like(
weight,
c10::nullopt /* dtype */,
c10::nullopt /* layout */,
c10::nullopt /* device */,
c10::nullopt /* pin_memory */,
LEGACY_CONTIGUOUS_MEMORY_FORMAT)
: at::native::zeros_like(
weight,
c10::nullopt /* dtype */,
c10::nullopt /* layout */,
c10::nullopt /* device */,
c10::nullopt /* pin_memory */,
LEGACY_CONTIGUOUS_MEMORY_FORMAT);
}

if (input.numel() != 0 && grad_output.numel() != 0) {
Tensor input_ = (input.dim() == 1) ? input.reshape({M, N}) : input;
Tensor grad_output_ =
(grad_output.dim() == 1) ? grad_output.reshape({M, N}) : grad_output;
Tensor weight_ =
(weight.defined() && weight.dim() == 1) ? weight.reshape({N}) : weight;
Tensor output_ = (output.dim() == 1) ? output.reshape({M, N}) : output;

input_ = input_.contiguous();
grad_output_ = grad_output_.contiguous();
output_ = output_.contiguous();
weight_ = weight_.defined() ? weight_.contiguous() : weight_;

RmsNormBackwardKernelImpl(
grad_output_,
input_,
rstd,
weight_,
M,
N,
grad_input,
grad_weight,
output_,
grad_input_mask);
}
return std::make_tuple(
grad_input_mask[0] ? grad_input.reshape(input.sizes()) : grad_input,
grad_input_mask[1] ? grad_weight.reshape(weight.sizes()) : grad_weight);
}

class IPEXRmsNormOp : public Function<IPEXRmsNormOp> {
public:
static variable_list forward(
AutogradContext* ctx,
const Tensor& input,
at::IntArrayRef normalized_shape,
const Tensor& weight,
double epsilon) {
#ifdef BUILD_SIMPLE_TRACE
SimpleTrace trace(
"IPEXRmsNormOp forward -> at::AtenIpexTypeXPU::IPEXRmsNormOp::forward");
#endif
ctx->saved_data["input_requires_grad"] = input.requires_grad();
ctx->saved_data["weight_requires_grad"] = weight.requires_grad();
ctx->saved_data["normalized_shape"] = normalized_shape;
auto outputs = rms_norm_fw(input, normalized_shape, weight, epsilon);

ctx->save_for_backward(
{input, weight, std::get<0>(outputs), std::get<1>(outputs)});
variable_list result = {std::get<0>(outputs), std::get<1>(outputs)};
return result;
}

static variable_list backward(
AutogradContext* ctx,
variable_list grad_outputs) {
#ifdef BUILD_SIMPLE_TRACE
SimpleTrace trace(
"IPEXRmsNormOp backward -> at::AtenIpexTypeXPU::IPEXRmsNormOp::backward");
#endif
auto weight_requires_grad =
ctx->saved_data["weight_requires_grad"].toBool();
auto input_requires_grad = ctx->saved_data["input_requires_grad"].toBool();
auto saved = ctx->get_saved_variables();
Tensor input = saved[0];
Tensor weight = saved[1];
Tensor output = saved[2];
Tensor rstd = saved[3];
auto normalized_shape = weight.sizes();

auto grad_inputs = rms_norm_bw(
grad_outputs[0],
input,
normalized_shape,
rstd,
weight,
output,
{input_requires_grad, weight_requires_grad});
return {
std::get<0>(grad_inputs), Tensor(), std::get<1>(grad_inputs), Tensor()};
}
};

Tensor rms_norm_impl(
const Tensor& input,
at::IntArrayRef normalized_shape,
const Tensor& weight,
double epsilon) {
auto output = IPEXRmsNormOp::apply(input, normalized_shape, weight, epsilon);
return output[0];
}
} // namespace AtenIpexTypeXPU
} // namespace at

namespace {
IPEX_LIBRARY_FRAGMENT() {
IPEX_OP_REGISTER_DISPATCH(
"rms_norm_impl",
at::AtenIpexTypeXPU::rms_norm_impl,
c10::DispatchKey::AutogradXPU);
IPEX_OP_REGISTER("rms_norm.xpu", at::AtenIpexTypeXPU::rms_norm_fw);
}
} // namespace
5 changes: 5 additions & 0 deletions intel_extension_for_pytorch/xpu/intrinsic/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
"copy_blocks",
"swap_blocks",
"IpexPaged_attention",
"IpexRmsNorm",
]


Expand Down Expand Up @@ -164,6 +165,10 @@ def IpexSDP_dropout(
)


def IpexRmsNorm(input, normalized_shape, weight, epsilon) -> Tensor:
return torch.ops.torch_ipex.rms_norm_impl(input, normalized_shape, weight, epsilon)


def varlen_fwd(
query, # [total_q, num_head, head_size]
key, # [total_k, num_head_k, head_size]
Expand Down
45 changes: 45 additions & 0 deletions tests/gpu/examples/test_rms_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,56 @@ def test_rms_norm_fw_xpu(dtype):
w = model.weight.xpu()
output = torch.ops.torch_ipex.rms_norm(input_case, [hsz], w, 1e-5)
output1 = ipex.llm.modules.RMSNorm.apply(input_case, w, 1e-5)
output2 = torch.xpu.IpexRmsNorm(input_case, [hsz], w, 1e-5)
# diff = (output.cpu() - output_ref).abs().max().item()
# print('diff', diff)
# assert diff < 1e-2
self.assertEqual(output[0].cpu(), output_ref, atol=1e-2, rtol=1e-2)
self.assertEqual(output1.cpu(), output_ref, atol=1e-2, rtol=1e-2)
self.assertEqual(output2.cpu(), output_ref, atol=1e-2, rtol=1e-2)

test_rms_norm_fw_xpu(torch.float)
test_rms_norm_fw_xpu(torch.bfloat16)

def test_rms_norm_bw(self):
def test_rms_norm_fwd_bwd(dtype):
print("test_rms_norm_fw_bw", dtype)
torch.manual_seed(13)
modelb = RMSNormRef(64)
model0 = RMSNormRef(768)
model1 = RMSNormRef(2048)
model2 = RMSNormRef(4096)
model3 = RMSNormRef(16384)
model4 = RMSNormRef(16384 * 4 + 123)
hszs = [64, 768, 2048, 4096, 16384, 16384 * 4 + 123]
ls = [modelb, model0, model1, model2, model3, model4]
for i, model in enumerate(ls):
model = model.to(dtype)
hsz = hszs[i]
input_case = torch.rand(4, 1024, hsz).to(dtype)
input_case.requires_grad_(True)
grad = torch.rand(4, 1024, hsz).to(dtype)
output_ref = model(input_case)
output_ref.backward(grad)
grad_wei = model.weight.grad.clone()
input_grad_cpu = input_case.grad.clone()
w = model.weight.clone()

input_case_xpu = input_case.clone().xpu()
input_case_xpu.retain_grad()
input_case_xpu.requires_grad_(True)
grad_xpu = grad.xpu()
w = w.xpu()
w.retain_grad()
w.requires_grad_(True)
output1 = torch.xpu.IpexRmsNorm(input_case_xpu, [hsz], w, 1e-5)
output1.backward(grad_xpu)
grad_wei_xpu = w.grad

self.assertEqual(grad_wei_xpu.cpu(), grad_wei, atol=10e-2, rtol=10e-2)
self.assertEqual(
input_case_xpu.grad.cpu(), input_grad_cpu, atol=10e-2, rtol=10e-2
)

test_rms_norm_fwd_bwd(torch.bfloat16)
test_rms_norm_fwd_bwd(torch.float)

0 comments on commit e4938e0

Please sign in to comment.