Skip to content

Commit

Permalink
[PIR] register fused_attention in pir (PaddlePaddle#57557)
Browse files Browse the repository at this point in the history
* register fused_attention in pir

* fix

* fix
  • Loading branch information
kangguangli authored Sep 21, 2023
1 parent abe013a commit 13b67ff
Show file tree
Hide file tree
Showing 6 changed files with 342 additions and 0 deletions.
1 change: 1 addition & 0 deletions paddle/fluid/pir/dialect/op_generator/ops_api_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@
'c_allreduce_max',
'c_allgather',
'seed',
"fused_attention",
]


Expand Down
9 changes: 9 additions & 0 deletions paddle/fluid/pir/dialect/operator/ir/ops.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -170,3 +170,12 @@
args : (Tensor i, Tensor x)
output : Tensor[](out)
backward: write_to_array_grad

- op: fused_attention
args: (Tensor x, Tensor ln_scale, Tensor ln_bias, Tensor qkv_weight, Tensor qkv_bias, Tensor cache_kv, Tensor src_mask, Tensor out_linear_weight, Tensor out_linear_bias, Tensor ln_scale_2, Tensor ln_bias_2, int num_heads, bool transpose_qkv_wb, bool pre_layer_norm, float epsilon, float attn_dropout_rate, bool is_test, bool attn_dropout_fix_seed, int attn_dropout_seed, str attn_dropout_implementation, float dropout_rate, bool dropout_fix_seed, int dropout_seed, str dropout_implementation, float ln_epsilon, bool add_residual, int ring_id)
output: Tensor(ln_mean), Tensor(ln_var), Tensor(ln_out), Tensor(qkv_out), Tensor(qkv_bias_out), Tensor(transpose_out_2), Tensor(qk_out), Tensor(qktv_out), Tensor(softmax_out), Tensor(attn_dropout_mask_out), Tensor(attn_dropout_out), Tensor(src_mask_out), Tensor(fmha_out), Tensor(out_linear_out), Tensor(dropout_mask_out), Tensor(ln_mean_2), Tensor(ln_var_2), Tensor(bias_dropout_residual_out), Tensor(cache_kv_out), Tensor(out)
kernel:
func: fused_attention
infer_meta:
func: FusedAttentionInferMeta
optional: cache_kv, ln_scale, ln_bias, qkv_bias, src_mask, out_linear_bias, ln_scale_2, ln_bias_2
35 changes: 35 additions & 0 deletions paddle/phi/api/yaml/op_compat.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1181,6 +1181,41 @@
data_type : float
support_tensor : true

- op : fused_attention
inputs:
x: X
ln_scale: LnScale
ln_bias: LnBias
qkv_weight: QKVW
qkv_bias: QKVBias
cache_kv: CacheKV
src_mask: SrcMask
out_linear_weight: OutLinearW
out_linear_bias: OutLinearBias
ln_scale_2: Ln2Scale
ln_bias_2: Ln2Bias
outputs:
ln_mean: LnMean
ln_var: LnVariance
ln_out: LnOut
qkv_out: QKVOut
qkv_bias_out: QKVBiasOut
transpose_out_2: TransposeOut2
qk_out: QKOut
qktv_out: QKTVOut
softmax_out: SoftmaxOut
attn_dropout_mask_out: AttnDropoutMaskOut
attn_dropout_out: AttnDropoutOut
src_mask_out: SrcMaskOut
fmha_out: FMHAOut
out_linear_out: OutLinearOut
dropout_mask_out: DropoutMaskOut
ln_mean_2: Ln2Mean
ln_var_2: Ln2Variance
bias_dropout_residual_out: BiasDropoutResidualOut
cache_kv_out: CacheKVOut
out: Y

- op : fused_batch_norm_act
backward : fused_batch_norm_act_grad
inputs:
Expand Down
247 changes: 247 additions & 0 deletions paddle/phi/infermeta/multiary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1604,6 +1604,253 @@ void FusedBiasActInferMeta(const MetaTensor& x,
out->set_layout(x.layout());
}

