Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add matmul_int8 op #55228

Merged
merged 9 commits into from
Jul 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 34 additions & 34 deletions paddle/fluid/operators/fused/attn_gemm_int8.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,29 +57,29 @@ class AttnMatmulINT8 {
const int quant_round_type = 1,
const float quant_max_bound = 127.0,
const float quant_min_bound = -127.0) {
quantize_kernel_launcher<T>(input->data<T>(),
input_tmp->data<int8_t>(),
quant_in_scale,
m_,
k_,
quant_round_type,
quant_max_bound,
quant_min_bound,
dev_ctx_.stream());
LaunchQuantKernel<T>(input->data<T>(),
input_tmp->data<int8_t>(),
quant_in_scale,
m_,
k_,
quant_round_type,
quant_max_bound,
quant_min_bound,
dev_ctx_.stream());

helpers_[0]->GEMM(input_tmp->data<int8_t>(),
weight->data<int8_t>(),
output_tmp->data<int32_t>(),
dev_ctx_.stream());

dequantize_kernel_launcher<T>(output_tmp->data<int32_t>(),
output->data<T>(),
m_,
n_,
dev_ctx_.stream(),
gpu_config_.get(),
quant_in_scale,
dequant_out_scale->data<float>());
LaunchDequantKernel<T>(output_tmp->data<int32_t>(),
output->data<T>(),
m_,
n_,
dev_ctx_.stream(),
gpu_config_.get(),
quant_in_scale,
dequant_out_scale->data<float>());

if (compute_bias_) {
// bias_out = output + bias
Expand Down Expand Up @@ -126,14 +126,14 @@ class AttnMatmulINT8 {
output_tmp->data<int32_t>(),
dev_ctx_.stream());

dequantize_kernel_launcher<T>(output_tmp->data<int32_t>(),
output->data<T>(),
m_,
n_,
dev_ctx_.stream(),
gpu_config_.get(),
quant_in_scale,
dequant_out_scale->data<float>());
LaunchDequantKernel<T>(output_tmp->data<int32_t>(),
output->data<T>(),
m_,
n_,
dev_ctx_.stream(),
gpu_config_.get(),
quant_in_scale,
dequant_out_scale->data<float>());

if (compute_bias_) {
// bias_out = output + bias
Expand Down Expand Up @@ -162,15 +162,15 @@ class AttnMatmulINT8 {
const int quant_round_type = 1,
const float quant_max_bound = 127.0,
const float quant_min_bound = -127.0) {
quantize_kernel_launcher<T>(input->data<T>(),
input_tmp->data<int8_t>(),
quant_in_scale,
m_,
k_,
quant_round_type,
quant_max_bound,
quant_min_bound,
dev_ctx_.stream());
LaunchQuantKernel<T>(input->data<T>(),
input_tmp->data<int8_t>(),
quant_in_scale,
m_,
k_,
quant_round_type,
quant_max_bound,
quant_min_bound,
dev_ctx_.stream());

helpers_[0]->GEMM(input_tmp->data<int8_t>(),
weight->data<int8_t>(),
Expand Down
80 changes: 40 additions & 40 deletions paddle/fluid/operators/fused/quant_dequant_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,14 +47,14 @@ __forceinline__ __device__ int8_t quant_helper(const T input,
}

template <typename T>
__global__ void quantize_kernel(const T* input,
char4* output,
const float scale,
const int m,
const int n,
const int round_type,
const float max_bound,
const float min_bound) {
__global__ void QuantKernel(const T* input,
char4* output,
const float scale,
const int m,
const int n,
const int round_type,
const float max_bound,
const float min_bound) {
int n_id = (blockIdx.x * blockDim.x + threadIdx.x) << 2;
int m_id = blockIdx.y * blockDim.y + threadIdx.y;

Expand All @@ -74,36 +74,36 @@ __global__ void quantize_kernel(const T* input,
}

template <typename T>
void quantize_kernel_launcher(const T* input,
int8_t* output,
const float scale,
const int m,
const int n,
const int round_type,
const float max_bound,
const float min_bound,
gpuStream_t stream) {
void LaunchQuantKernel(const T* input,
int8_t* output,
const float scale,
const int m,
const int n,
const int round_type,
const float max_bound,
const float min_bound,
gpuStream_t stream) {
// TODO(minghaoBD): optimize the kennel launch times when m==1 or n==1
dim3 grid((n >> 2 + 31) / 32, (m + 31) / 32);
dim3 block(32, 32);

quantize_kernel<<<grid, block, 0, stream>>>(input,
(char4*)output, // NOLINT
scale,
m,
n,
round_type,
max_bound,
min_bound);
QuantKernel<<<grid, block, 0, stream>>>(input,
(char4*)output, // NOLINT
scale,
m,
n,
round_type,
max_bound,
min_bound);
}

template <typename T, int VecSize>
__global__ void dequantize_kernel(T* output,
const int32_t* input,
const int m, // batch size
const int n, // hidden
const float quant_in_scale,
const float* dequant_out_scale_data) {
__global__ void DequantKernel(T* output,
const int32_t* input,
const int m, // batch size
const int n, // hidden
const float quant_in_scale,
const float* dequant_out_scale_data) {
int numel = m * n;
int stride = blockDim.x * gridDim.x * VecSize;
int idx = (blockIdx.x * blockDim.x + threadIdx.x) * VecSize;
Expand All @@ -128,15 +128,15 @@ __global__ void dequantize_kernel(T* output,
}

template <typename T>
void dequantize_kernel_launcher(const int32_t* input,
T* output,
const int m, // m
const int n, // n
gpuStream_t stream,
GpuLaunchConfig* gpu_config,
const float quant_in_scale,
const float* dequant_out_scale_data) {
dequantize_kernel<T, DequantKernelVecSize>
void LaunchDequantKernel(const int32_t* input,
T* output,
const int m, // m
const int n, // n
gpuStream_t stream,
GpuLaunchConfig* gpu_config,
const float quant_in_scale,
const float* dequant_out_scale_data) {
DequantKernel<T, DequantKernelVecSize>
<<<gpu_config->block_per_grid, gpu_config->thread_per_block, 0, stream>>>(
output, input, m, n, quant_in_scale, dequant_out_scale_data);
}
Expand Down
8 changes: 8 additions & 0 deletions paddle/phi/api/yaml/legacy_ops.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -513,6 +513,14 @@
func : matmul
backward : matmul_grad

- op : matmul_int8
args : (Tensor x, Tensor y, bool transpose_x = false, bool transpose_y = false)
output : Tensor
infer_meta :
func : MatmulInt8InferMeta
kernel :
func : matmul_int8

- op : matrix_rank
args : (Tensor x, float tol, bool use_default_tol=true, bool hermitian=false)
output : Tensor(out)
Expand Down
70 changes: 70 additions & 0 deletions paddle/phi/infermeta/binary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2096,6 +2096,76 @@ void MatmulInferMeta(const MetaTensor& x,
out->set_layout(x.layout());
}

void MatmulInt8InferMeta(const MetaTensor& x,
const MetaTensor& y,
bool trans_x,
bool trans_y,
MetaTensor* out) {
std::vector<int64_t> dims_x = phi::vectorize(x.dims());
std::vector<int64_t> dims_y = phi::vectorize(y.dims());
auto ndims_x = dims_x.size();
auto ndims_y = dims_y.size();
PADDLE_ENFORCE_GT(ndims_x,
0UL,
phi::errors::InvalidArgument(
"The Input(x) dims size must be greater than 0,"
" but reviced dims size is 0. "));
PADDLE_ENFORCE_GT(ndims_y,
0UL,
phi::errors::InvalidArgument(
"The Input(y) dims size must be greater than 0,"
" but reviced dims size is 0. "));

bool x_broadcasted = false, y_broadcasted = false;
if (ndims_x == 1) {
dims_x.insert(dims_x.begin(), 1);
ndims_x = 2;
x_broadcasted = true;
}

if (ndims_y == 1) {
dims_y.push_back(1);
ndims_y = 2;
y_broadcasted = true;
}

size_t M, N;
if (trans_x) {
M = dims_x[ndims_x - 1];
} else {
M = dims_x[ndims_x - 2];
}
if (trans_y) {
N = dims_y[ndims_y - 2];
} else {
N = dims_y[ndims_y - 1];
}

std::vector<int64_t> new_dims;
if (ndims_x > ndims_y) {
new_dims.assign(dims_x.begin(), dims_x.end() - 2);
} else if (ndims_x < ndims_y) {
new_dims.assign(dims_y.begin(), dims_y.end() - 2);
} else {
new_dims.reserve(ndims_x);
for (size_t i = 0; i < ndims_x - 2; ++i) {
new_dims.push_back(std::max(dims_x[i], dims_y[i]));
}
}
if (!x_broadcasted) {
new_dims.push_back(M);
}
if (!y_broadcasted) {
new_dims.push_back(N);
}

auto ddim_out = phi::make_ddim(new_dims);

out->set_dims(ddim_out);
out->set_dtype(phi::DataType::INT32);
out->set_layout(x.layout());
}

void MatmulWithFlattenInferMeta(const MetaTensor& x,
const MetaTensor& y,
int x_num_col_dims,
Expand Down
6 changes: 6 additions & 0 deletions paddle/phi/infermeta/binary.h
Original file line number Diff line number Diff line change
Expand Up @@ -333,6 +333,12 @@ void MatmulInferMeta(const MetaTensor& x,
bool trans_y,
MetaTensor* out);

void MatmulInt8InferMeta(const MetaTensor& x,
const MetaTensor& y,
bool trans_x,
bool trans_y,
MetaTensor* out);

void MatmulWithFlattenInferMeta(const MetaTensor& x,
const MetaTensor& y,
int x_num_col_dims,
Expand Down
29 changes: 10 additions & 19 deletions paddle/phi/kernels/funcs/cublaslt.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,25 +39,16 @@ const std::map<std::tuple<int, int, int>, CublasLtAlgoParam> AlgoParamCache{};

class CublasLtHelper {
public:
CublasLtHelper(int m, int k, int n)
: alpha_(1), beta_(0), m_(m), k_(k), n_(n) {
CublasLtHelper(int m, int k, int n, cublasLtHandle_t handle)
: handle_(handle), alpha_(1), beta_(0), m_(m), k_(k), n_(n) {
cublasStatus_t status;
// handle and matmul desc
status = dyl::cublasLtCreate(&handle_);
#if CUBLAS_VER_MAJOR < 11
cudaDataType_t cudaComputeType = CUDA_R_32I;
#else
cublasComputeType_t cudaComputeType = CUBLAS_COMPUTE_32I;
#endif

PADDLE_ENFORCE_EQ(
status,
CUBLAS_STATUS_SUCCESS,
phi::errors::External(
"cublasLtMatrixLayoutCreate execution error"
"refer https://docs.nvidia.com/cuda/cublas/index.html to get more "
"information"));

// matmul desc
#if CUBLAS_VER_MAJOR < 11
status = dyl::cublasLtMatmulDescCreate(&matmul_desc_, cudaComputeType);
#else
Expand Down Expand Up @@ -179,7 +170,7 @@ class CublasLtHelper {
}
~CublasLtHelper() {}

void GEMM(int8_t* A_dev,
void GEMM(const int8_t* A_dev,
const int8_t* B_dev,
int32_t* C_dev,
cudaStream_t stream,
Expand Down Expand Up @@ -226,14 +217,14 @@ class CublasLtHelper {

cublasLtMatmulAlgo_t algo_;

int32_t alpha_;
int32_t beta_;
int32_t alpha_ = 1;
int32_t beta_ = 0;

int m_;
int k_;
int n_;
int m_ = 0;
int k_ = 0;
int n_ = 0;

size_t workspace_size_;
size_t workspace_size_ = 0;
};

} // namespace phi
Loading