Skip to content

Commit

Permalink
modify the distributedAttention for different data pack mode
Browse files Browse the repository at this point in the history
  • Loading branch information
yingtongxiong committed Jan 25, 2024
1 parent 83517ca commit aa388b5
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 30 deletions.
1 change: 1 addition & 0 deletions internlm/core/context/parallel_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,7 @@ def __init__(self):
self.virtual_pipeline_parallel_size = None
self.virtual_pipeline_parallel_rank = None
self._expert_parallel_group_names = []
self.evaluation = False

@property
def config(self):
Expand Down
65 changes: 37 additions & 28 deletions internlm/model/multi_head_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,48 +80,57 @@ def __init__(
self,
local_attention: Module,
sequence_process_group: dist.ProcessGroup,
first_scatter_idx: int = 2,
first_gather_idx: int = 0,
second_scatter_idx: int = 0,
second_gather_idx: int = 1,
) -> None:
super().__init__()
self.local_attn = local_attention
self.spg = sequence_process_group
self.first_scatter_idx = first_scatter_idx
self.first_gather_idx = first_gather_idx
self.second_scatter_idx = second_scatter_idx
self.second_gather_idx = second_gather_idx
self._scatter_gather_idx = {}

# scatter_gather_idx contains the scatter and gather index for different data packed mode
# key is the data packed mode, which should be in ['qkv', 'kv', 'q', 'output']
# value is the scatter and gather index in all2all
self._scatter_gather_idx['qkv'] = [2, 0] # qkv shape:[sequence, 3, head, head_dim]
self._scatter_gather_idx['kv'] = [2, 0] # kv shape: [sequence, 2, head, head_dim]
self._scatter_gather_idx['q'] = [1, 0] # q/k/v shape: [sequence, head, head_dim]
self._scatter_gather_idx['output'] = [0, 1] # output shape: [sequence, head, head_dim]


def forward(self, qkv: Tensor = None, kv: Tensor = None, q: Tensor = None, k: Tensor = None, v: Tensor = None, **kwargs: Any) -> Tensor:
if gpc.evaluation is True:
# when conducting evaluation, the scatter and gather index should add 1.
eval_scatter_gather_idx = {key: [x + 1 for x in value] for key, value in self._scatter_gather_idx.items()}
self._forward(qkv=qkv, kv=kv, q=q, k=k, v=v, scatter_gather=eval_scatter_gather_idx, **kwargs)
else:
self._forward(qkv=qkv, kv=kv, q=q, k=k, v=v, scatter_gather=self._scatter_gather_idx, **kwargs)

def forward(self, qkv: Tensor, **kwargs: Any) -> Tensor:
def _forward(self, qkv: Tensor = None, kv: Tensor = None, q: Tensor = None, k: Tensor = None, v: Tensor = None, scatter_gather: dict = None, **kwargs: Any) -> Tensor:
"""forward
Arguments:
query (Tensor): query input to the layer
key (Tensor): key input to the layer
value (Tensor): value input to the layer
qkv (Tensor): packed qkv input to the layer
kv (Tensor): packed kv input to the layer
q (Tensor): q input to the layer
k (Tensor): k input to the layer
v (Tensor): v input to the layer
args: other args
Returns:
* output (Tensor): context output
"""
# Evaluation
if qkv.ndim == 5:
# in shape: [batch, seq/tp_size, 3, head, head_dim]
qkv = _SeqAllToAll.apply(self.spg, qkv, self.first_scatter_idx + 1, self.first_gather_idx + 1)
# out shape : [batch, seq, head/tp_size, head_dim]
context_layer = self.local_attn(qkv, **kwargs)
# in shape: [batch, seq, head/tp_size, head_dim]
output = _SeqAllToAll.apply(
self.spg, context_layer, self.second_scatter_idx + 1, self.second_gather_idx + 1
)
else: # training
# in shape: [seq/tp_size, 3, head, head_dim]
qkv = _SeqAllToAll.apply(self.spg, qkv, self.first_scatter_idx, self.first_gather_idx)
# out shape : [seq, head/tp_size, head_dim]

if qkv is not None:
qkv = _SeqAllToAll.apply(self.spg, qkv, scatter_gather['qkv'][0], scatter_gather['qkv'][1])
context_layer = self.local_attn(qkv, **kwargs)
# in shape: [seq, head/tp_size, head_dim]
output = _SeqAllToAll.apply(self.spg, context_layer, self.second_scatter_idx, self.second_gather_idx)
elif kv is not None:
q = _SeqAllToAll.apply(self.spg, q, scatter_gather['q'][0], scatter_gather['q'][1])
kv = _SeqAllToAll.apply(self.spg, kv, scatter_gather['kv'][0], scatter_gather['kv'][1])
context_layer = self.local_attn(q, kv, **kwargs)
else:
q = _SeqAllToAll.apply(self.spg, q, scatter_gather['q'][0], scatter_gather['q'][1])
k = _SeqAllToAll.apply(self.spg, k, scatter_gather['q'][0], scatter_gather['q'][1])
v = _SeqAllToAll.apply(self.spg, v, scatter_gather['q'][0], scatter_gather['q'][1])
context_layer = self.local_attn(q, k, v, **kwargs)
output = _SeqAllToAll.apply(self.spg, context_layer, scatter_gather['output'][0], scatter_gather['output'][1])

# out e.g., [s/p::h]
return output
Expand Down
7 changes: 5 additions & 2 deletions internlm/utils/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,18 +47,21 @@ def switch_evaluation_pipeline_scheduler(trainer, num_microbatches, tensor_shape


@contextmanager
def switch_sequence_parallel_mode():
def switch_evaluation_mode():
prev_mode = gpc.config.parallel.sequence_parallel
prev_evaluation = gpc.evaluation
try:
# when training x.shape is torch.Size([1024, 4096]), linear all gather in dim=0(sequence dim)
# but evaluation x.shape is torch.Size([1, 1024, 4096]), gather in dim=0 is error.
if gpc.config.parallel["tensor"]["mode"] == "isp":
gpc.config.parallel.sequence_parallel = True
else:
gpc.config.parallel.sequence_parallel = False
gpc.evaluation = True
yield
finally:
gpc.config.parallel.sequence_parallel = prev_mode
gpc.evaluation = prev_evaluation


def evaluate_on_val_dls(
Expand All @@ -70,7 +73,7 @@ def evaluate_on_val_dls(
update_panel: bool = False,
streaming: bool = False,
):
with switch_sequence_parallel_mode():
with switch_evaluation_mode():
torch.cuda.empty_cache()
trainer.eval()
verbose = gpc.is_rank_for_log()
Expand Down

0 comments on commit aa388b5

Please sign in to comment.