void FusedAttentionInferMeta(const MetaTensor& x,
const MetaTensor& ln_scale,
const MetaTensor& ln_bias,
const MetaTensor& qkv_weight,
const MetaTensor& qkv_bias,
const MetaTensor& cache_kv,
const MetaTensor& src_mask,
const MetaTensor& out_linear_weight,
const MetaTensor& out_linear_bias,
const MetaTensor& ln_scale_2,
const MetaTensor& ln_bias_2,
int num_heads,
bool transpose_qkv_wb,
bool pre_layer_norm,
float epsilon,
float attn_dropout_rate,
bool is_test,
bool attn_dropout_fix_seed,
int attn_dropout_seed,
const std::string& attn_dropout_implementation,
float dropout_rate,
bool dropout_fix_seed,
int dropout_seed,
const std::string& dropout_implementation,
float ln_epsilon,
bool add_residual,
int ring_id,
MetaTensor* ln_mean,
MetaTensor* ln_var,
MetaTensor* ln_out,
MetaTensor* qkv_out,
MetaTensor* qkv_bias_out,
MetaTensor* transpose_out_2,
MetaTensor* qk_out,
MetaTensor* qktv_out,
MetaTensor* softmax_out,
MetaTensor* attn_dropout_mask_out,
MetaTensor* attn_dropout_out,
MetaTensor* src_mask_out,
MetaTensor* fmha_out,
MetaTensor* out_linear_out,
MetaTensor* dropout_mask_out,
MetaTensor* ln_mean_2,
MetaTensor* ln_var_2,
MetaTensor* bias_dropout_residual_out,
MetaTensor* cache_kv_out,
MetaTensor* out,
MetaConfig config) {
auto x_dim = x.dims();
auto y_dim = qkv_weight.dims();

int dim_head = 0;
int hidden_size = 0;
int nranks = 1;
if (transpose_qkv_wb) {
PADDLE_ENFORCE_EQ(y_dim.size(),
2,
phi::errors::InvalidArgument(
"The dimensions of qkv_weight must be 2 if enable"
"transpose_qkv_wb: (dim_embed, 3 * dim_embed),"
"but received dimensions of"
"Input is [%d]",
y_dim.size()));
PADDLE_ENFORCE_GT(num_heads,
0,
phi::errors::InvalidArgument(
"The num_heads must be provided and greater than 0 "
"if enable transpose_qkv_wb, but we got %d.",
num_heads));
PADDLE_ENFORCE_EQ(y_dim[0] % num_heads,
0,
phi::errors::InvalidArgument(
"First dim of qkv_w must be divisible by num heads "
"if enable transpose_qkv_wb, but receive first "
"dim of qkv_w is %d and num_heads is %d.",
y_dim[0],
num_heads));
if (ring_id == -1) {
PADDLE_ENFORCE_EQ(
y_dim[0] * 3,
y_dim[1],
phi::errors::InvalidArgument("The dimensions of qkv_weight must be 2"
"(dim_embed, 3 * dim_embed)."));
} else {
// compute the mp nranks
nranks = (y_dim[0] * 3) / y_dim[1];
}
dim_head = y_dim[0] / (num_heads * nranks);
hidden_size = y_dim[0];
} else {
PADDLE_ENFORCE_EQ(y_dim.size(),
4,
phi::errors::InvalidArgument(
"The dimensions of qkv_weight must be 4 if not"
"enable transpose_qkv_wb: (3, num_head, dim_head, "
"dim_embed), but received [%d]",
y_dim.size()));
PADDLE_ENFORCE_EQ(
y_dim[0],
3,
phi::errors::InvalidArgument("First dim of qkv_w must be 3 if disable "
"transpose_qkv_wb, but we got %d.",
y_dim[0]));
if (ring_id == -1) {
PADDLE_ENFORCE_EQ(
y_dim[1] * y_dim[2],
y_dim[3],
phi::errors::InvalidArgument("The dimensions of qkv_weight must be 4"
"(3, num_head, dim_head, dim_embed),"
"and must satisfy the limitations: "
"(num_head * dim_head == dim_embed)"));
}
num_heads = y_dim[1];
dim_head = y_dim[2];
hidden_size = y_dim[3];
}

PADDLE_ENFORCE_EQ(
x_dim.size(),
3,
phi::errors::InvalidArgument("The dimensions of x must be 3"
"(batch_size, seq_len, dim_embed),"
"but received dimensions of"
"Input is [%d]",
x_dim.size()));

PADDLE_ENFORCE_EQ(x_dim[2],
hidden_size,
phi::errors::InvalidArgument(
"ShapeError: the dimension of x_dim[2] and y_dim[3] "
"(y_dim[1] if enable transpose_qkv_w) "
"must be equal. But received: the shape "
"of input x = [%s], and the shape of "
"input qkv_weight = [%s]",
x_dim,
y_dim));

if (pre_layer_norm) {
ln_mean->set_dims({x_dim[0] * x_dim[1]});
ln_var->set_dims({x_dim[0] * x_dim[1]});
ln_out->set_dims(x.dims());
} else {
ln_mean_2->set_dims({x_dim[0] * x_dim[1]});
ln_var_2->set_dims({x_dim[0] * x_dim[1]});
bias_dropout_residual_out->set_dims(x.dims());
}

if (transpose_qkv_wb) {
// [batch_size, seq_len, 3 * num_heads * dim_head]
qkv_out->set_dims({x_dim[0], x_dim[1], 3 * num_heads * dim_head});

if (qkv_bias) {
qkv_bias_out->set_dims({x_dim[0], x_dim[1], 3 * num_heads * dim_head});
}
} else {
// [batch_size, seq_len, 3, num_head, head_size]
qkv_out->set_dims({x_dim[0], x_dim[1], 3, num_heads, dim_head});

if (qkv_bias) {
qkv_bias_out->set_dims({x_dim[0], x_dim[1], 3, num_heads, dim_head});
}
}

// [3, batch_size, num_head, seq_len, head_size]
transpose_out_2->set_dims({3, x_dim[0], num_heads, x_dim[1], dim_head});

// cache_seq_len + seq_len if cache else seq_len
auto out_seq_len = x_dim[1];
if (cache_kv) {
// [2, batch_size, num_head, cache_seq_len, head_size]
auto c_dim = cache_kv.dims();

PADDLE_ENFORCE_EQ(
c_dim.size(),
5,
phi::errors::InvalidArgument("The CacheKV must be 5 dims, but got %d",
c_dim.size()));
PADDLE_ENFORCE_EQ(c_dim[0],
2,
phi::errors::InvalidArgument(
"The first dim of CacheKV must be 2, but got %d",
c_dim[0])); // 2
PADDLE_ENFORCE_EQ(c_dim[1],
x_dim[0],
phi::errors::InvalidArgument(
"The second dim of CacheKV must be equal with "
"batch size %d, but got %d",
x_dim[0],
c_dim[1])); // batch_size
PADDLE_ENFORCE_EQ(c_dim[2],
num_heads,
phi::errors::InvalidArgument(
"The third dim of CacheKV must be equal with num "
"head %d, but got %d",
num_heads,
c_dim[2])); // num_head
// In compile stage, input seq_len can be -1, in that case
// c_dim[3] may < 0 in while
if (config.is_runtime) {
PADDLE_ENFORCE_GE(
c_dim[3],
0,
phi::errors::InvalidArgument(
"The forth dim of CacheKV must be greater than 0, but got %d",
c_dim[3])); // cache_seq_len
}

PADDLE_ENFORCE_EQ(c_dim[4],
dim_head,
phi::errors::InvalidArgument(
"The fifth dim of CacheKV must be equal with head "
"size %d, but got %d",
dim_head,
c_dim[4])); // head_size

out_seq_len += c_dim[3];
// [3, batch_size, num_head, cache_seq_len + seq_len, head_size]
cache_kv_out->set_dims(
{c_dim[0], c_dim[1], c_dim[2], out_seq_len, c_dim[4]});
}
// [batch, num_head, seq_len, out_seq_len]
qk_out->set_dims({x_dim[0], num_heads, x_dim[1], out_seq_len});

if (src_mask) {
src_mask_out->set_dims({x_dim[0], num_heads, x_dim[1], out_seq_len});
}
// the same as QKOut's shape.
attn_dropout_out->set_dims({x_dim[0], num_heads, x_dim[1], out_seq_len});
if (is_test) {
attn_dropout_mask_out->set_dims(
{x_dim[0], num_heads, x_dim[1], out_seq_len});
}
softmax_out->set_dims({x_dim[0], num_heads, x_dim[1], out_seq_len});
// [batch_size, num_heads, seq_len, head_dim]
qktv_out->set_dims({x_dim[0], num_heads, x_dim[1], dim_head});
// [batch_size, seq_len, number of heads*head size]
fmha_out->set_dims({x_dim[0], x_dim[1], num_heads, dim_head});

out_linear_out->set_dims(x.dims());

if (is_test == false) {
dropout_mask_out->set_dims(x.dims());
}

out->set_dims(x.dims());
}

