diff --git a/deepspeed/module_inject/tp_shard.py b/deepspeed/module_inject/tp_shard.py index 6758c7a657f6..57be0c793856 100644 --- a/deepspeed/module_inject/tp_shard.py +++ b/deepspeed/module_inject/tp_shard.py @@ -24,7 +24,9 @@ def set_n_embd(num): def get_num_kv_heads(): global num_kv_heads - return num_kv_heads + if 'num_kv_heads' in globals(): + return num_kv_heads + return None def get_num_attention_heads(): diff --git a/deepspeed/sequence/layer.py b/deepspeed/sequence/layer.py index e809fe1118b5..8fd3ca918433 100644 --- a/deepspeed/sequence/layer.py +++ b/deepspeed/sequence/layer.py @@ -10,6 +10,8 @@ import deepspeed.comm as dist from deepspeed.accelerator import get_accelerator +from deepspeed.module_inject.tp_shard import get_shard_size_list, set_num_kv_heads, get_num_kv_heads +from deepspeed.utils import groups def post_all2all(scatter_idx, batch_dim_idx, seq_world_size, bs, seq_len, num_head, head_dim): @@ -38,8 +40,132 @@ def post_func(input): return post_func +def uneven_heads_all2all(input, scatter_idx, gather_idx, batch_dim_idx, group): + seq_world_size = dist.get_world_size(group) + inp_shape = list(input.shape) + assert batch_dim_idx in [0, 1], "batch_dim_idx must be either 0 or 1" + + if not (scatter_idx < 2): + input_splits = get_shard_size_list(inp_shape[scatter_idx], seq_world_size) + input = input.transpose(0, scatter_idx).contiguous() + local_heads = input_splits[groups._get_sequence_parallel_rank()] + output_splits = [local_heads] * seq_world_size + + output_buffer_shape = [seq_world_size * local_heads] + list(input.shape[1:]) + output = torch.empty(output_buffer_shape, device=input.device, dtype=input.dtype) + dist.all_to_all_single(output,input,output_split_sizes=output_splits,\ + input_split_sizes=input_splits,group=group) + ###[seq_ws*local_heads, ...] to [seq_ws, local_heads, ...] + output = output.view(seq_world_size, local_heads, *output.shape[1:]) + ###[seq_ws,local_heads,b,seq_len,...] to [seq_ws,seq_len,b,local_heads,...] + + ### batch_dim_idx=0 [seq_ws,local_heads,seq_len,b,...] to [b, seq_ws, seq_len, local_heads ...] + ### batch_dim_idx=1 [seq_ws,local_heads,b,seq_len,...] to [seq_ws,seq_len,b,local_heads,...] + if batch_dim_idx == 0: + order = [3, 0, 2, 1] + list(range(4, len(output.shape))) + output = output.permute(order).contiguous() + ###[b, seq_ws*local_seq_len, local_heads,...] + output = output.view(output.shape[0], inp_shape[gather_idx] * seq_world_size, + *output.shape[3:]).contiguous() + elif batch_dim_idx == 1: + output = output.transpose(1, 3).contiguous() + ###[seq_ws*local_seq_len, b, local_heads,...] + output = output.view(inp_shape[gather_idx] * seq_world_size, *output.shape[2:]).contiguous() + else: + # The compatibility handling of 4D and 3D tensors, standardizing to 3D. + input = input.reshape(input.shape[0], input.shape[1], -1) + + if batch_dim_idx == 0: #b,s,h + input = input.permute(1, 2, 0).contiguous() #s,h,b + elif batch_dim_idx == 1: #s,b,h + input = input.transpose(1, 2).contiguous() #s,h,b + seq_len, h, batch_size = input.shape + num_local_heads_list = get_shard_size_list(get_num_kv_heads(), seq_world_size) + local_heads = num_local_heads_list[groups._get_sequence_parallel_rank()] + h_dim = h // local_heads + local_seq_len = seq_len // seq_world_size + + input = input.view(seq_len * h, batch_size) + local_seq_len_with_heads = int(input.shape[0] / seq_world_size) # dim size of local_seq_len*local_heads*hdim + input_splits = [local_seq_len_with_heads] * seq_world_size + coeff = local_seq_len_with_heads // local_heads #per head: dim size of local_seq_len*hdim + + #uneven seq_world_size coeff, total_heads/local_heads. + heads_scale_coeff = get_num_kv_heads() / local_heads + + output_splits = [num_local_heads * coeff for num_local_heads in num_local_heads_list] + output_buff_d1_size = int(heads_scale_coeff * local_seq_len_with_heads) + total_h = int(inp_shape[gather_idx] * heads_scale_coeff) + output = torch.empty(output_buff_d1_size, input.shape[1], device=input.device, dtype=input.dtype) + dist.all_to_all_single(output,input,output_split_sizes=output_splits, \ + input_split_sizes=input_splits,group=group) + ################## + #suppose 7 heads divide into 4 ranks [2,2,2,1] + #chunk_num_heads_small=floor(7/4)=1 + #chunk_num_heads_large=ceil(7/4)=2 + #num_chunk_heads_large=len([2,2,2])=3, all2all_buffer_counts + #num_chunk_heads_small=len([1])=1, all2all_buffer_counts + #total_num_large_heads=sum([2,2,2])=7 + #total_num_small_heads=sum([1])=1 + + chunk_num_heads_small = get_num_kv_heads() // seq_world_size # even heads compatible + chunk_num_heads_large = chunk_num_heads_small + 1 + num_chunk_heads_large = get_num_kv_heads() % seq_world_size + num_chunk_heads_small = seq_world_size - num_chunk_heads_large + total_num_large_heads = num_chunk_heads_large * chunk_num_heads_large + total_num_small_heads = num_chunk_heads_small * chunk_num_heads_small + + heads_large_combine_size = coeff * total_num_large_heads + heads_small_combine_size = coeff * total_num_small_heads + heads_large_chunk, heads_small_chunk = output.split([heads_large_combine_size, heads_small_combine_size], + dim=0) + heads_large_chunk = heads_large_chunk.view(num_chunk_heads_large, local_seq_len, chunk_num_heads_large, h_dim, + batch_size) + heads_small_chunk = heads_small_chunk.view(num_chunk_heads_small, local_seq_len, chunk_num_heads_small, h_dim, + batch_size) + if batch_dim_idx == 0: + #[all2all_buffer_counts, local_seq_len, n_heads,dim,batch]->[batch,local_seq_len,all2all_buffer_counts*n_heads,dim] + order = [4, 1, 0, 2, 3] + heads_large_chunk = heads_large_chunk.permute(order).contiguous().view(batch_size, local_seq_len, + total_num_large_heads, h_dim) + heads_small_chunk = heads_small_chunk.permute(order).contiguous().view(batch_size, local_seq_len, + total_num_small_heads, h_dim) + elif batch_dim_idx == 1: + #[all2all_buffer_counts, local_seq_len, n_heads,dim,batch]->[local_seq_len,batch,all2all_buffer_counts*n_heads,dim] + order = [1, 4, 0, 2, 3] + heads_large_chunk = heads_large_chunk.permute(order).contiguous().view(local_seq_len, batch_size, + total_num_large_heads, h_dim) + heads_small_chunk = heads_small_chunk.permute(order).contiguous().view(local_seq_len, batch_size, + total_num_small_heads, h_dim) + + output = torch.cat([heads_large_chunk, heads_small_chunk], dim=2).contiguous() + + inp_shape[scatter_idx] = inp_shape[scatter_idx] // seq_world_size + output_shape= inp_shape[: gather_idx] + \ + [total_h,] + \ + inp_shape[gather_idx + 1:] + + output = output.view(output_shape) + + return output + + def single_all_to_all(input, scatter_idx, gather_idx, batch_dim_idx, group, async_op=False, handle=None, type=None): seq_world_size = dist.get_world_size(group) + # we only need num_heads once + num_heads = input.shape[2] + + if get_num_kv_heads() is not None or num_heads % seq_world_size != 0: + # Assuming here that the number of heads for q is consistent with kv + # If not, additional logic is required for cases like GQA + if get_num_kv_heads() is None: + assert num_heads > seq_world_size, f"Number of heads ({num_heads}) must be larger than sequence parallel size ({seq_world_size})" + # set heads at first call by num_total_heads. + # then use ``get_num_kv_heads() is not None`` to re-entry uneven path. + set_num_kv_heads(num_heads) + assert async_op == False, "uneven head sp does not support async op" + return uneven_heads_all2all(input, scatter_idx, gather_idx, batch_dim_idx, group) + if batch_dim_idx == 0: # b, s, n, h if scatter_idx < 2: diff --git a/deepspeed/utils/groups.py b/deepspeed/utils/groups.py index 9dd288ef46db..e9550a0ec25a 100755 --- a/deepspeed/utils/groups.py +++ b/deepspeed/utils/groups.py @@ -484,6 +484,8 @@ def _get_sequence_parallel_rank(): global mpu if mpu is not None and hasattr(mpu, 'get_sequence_parallel_rank'): return mpu.get_sequence_parallel_rank() + if mesh_device is not None: + return dist.get_rank(mesh_device.get_group(mesh_dim="sequence_parallel")) return 0 diff --git a/tests/unit/sequence_parallelism/test_ulysses.py b/tests/unit/sequence_parallelism/test_ulysses.py index 915c89e0b00a..d9ed54322d5c 100644 --- a/tests/unit/sequence_parallelism/test_ulysses.py +++ b/tests/unit/sequence_parallelism/test_ulysses.py @@ -11,9 +11,12 @@ from unit.common import DistributedTest from deepspeed.sequence.layer import _SeqAllToAll from unit.util import skip_on_arch +from unit.simple_model import * +from deepspeed.utils import groups +from deepspeed.module_inject.tp_shard import get_shard_size_list +#Use mesh device to create data and sequence parallel group -#Use mesh device to create data and sequence parallel group class TestUlyssesUtils(DistributedTest): world_size = 4 @@ -75,3 +78,82 @@ def test_alltoall_output_consistency(self, d0: int, d1: int, head_dim: int, num_ # Check outputs are the same as input for i in range(1, len(outputs)): assert torch.allclose(input_tensor, outputs[i]), f"Outputs differ for sequence dim {seq_dims[i]}" + + +@pytest.mark.parametrize("d0", [2, 4]) #batch or sequence dimension +@pytest.mark.parametrize("d1", [4, 8]) #batch or sequence dimension +@pytest.mark.parametrize("num_heads", [3, 7]) +@pytest.mark.parametrize("head_dim", [16]) +class TestUlyssesAll2All_odd(DistributedTest): + world_size = 4 + + def test_alltoall_output_consistency(self, d0: int, d1: int, head_dim: int, num_heads: int) -> None: + + data_parallel_size = 2 + seq_parallel_size = self.world_size // data_parallel_size + skip_on_arch(min_arch=8) + + def seq_batch_heads_hash(d0, d1, h, offset_d0=0, offset_d1=0, offset_h=0): + d0 += offset_d0 + d1 += offset_d1 + h += offset_h + return d0 * 10 + h + d1 * 0.1 + + hidden_dim = 10 + model = SimpleModel(hidden_dim) + ds_engine, _, _, _ = initialize(model=model, + config_params={"train_batch_size": 8}, + mesh_param=(data_parallel_size, seq_parallel_size)) + + scatter_idx = 2 + outputs = [] + inputs = [] + batch_dims = [0, 1] + seq_dims = [1, 0] + + for idx, seq_dim in enumerate(seq_dims): + gather_idx = seq_dim + batch_dim_idx = batch_dims[idx] + + #4D tensor : b,s,h,d or s,b,h,d + #create a hash tensor from pos_id, head_id, and batch_id + d0_indices = torch.arange(d0).reshape(-1, 1, 1, 1) + d1_indices = torch.arange(d1).reshape(1, -1, 1, 1) + h_indices = torch.arange(num_heads).reshape(1, 1, -1, 1) + input_tensor = torch.randn(d0, d1, num_heads, head_dim, device=ds_engine.device) + if batch_dim_idx == 1: #seq_len_dim : 0(d0) + input_tensor[:] = seq_batch_heads_hash(d0_indices, d1_indices, h_indices, + d0 * groups._get_sequence_parallel_rank(), 0) + elif batch_dim_idx == 0: #seq_len_dim : 1(d1) + input_tensor[:] = seq_batch_heads_hash(d0_indices, d1_indices, h_indices, 0, + d1 * groups._get_sequence_parallel_rank()) + inputs.append(input_tensor) + + ### first all2all: sequence parallel to head parallel + s2h_tensor = _SeqAllToAll.apply(ds_engine.seq_parallel_group, input_tensor, scatter_idx, gather_idx, + batch_dim_idx) + + # s2h_tensor check for the first all2all: compare with the expected ground truth + d0_indices = torch.arange(s2h_tensor.shape[0]).reshape(-1, 1, 1, 1) + d1_indices = torch.arange(s2h_tensor.shape[1]).reshape(1, -1, 1, 1) + h_indices = torch.arange(s2h_tensor.shape[2]).reshape(1, 1, -1, 1) + shard_list = get_shard_size_list(num_heads, groups._get_sequence_parallel_world_size()) + head_offset = sum(shard_list[:groups._get_sequence_parallel_rank()]) + s2h_truth = torch.zeros_like(s2h_tensor) + s2h_truth[:] = seq_batch_heads_hash(d0_indices, d1_indices, h_indices, 0, 0, head_offset) + + assert torch.allclose(s2h_truth, + s2h_tensor), f"s2h_tensor differs from the expected for sequence dim: {seq_dim}" + #No op + ### second all2all: head parallel to sequence parallel + h2s_tensor = _SeqAllToAll.apply(ds_engine.seq_parallel_group, s2h_tensor, gather_idx, scatter_idx, + batch_dim_idx) + print( + f'[{dist.get_rank()}] s={seq_dim} input: {input_tensor.shape} s2h: {s2h_tensor.shape} h2s_tensor: {h2s_tensor.shape}' + ) + outputs.append(h2s_tensor) + + # Check outputs for the second all2all + for i in range(0, len(outputs)): + assert torch.allclose(inputs[i], + outputs[i]), f"[{dist.get_rank()}]Outputs differ for sequence dim {seq_dims[i]}"