Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[OPT] FlashAttention && ModelParallel #51617

Merged
merged 3 commits into from
Mar 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 7 additions & 10 deletions paddle/phi/kernels/gpu/flash_attn_grad_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -67,10 +67,9 @@ void FlashAttnUnpaddedGradKernel(const Context& ctx,
int num_splits = 0; // 0 for an internal heuristic, which is optimal
bool zero_tensors = false;

std::vector<int64_t> seed_offset_vec;
phi::TensorToVector<int64_t>(seed_offset, ctx, &seed_offset_vec);
uint64_t seed = seed_offset_vec[0];
uint64_t offset = seed_offset_vec[1];
const int64_t* seed_offset_data = seed_offset.data<int64_t>();
uint64_t seed = static_cast<uint64_t>(seed_offset_data[0]);
uint64_t offset = static_cast<uint64_t>(seed_offset_data[1]);

int64_t seq_len_q = ((max_seqlen_q + 16 - 1) / 16) * 16;
DenseTensor dsoftmax = Empty<float>(ctx, {batch_size, num_heads, seq_len_q});
Expand Down Expand Up @@ -188,12 +187,10 @@ void FlashAttnGradKernel(const Context& ctx,

float scale = 1.0f / std::sqrt(head_size);

DenseTensor q_t_s =
Reshape<T, Context>(ctx, q, {total_q, num_heads, head_size});
DenseTensor k_t_s =
Reshape<T, Context>(ctx, k, {total_k, num_heads, head_size});
DenseTensor v_t_s =
Reshape<T, Context>(ctx, v, {total_k, num_heads, head_size});
DenseTensor q_t_s, k_t_s, v_t_s;
q_t_s.ShareDataWith(q).Resize({total_q, num_heads, head_size});
k_t_s.ShareDataWith(k).Resize({total_k, num_heads, head_size});
v_t_s.ShareDataWith(v).Resize({total_k, num_heads, head_size});

DenseTensor cu_seqlens_q;
DenseTensor cu_seqlens_k;
Expand Down
17 changes: 9 additions & 8 deletions paddle/phi/kernels/gpu/flash_attn_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -75,11 +75,14 @@ void FlashAttnUnpaddedKernel(const Context& ctx,
auto gen = ctx.GetGenerator();
uint64_t inc = batch_size * num_heads * 32;
auto seed_offset_pair = gen->IncrementOffset(inc);

uint64_t seed = seed_offset_pair.first;
uint64_t offset = seed_offset_pair.second;

std::vector<int64_t> seed_offset_vec{int64_t(seed), int64_t(offset)};
phi::TensorFromVector<int64_t>(seed_offset_vec, ctx, seed_offset);
seed_offset->Resize({2});
auto* seed_offset_data = ctx.template HostAlloc<int64_t>(seed_offset);
seed_offset_data[0] = static_cast<int64_t>(seed);
seed_offset_data[1] = static_cast<int64_t>(offset);

int64_t seq_len_q = ((max_seqlen_q + 16 - 1) / 16) * 16;

Expand Down Expand Up @@ -210,12 +213,10 @@ void FlashAttnKernel(const Context& ctx,

float scale = 1.0f / std::sqrt(head_size);

DenseTensor q_t_s =
Reshape<T, Context>(ctx, q, {total_q, num_heads, head_size});
DenseTensor k_t_s =
Reshape<T, Context>(ctx, k, {total_k, num_heads, head_size});
DenseTensor v_t_s =
Reshape<T, Context>(ctx, v, {total_k, num_heads, head_size});
DenseTensor q_t_s, k_t_s, v_t_s;
q_t_s.ShareDataWith(q).Resize({total_q, num_heads, head_size});
k_t_s.ShareDataWith(k).Resize({total_k, num_heads, head_size});
v_t_s.ShareDataWith(v).Resize({total_k, num_heads, head_size});

DenseTensor cu_seqlens_q;
DenseTensor cu_seqlens_k;
Expand Down
47 changes: 37 additions & 10 deletions python/paddle/distributed/fleet/layers/mpu/mp_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

import paddle
from paddle.autograd import PyLayer
from paddle.fluid import core
from paddle.nn import functional as F

Expand Down Expand Up @@ -328,6 +329,17 @@ def forward(self, x):
return output


class MPScale(PyLayer):
@staticmethod
def forward(ctx, x, mp_degree):
out = paddle.scale(x, 1.0 / mp_degree)
return out

@staticmethod
def backward(ctx, dout):
return dout


class RowParallelLinear(paddle.nn.Layer):
"""Linear layer with mp parallelized(row).
this class is used for splitting Linear Layer in mp group, row split the weight of the Linear layer.
Expand Down Expand Up @@ -467,6 +479,7 @@ def __init__(
from paddle.incubate.nn.functional import fused_linear

self.linear = fused_linear
self.fuse_matmul_bias = fuse_matmul_bias

def forward(self, x):
if self.input_is_parallel or (not self.is_mp):
Expand All @@ -476,16 +489,30 @@ def forward(self, x):
input_parallel = mp_ops._c_split(x, group=self.model_parallel_group)

if self.is_mp:
output_parallel = self.linear(
input_parallel, self.weight, name=self._name
)
output_ = mp_ops._mp_allreduce(
output_parallel,
group=self.model_parallel_group,
use_calc_stream=True,
use_model_parallel=True,
)
output = output_ + self.bias if self.bias is not None else output_
if self.fuse_matmul_bias:
bias = MPScale.apply(self.bias, self.world_size)
output_parallel = self.linear(
input_parallel, self.weight, bias, name=self._name
)
output = mp_ops._mp_allreduce(
output_parallel,
group=self.model_parallel_group,
use_calc_stream=True,
use_model_parallel=True,
)
else:
output_parallel = self.linear(
input_parallel, self.weight, name=self._name
)
output_ = mp_ops._mp_allreduce(
output_parallel,
group=self.model_parallel_group,
use_calc_stream=True,
use_model_parallel=True,
)
output = (
output_ + self.bias if self.bias is not None else output_
)
else:
output = self.linear(
input_parallel, self.weight, self.bias, name=self._name
Expand Down
20 changes: 2 additions & 18 deletions python/paddle/distributed/fleet/layers/mpu/mp_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,15 +46,7 @@ def _c_identity(tensor, group=None):
class c_identity_eager(PyLayer):
@staticmethod
def forward(ctx, tensor):
return _legacy_C_ops.c_identity(
tensor,
'use_calc_stream',
True,
'ring_id',
group.id,
'use_model_parallel',
True,
)
return tensor

@staticmethod
def backward(ctx, dy):
Expand Down Expand Up @@ -257,15 +249,7 @@ def forward(

@staticmethod
def backward(ctx, dy):
return _legacy_C_ops.c_identity(
dy,
'use_calc_stream',
True,
'ring_id',
ctx.ring_id,
'use_model_parallel',
True,
)
return dy

return mp_allreduce_eager.apply(
tensor, group, use_calc_stream, use_model_parallel
Expand Down