From 862286518b50bf53c76fd8e71ca8e92124a2f688 Mon Sep 17 00:00:00 2001 From: Yuanle Liu Date: Wed, 27 Sep 2023 14:22:08 +0800 Subject: [PATCH] add weight_only_linear_grad kernel to support weight_only_linear backward (#57685) * add weight_only_linear_grad to support weight_only_liear backward * update --------- Co-authored-by: xiaoxiaohehe001 --- paddle/fluid/eager/eager_amp_auto_cast.h | 2 +- paddle/phi/api/yaml/backward.yaml | 12 + paddle/phi/api/yaml/ops.yaml | 1 + paddle/phi/infermeta/backward.cc | 11 + paddle/phi/infermeta/backward.h | 8 + .../kernels/funcs/weight_dequant_functor.h | 283 ++++++++++++++++++ .../gpu/weight_only_linear_grad_kernel.cu | 74 +++++ .../kernels/weight_only_linear_grad_kernel.h | 28 ++ 8 files changed, 418 insertions(+), 1 deletion(-) create mode 100644 paddle/phi/kernels/funcs/weight_dequant_functor.h create mode 100644 paddle/phi/kernels/gpu/weight_only_linear_grad_kernel.cu create mode 100644 paddle/phi/kernels/weight_only_linear_grad_kernel.h diff --git a/paddle/fluid/eager/eager_amp_auto_cast.h b/paddle/fluid/eager/eager_amp_auto_cast.h index a612a84d2ae1c..56549d427accf 100644 --- a/paddle/fluid/eager/eager_amp_auto_cast.h +++ b/paddle/fluid/eager/eager_amp_auto_cast.h @@ -90,7 +90,7 @@ inline paddle::Tensor EagerAmpAutoCast(const std::string& input_name, << " input(" << egr::EagerUtils::TensorStr(input) << " to dst_dtype(" << phi::DataTypeToString(dst_dtype) << ")."; if ((op_name == "batch_norm" || op_name == "layer_norm" || - op_name == "sync_batch_norm") && + op_name == "sync_batch_norm" || op_name == "weight_only_linear") && input_name != "x") { return input; } diff --git a/paddle/phi/api/yaml/backward.yaml b/paddle/phi/api/yaml/backward.yaml index b9695e2909b11..7be497318443a 100644 --- a/paddle/phi/api/yaml/backward.yaml +++ b/paddle/phi/api/yaml/backward.yaml @@ -2568,6 +2568,18 @@ func : warprnnt_grad no_need_buffer : input +- backward_op : weight_only_linear_grad + forward : weight_only_linear(Tensor x, Tensor weight, Tensor bias, Tensor weight_scale, str weight_dtype) -> Tensor(out) + args : (Tensor x, Tensor weight, Tensor bias, Tensor weight_scale, Tensor out_grad, str weight_dtype) + output : Tensor(x_grad) + infer_meta : + func : WeightOnlyLinearGradInferMeta + kernel : + func : weight_only_linear_grad + data_type : out_grad + optional: bias + no_need_buffer: x + - backward_op : where_grad forward : where (Tensor condition, Tensor x, Tensor y) -> Tensor(out) args : (Tensor condition, Tensor x, Tensor y, Tensor out_grad) diff --git a/paddle/phi/api/yaml/ops.yaml b/paddle/phi/api/yaml/ops.yaml index fdada46699d26..fff70e820e575 100644 --- a/paddle/phi/api/yaml/ops.yaml +++ b/paddle/phi/api/yaml/ops.yaml @@ -2795,6 +2795,7 @@ func : weight_only_linear data_type : x optional: bias + backward: weight_only_linear_grad - op : weight_quantize args : (Tensor x, str algo = "weight_only_int8") diff --git a/paddle/phi/infermeta/backward.cc b/paddle/phi/infermeta/backward.cc index e66e55ca97153..4c5e130aab7a0 100644 --- a/paddle/phi/infermeta/backward.cc +++ b/paddle/phi/infermeta/backward.cc @@ -1156,6 +1156,17 @@ void UnStackGradInferMeta(const std::vector& out_grad, x_grad->set_dtype(out_grad[0]->dtype()); } +void WeightOnlyLinearGradInferMeta(const MetaTensor& x, + const MetaTensor& weight, + const MetaTensor& bias, + const MetaTensor& weight_scale, + const MetaTensor& out_grad, + const std::string& weight_dtype, + MetaTensor* x_grad) { + x_grad->set_dims(x.dims()); + x_grad->set_dtype(x.dtype()); +} + void YoloLossGradInferMeta(const MetaTensor& x, const MetaTensor& gt_box, const MetaTensor& gt_label, diff --git a/paddle/phi/infermeta/backward.h b/paddle/phi/infermeta/backward.h index a00bc2cde450f..13dd392344f97 100644 --- a/paddle/phi/infermeta/backward.h +++ b/paddle/phi/infermeta/backward.h @@ -445,6 +445,14 @@ void UnStackGradInferMeta(const std::vector& out_grad, int axis, MetaTensor* x_grad); +void WeightOnlyLinearGradInferMeta(const MetaTensor& x, + const MetaTensor& weight, + const MetaTensor& bias, + const MetaTensor& weight_scale, + const MetaTensor& out_grad, + const std::string& weight_dtype, + MetaTensor* x_grad); + void YoloLossGradInferMeta(const MetaTensor& x, const MetaTensor& gt_box, const MetaTensor& gt_label, diff --git a/paddle/phi/kernels/funcs/weight_dequant_functor.h b/paddle/phi/kernels/funcs/weight_dequant_functor.h new file mode 100644 index 0000000000000..dd1631ca722ee --- /dev/null +++ b/paddle/phi/kernels/funcs/weight_dequant_functor.h @@ -0,0 +1,283 @@ +/* + * Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/common/bfloat16.h" +#include "paddle/phi/common/datatype_traits.h" +#include "paddle/phi/common/float16.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/funcs/aligned_vector.h" +#include "paddle/phi/kernels/fusion/cutlass/cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm_template.h" +#include "paddle/phi/kernels/matmul_kernel.h" + +namespace phi { + +template +struct FastWeightOnlyHalfConverter; + +template <> +struct FastWeightOnlyHalfConverter { + using Converter = + cutlass::FastInterleavedAndBiasedNumericArrayConverter; + static constexpr int kHalfLength = 4; + static constexpr int kWeightOnlyLength = 4; + + __device__ static inline void convert(half halves[kHalfLength], + uint8_t chars[kWeightOnlyLength], + float scale) { + *reinterpret_cast(halves) = + Converter::convert(*reinterpret_cast(chars)); +#pragma unroll + for (int i = 0; i < kHalfLength; ++i) { + float dequant_value = __half2float(halves[i]) * scale; + halves[i] = __float2half_rn(dequant_value); + } + } +}; + +template <> +struct FastWeightOnlyHalfConverter { + using Converter = + cutlass::FastInterleavedAndBiasedNumericArrayConverter; + static constexpr int kHalfLength = 8; + static constexpr int kWeightOnlyLength = 4; + + __device__ static inline void convert(half halves[kHalfLength], + uint8_t chars[kWeightOnlyLength], + float scale) { + *reinterpret_cast(halves) = + Converter::convert(*reinterpret_cast(chars)); +#pragma unroll + for (int i = 0; i < kHalfLength; ++i) { + float dequant_value = __half2float(halves[i]) * scale; + halves[i] = __float2half_rn(dequant_value); + } + } +}; + +#if defined(PADDLE_CUDA_BF16) +template <> +struct FastWeightOnlyHalfConverter<__nv_bfloat16, 8> { + using Converter = cutlass::FastInterleavedAndBiasedNumericArrayConverter< + cutlass::bfloat16_t, + uint8_t, + 4>; + static constexpr int kHalfLength = 4; + static constexpr int kWeightOnlyLength = 4; + + __device__ static inline void convert(__nv_bfloat16 halves[kHalfLength], + uint8_t chars[kWeightOnlyLength], + float scale) { + *reinterpret_cast(halves) = + Converter::convert(*reinterpret_cast(chars)); +#pragma unroll + for (int i = 0; i < kHalfLength; ++i) { + float dequant_value = __bfloat162float(halves[i]) * scale; + halves[i] = __float2bfloat16_rn(dequant_value); + } + } +}; + +template <> +struct FastWeightOnlyHalfConverter<__nv_bfloat16, 4> { + using Converter = cutlass::FastInterleavedAndBiasedNumericArrayConverter< + cutlass::bfloat16_t, + cutlass::uint4b_t, + 8>; + static constexpr int kHalfLength = 8; + static constexpr int kWeightOnlyLength = 4; + + __device__ static inline void convert(__nv_bfloat16 halves[kHalfLength], + uint8_t chars[kWeightOnlyLength], + float scale) { + *reinterpret_cast(halves) = + Converter::convert(*reinterpret_cast(chars)); +#pragma unroll + for (int i = 0; i < kHalfLength; ++i) { + float dequant_value = __bfloat162float(halves[i]) * scale; + halves[i] = __float2bfloat16_rn(dequant_value); + } + } +}; +#endif + +template +__global__ void int8_weight_only_dequant(const uint8_t* weight, + const float* scale_list, + T* output, + const int n, + const int k) { + using Converter = FastWeightOnlyHalfConverter; + AlignedVector vec_weight; + T vec_weight_f16[16]; + AlignedVector vec_out; + + int warp_id = threadIdx.x / 32, lane_id = threadIdx.x % 32; + int tile_id = blockIdx.x * blockDim.x / 32 + warp_id; + // Every two rows of the original weights are interleaved into a row with + // stride of 64, so if each thread processes 16 elements(for int8, we can use + // ldg.128 to load weights), then every group of four adjacent threads will + // alternately process two different row weights for example every 128 + // consecutive int8 elements [128*i, 128*(i+1)-1] of row N under interleave + // layout, the first 64 are from [64*i, 64*(i+1)-1] of row 2N before + // interleaving, and the last 64 are from [64*i, 64*(i+1)-1] of row 2N+1 + // before interleaving. So if each thread loads 16 int8 elements, then the + // elements of the first four and last four threads of each 8 consecutive + // threads will come from row 2N and row 2N+1 respectively before + // interleaving. + int row_id = tile_id * 2 + ((lane_id % 8) > 3 ? 1 : 0); + weight += tile_id * k * 2; + output += row_id * k; + float scale = scale_list[row_id]; +#pragma unroll + for (int i = lane_id * 16; i < k * 2; i += 16 * 32) { + Load(&weight[i], &vec_weight); +#pragma unroll + for (int p = 0; p < 16; p += Converter::kHalfLength) { + // The rearrangement here counteracts the effect of + // cutlass::add_bias_and_interleave_int8s_inplace Input int8 data layout + // [elt_3 elt_1 elt_2 elt_0] (each elt occupies 8 bits) + // + // Converted fp16 data layout + // [elt_3 elt_2 elt_1 elt_0] (each elt occupies 16 bits) + // vec_weight_f16[p] = static_cast(static_cast(vec_weight[p]) * + // scale); + // fast_cvt_4_packed_signed_i8s_to_2_half2s() + Converter::convert(vec_weight_f16 + p, &vec_weight[p], scale); + } +#pragma unroll + for (int p = 0; p < 16; ++p) { + // The index remapping here is to counteracts the effect of + // cutlass::permute_B_rows_for_mixed_gemm input 0 1 2 3 4 5 6 7 8 9 10 11 + // 12 13 14 15 weight 0 1 8 9 2 3 10 11 4 5 12 13 6 7 14 15 + vec_out[p] = vec_weight_f16[4 * ((p % 8) / 2) + p % 2 + 2 * (p / 8)]; + } + Store(vec_out, &output[i / 128 * 64 + (i % 64)]); + } +} + +template +__global__ void int4_weight_only_dequant(const uint8_t* weight, + const float* scale_list, + T* output, + const int n, + const int k) { + using Converter = FastWeightOnlyHalfConverter; + + AlignedVector vec_weight; + T vec_weight_f16[32]; + AlignedVector vec_out; + + int warp_id = threadIdx.x / 32, lane_id = threadIdx.x % 32; + int tile_id = blockIdx.x * blockDim.x / 32 + warp_id; + // Every two rows of the original weights are interleaved into a row with + // stride of 64, so if each thread processes 16 elements(for int8, we can use + // ldg.128 to load weights), then every group of four adjacent threads will + // alternately process two different row weights for example every 128 + // consecutive int8 elements [128*i, 128*(i+1)-1] of row N under interleave + // layout, the first 64 are from [64*i, 64*(i+1)-1] of row 2N before + // interleaving, and the last 64 are from [64*i, 64*(i+1)-1] of row 2N+1 + // before interleaving. So if each thread loads 16 int8 elements, then the + // elements of the first four and last four threads of each 8 consecutive + // threads will come from row 2N and row 2N+1 respectively before + // interleaving. + int row_id = tile_id * 4 + ((lane_id % 8) / 2); + weight += tile_id * k / 2 * 4; + output += row_id * k; + float scale = scale_list[row_id]; +#pragma unroll + for (int i = lane_id * 32; i < k * 4; i += 32 * 32) { + Load(&weight[i / 2], &vec_weight); +#pragma unroll + for (int p = 0; p < 32; p += Converter::kHalfLength) { + // The rearrangement here counteracts the effect of + // cutlass::add_bias_and_interleave_int4s_inplace Input int8 data layout + // [elt_7 elt_5 elt_3 elt_1 elt_6 elt_4 elt_2 elt_0] (each elt + // occupies 4 bits) + // + // Converted fp16 data layout + // [elt_7 elt_6 elt_5 elt_4 elt_3 elt_2 elt_1 elt_0] (each elt + // occupies 16 bits) + // vec_weight_f16[p] = + // static_cast(static_cast(vec_weight[p]) * scale); + Converter::convert(vec_weight_f16 + p, &vec_weight[p / 2], scale); + } +#pragma unroll + for (int p = 0; p < 32; ++p) { + // The index remapping here is to counteracts the effect of + // cutlass::permute_B_rows_for_mixed_gemm input 0 1 2 3 4 5 6 7 8 9 10 11 + // 12 13 14 15 ... 31 weight 0 1 8 9 16 17 24 25 2 3 10 11 18 19 26 27 4 5 + // 12 13 20 21 28 29 6 7 14 15 22 23 30 31 + vec_out[p] = vec_weight_f16[8 * ((p % 8) / 2) + p % 2 + 2 * (p / 8)]; + } + Store(vec_out, &output[i / 256 * 64 + (i % 64)]); + } +} + +template +void WeightDequantize(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& scale, + const std::string& algo, + const bool transpose, + DenseTensor* out) { + using DataType = typename PDDataTypeTraits::DataType; + + int n = scale.dims()[0]; + int k = x.dims()[1]; + dim3 block(512); + dim3 grid(n / 32); + auto stream = dev_ctx.stream(); + + if (algo == "weight_only_int8") { + int8_weight_only_dequant<<>>( + reinterpret_cast(x.data()), + scale.data(), + reinterpret_cast(out->data()), + n, + k); + } else if (algo == "weight_only_int4") { + grid.x /= 2; + int4_weight_only_dequant<<>>( + reinterpret_cast(x.data()), + scale.data(), + reinterpret_cast(out->data()), + n, + k); + } +} + +} // namespace phi diff --git a/paddle/phi/kernels/gpu/weight_only_linear_grad_kernel.cu b/paddle/phi/kernels/gpu/weight_only_linear_grad_kernel.cu new file mode 100644 index 0000000000000..f327ccef1a1aa --- /dev/null +++ b/paddle/phi/kernels/gpu/weight_only_linear_grad_kernel.cu @@ -0,0 +1,74 @@ +/* + * Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/weight_only_linear_grad_kernel.h" +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/common/datatype_traits.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/matmul_kernel.h" + +#if defined(PADDLE_WITH_CUTLASS) +#include "paddle/phi/kernels/funcs/weight_dequant_functor.h" +#endif + +namespace phi { + +template +void WeightOnlyLinearGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& weight, + const paddle::optional& bias, + const DenseTensor& weight_scale, + const DenseTensor& out_grad, + const std::string& weight_dtype, + DenseTensor* x_grad) { +#if defined(PADDLE_WITH_CUTLASS) + int n = weight_scale.dims()[0]; + int k = weight.dims()[1]; + dev_ctx.template Alloc(x_grad); + DenseTensor weight_dequantized; + weight_dequantized.Resize({{n, k}}); + dev_ctx.template Alloc(&weight_dequantized); + std::string algo = + weight_dtype == "int8" ? "weight_only_int8" : "weight_only_int4"; + WeightDequantize( + dev_ctx, weight, weight_scale, algo, true, &weight_dequantized); + MatmulKernel( + dev_ctx, out_grad, weight_dequantized, false, false, x_grad); +#endif +} +} // namespace phi + +PD_REGISTER_KERNEL(weight_only_linear_grad, + GPU, + ALL_LAYOUT, + phi::WeightOnlyLinearGradKernel, + phi::dtype::float16, + phi::dtype::bfloat16) {} diff --git a/paddle/phi/kernels/weight_only_linear_grad_kernel.h b/paddle/phi/kernels/weight_only_linear_grad_kernel.h new file mode 100644 index 0000000000000..6cf44ef6d4688 --- /dev/null +++ b/paddle/phi/kernels/weight_only_linear_grad_kernel.h @@ -0,0 +1,28 @@ +/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { + +template +void WeightOnlyLinearGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& weight, + const paddle::optional& bias, + const DenseTensor& weight_scale, + const DenseTensor& out_grad, + const std::string& weight_dtype, + DenseTensor* x_grad); + +} // namespace phi