void FusedLayerNormInferMeta(const MetaTensor& x,
const MetaTensor& bias,
const MetaTensor& residual,
Expand Down
49 changes: 49 additions & 0 deletions paddle/phi/infermeta/multiary.h
Original file line number Diff line number Diff line change
Expand Up @@ -334,6 +334,55 @@ void FusedBiasActInferMeta(const MetaTensor& x,
MetaTensor* out,
MetaConfig config = MetaConfig());

void FusedAttentionInferMeta(const MetaTensor& x,
const MetaTensor& ln_scale,
const MetaTensor& ln_bias,
const MetaTensor& qkv_weight,
const MetaTensor& qkv_bias,
const MetaTensor& cache_kv,
const MetaTensor& src_mask,
const MetaTensor& out_linear_weight,
const MetaTensor& out_linear_bias,
const MetaTensor& ln_scale_2,
const MetaTensor& ln_bias_2,
int num_heads,
bool transpose_qkv_wb,
bool pre_layer_norm,
float epsilon,
float attn_dropout_rate,
bool is_test,
bool attn_dropout_fix_seed,
int attn_dropout_seed,
const std::string& attn_dropout_implementation,
float dropout_rate,
bool dropout_fix_seed,
int dropout_seed,
const std::string& dropout_implementation,
float ln_epsilon,
bool add_residual,
int ring_id,
MetaTensor* ln_mean,
MetaTensor* ln_var,
MetaTensor* ln_out,
MetaTensor* qkv_out,
MetaTensor* qkv_bias_out,
MetaTensor* transpose_out_2,
MetaTensor* qk_out,
MetaTensor* qktv_out,
MetaTensor* softmax_out,
MetaTensor* attn_dropout_mask_out,
MetaTensor* attn_dropout_out,
MetaTensor* src_mask_out,
MetaTensor* fmha_out,
MetaTensor* out_linear_out,
MetaTensor* dropout_mask_out,
MetaTensor* ln_mean_2,
MetaTensor* ln_var_2,
MetaTensor* bias_dropout_residual_out,
MetaTensor* cache_kv_out,
MetaTensor* out,
MetaConfig config = MetaConfig());

void FusedLayerNormInferMeta(const MetaTensor& x,
const MetaTensor& bias,
const MetaTensor& residual,
Expand Down
1 change: 1 addition & 0 deletions test/white_list/new_ir_op_test_white_list
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ test_fmax_op
test_fmin_op
test_fold_op
test_frame_op
test_fused_attention_op_api
test_gather_tree_op
test_gaussian_random_op
test_generate_proposals_v2_op
Expand Down

0 comments on commit 13b67ff

Please sign in to comment.