From c4b1b65931a7614bfbed180c1e18c39b7cb61c2d Mon Sep 17 00:00:00 2001 From: Frank Lee Date: Mon, 26 Jun 2023 15:50:07 +0800 Subject: [PATCH] [test] fixed tests failed due to dtensor change (#4082) * [test] fixed tests failed due to dtensor change * polish code --- .../tensor_shard/node_handler/node_handler.py | 2 +- .../strategy/matmul_strategy_generator.py | 6 +- .../auto_parallel/tensor_shard/utils/misc.py | 4 +- colossalai/checkpoint_io/utils.py | 4 +- colossalai/lazy/lazy_init.py | 11 ++- colossalai/tensor/comm_spec.py | 97 +++++++++---------- colossalai/tensor/d_tensor/comm_spec.py | 88 ++++++++--------- colossalai/tensor/d_tensor/layout.py | 13 +-- .../tensor/d_tensor/layout_converter.py | 71 +++++++------- colossalai/tensor/shape_consistency.py | 6 +- colossalai/tensor/sharding_spec.py | 6 +- test.py | 1 - .../test_autochunk_unet.py | 11 +-- .../test_gemini_checkpoint_io.py | 4 +- tests/test_device/test_device_mesh.py | 10 +- tests/test_device/test_init_logical_pg.py | 16 ++- .../test_hf_model/hf_tracer_utils.py | 14 ++- .../test_hf_model/test_hf_albert.py | 2 +- .../test_tracer/test_hf_model/test_hf_bert.py | 4 +- .../test_hf_model/test_hf_diffuser.py | 2 +- .../test_tracer/test_hf_model/test_hf_gpt.py | 4 +- .../test_tracer/test_hf_model/test_hf_opt.py | 2 +- .../test_tracer/test_hf_model/test_hf_t5.py | 9 +- .../test_timm_model/test_timm_model.py | 2 +- .../test_torchaudio_model.py | 2 +- .../test_torchrec_model/test_deepfm_model.py | 2 +- .../test_torchrec_model/test_dlrm_model.py | 2 +- .../test_torchvision_model.py | 2 +- tests/test_lazy/lazy_init_utils.py | 4 +- tests/test_lazy/test_distribute.py | 30 +++--- tests/test_lazy/test_models.py | 2 +- .../test_dtensor/test_comm_spec.py | 33 ++----- .../test_tensor/test_dtensor/test_dtensor.py | 2 +- .../test_dtensor/test_layout_converter.py | 43 +++----- tests/test_tensor/test_shape_consistency.py | 7 +- tests/test_tensor/test_sharded_linear.py | 2 +- tests/test_tensor/test_sharding_spec.py | 2 +- 37 files changed, 233 insertions(+), 289 deletions(-) delete mode 100644 test.py diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/node_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/node_handler.py index 4262d76173e4..b4b7b0e794d1 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/node_handler.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/node_handler.py @@ -188,7 +188,7 @@ def register_strategy(self, compute_resharding_cost: bool = True) -> StrategiesV remove_strategy_list = [] for strategy in self.strategies_vector: shard_axis_list = [] - last_axis = len(self.device_mesh.mesh_shape) - 1 + last_axis = len(self.device_mesh.shape) - 1 for op_data, sharding_spec in strategy.sharding_specs.items(): if op_data.data is not None and isinstance(op_data.data, torch.Tensor): for dim, shard_axes in sharding_spec.dim_partition_dict.items(): diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/matmul_strategy_generator.py b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/matmul_strategy_generator.py index 1ce5a08f2d6b..aa1581b99e0f 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/matmul_strategy_generator.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/matmul_strategy_generator.py @@ -984,7 +984,7 @@ def split_batch_dim_both_contract(self, mesh_dim_0, mesh_dim_1): def collate_strategies(self) -> List[ShardingStrategy]: strategy_list = [] device_mesh_is_1d = True - if len(self.device_mesh.mesh_shape) == 2 and 1 not in self.device_mesh.mesh_shape: + if len(self.device_mesh.shape) == 2 and 1 not in self.device_mesh.shape: device_mesh_is_1d = False if device_mesh_is_1d: @@ -992,10 +992,10 @@ def collate_strategies(self) -> List[ShardingStrategy]: # Sb = Sb x Sb # can be None as it is only for 1D device mesh # only for 1D device mesh - if len(self.device_mesh.mesh_shape) == 1: + if len(self.device_mesh.shape) == 1: mesh_dim = 0 else: - mesh_dim = self.device_mesh.mesh_shape.index(1) + mesh_dim = self.device_mesh.shape.index(1) strategy_list.append(self.split_one_batch_dim(mesh_dim)) else: # for 2D device mesh diff --git a/colossalai/auto_parallel/tensor_shard/utils/misc.py b/colossalai/auto_parallel/tensor_shard/utils/misc.py index 9e402dab7578..475e95fc4326 100644 --- a/colossalai/auto_parallel/tensor_shard/utils/misc.py +++ b/colossalai/auto_parallel/tensor_shard/utils/misc.py @@ -46,8 +46,8 @@ def check_sharding_spec_validity(sharding_spec: ShardingSpec, tensor: torch.Tens # make sure all dims are covered in sharding spec sharding_len = len(sharding_spec.sharding_sequence) tensor_num_dim = tensor.dim() - num_devices_in_col = sharding_spec.device_mesh.mesh_shape[0] - num_devices_in_row = sharding_spec.device_mesh.mesh_shape[1] + num_devices_in_col = sharding_spec.device_mesh.shape[0] + num_devices_in_row = sharding_spec.device_mesh.shape[1] assert sharding_len == tensor_num_dim, \ f'The ShardingSpec ({sharding_spec.sharding_sequence}) is created for {sharding_len}-dimension tensor, but the given tensor is {tensor_num_dim}-dimension ({tensor.shape}).' diff --git a/colossalai/checkpoint_io/utils.py b/colossalai/checkpoint_io/utils.py index 68981dff0d0a..485577b9650c 100644 --- a/colossalai/checkpoint_io/utils.py +++ b/colossalai/checkpoint_io/utils.py @@ -99,7 +99,7 @@ def shard_model_checkpoint(state_dict: torch.Tensor, max_shard_size: int = 1024) for key, weight in state_dict.items(): ret_block = None ret_block_size = 0 - if is_distributed_tensor(weight): + if not is_distributed_tensor(weight): weight_size = calculate_tensor_size(weight) # If this weight is going to tip up over the maximal size, we split. @@ -146,7 +146,7 @@ def shard_optimizer_checkpoint(state_dict: dict, max_shard_size: int = 1024) -> continue # If the states are stored as DTensors, mark isDTensor as true. - if type(state_tensor) == DTensor: + if is_distributed_tensor(state_tensor): isDTensor = True state_size += calculate_tensor_size(state_tensor) diff --git a/colossalai/lazy/lazy_init.py b/colossalai/lazy/lazy_init.py index 1e45eced5f34..8b911407307c 100644 --- a/colossalai/lazy/lazy_init.py +++ b/colossalai/lazy/lazy_init.py @@ -1,5 +1,5 @@ from types import MethodType -from typing import Callable, Optional, Union +from typing import Callable, Dict, Optional, Union import torch import torch.distributed as dist @@ -173,7 +173,7 @@ def materialize(self) -> torch.Tensor: self.clean() return _convert_cls(self, target) - def distribute(self, layout: Layout) -> torch.Tensor: + def distribute(self, device_mesh: DeviceMesh, sharding_spec: ShardingSpec) -> torch.Tensor: """Distribute the ``LazyTensor`` to ``torch.Tensor`` by modifying __class__ (inplace), according to the layout. Args: @@ -537,7 +537,10 @@ def apply_fn(name: str, p: LazyTensor): return _apply_to_lazy_module(module, apply_fn, verbose) @staticmethod - def distribute(module: nn.Module, layout_dict: dict, verbose: bool = False) -> nn.Module: + def distribute(module: nn.Module, + device_mesh: DeviceMesh, + sharding_spec_dict: Dict[str, ShardingSpec], + verbose: bool = False) -> nn.Module: """Distribute all ``nn.Parameter`` from ``LazyTensor``. This function will modify the module in-place. Args: @@ -547,7 +550,7 @@ def distribute(module: nn.Module, layout_dict: dict, verbose: bool = False) -> n """ def apply_fn(name: str, p: LazyTensor): - p.distribute(layout_dict[name]) + p.distribute(device_mesh, sharding_spec_dict[name]) return _apply_to_lazy_module(module, apply_fn, verbose) diff --git a/colossalai/tensor/comm_spec.py b/colossalai/tensor/comm_spec.py index af38d2a502c2..204f81343199 100644 --- a/colossalai/tensor/comm_spec.py +++ b/colossalai/tensor/comm_spec.py @@ -16,69 +16,66 @@ def _all_gather(tensor, comm_spec): ''' Implement all gather operation on device mesh based on information provided by comm_spec. ''' - process_groups_list = comm_spec.device_mesh.process_groups_dict[comm_spec.logical_process_axis] - for rank_list, process_group in process_groups_list: - if dist.get_rank() in rank_list: - tensor_list = [ - torch.zeros(tensor.shape, dtype=tensor.dtype, device=tensor.device) - for _ in range(comm_spec.device_mesh.mesh_shape[comm_spec.logical_process_axis]) - ] - # without this contiguous operation, the all gather may get some unexpected results. - tensor = tensor.contiguous() - dist.all_gather(tensor_list, tensor, group=process_group) - output = torch.cat(tuple(tensor_list), comm_spec.gather_dim).contiguous() - return output + process_groups = comm_spec.device_mesh.get_process_group_for_all_axes() + process_group = process_groups[comm_spec.logical_process_axis] + + tensor_list = [ + torch.zeros(tensor.shape, dtype=tensor.dtype, device=tensor.device) + for _ in range(comm_spec.device_mesh.shape[comm_spec.logical_process_axis]) + ] + # without this contiguous operation, the all gather may get some unexpected results. + tensor = tensor.contiguous() + dist.all_gather(tensor_list, tensor, group=process_group) + output = torch.cat(tuple(tensor_list), comm_spec.gather_dim).contiguous() + return output def _split(tensor, comm_spec): ''' Implement shard operation on device mesh based on information provided by comm_spec. ''' - process_groups_list = comm_spec.device_mesh.process_groups_dict[comm_spec.logical_process_axis] - for rank_list, _ in process_groups_list: - if dist.get_rank() in rank_list: - dim = comm_spec.shard_dim - length = tensor.shape[comm_spec.shard_dim] // len(rank_list) - start = length * rank_list.index(dist.get_rank()) - output = torch.narrow(tensor, dim, start, length).contiguous() - return output + process_groups = comm_spec.device_mesh.get_process_group_for_all_axes() + process_group = process_groups[comm_spec.logical_process_axis] + + dim = comm_spec.shard_dim + length = tensor.shape[comm_spec.shard_dim] // dist.get_world_size(process_group) + start = length * dist.get_rank(process_group) + output = torch.narrow(tensor, dim, start, length).contiguous() + return output def _all_to_all(tensor, comm_spec): ''' Implement all to all operation on device mesh based on information provided by comm_spec. ''' - process_groups_list = comm_spec.device_mesh.process_groups_dict[comm_spec.logical_process_axis] - for rank_list, process_group in process_groups_list: - if dist.get_rank() in rank_list: - new_shape = list(tensor.shape) - new_shape[comm_spec.shard_dim] = new_shape[comm_spec.shard_dim] // len(rank_list) - new_shape = torch.Size(new_shape) - output_tensor_list = [ - torch.zeros(new_shape, dtype=tensor.dtype, device=tensor.device) for _ in range(len(rank_list)) - ] - dim = comm_spec.shard_dim - length = tensor.shape[comm_spec.shard_dim] // len(rank_list) - input_tensor_list = [ - torch.narrow(tensor, dim, length * i, length).contiguous() for i in range(len(rank_list)) - ] - group = process_group - dist.all_to_all(output_tensor_list, input_tensor_list, group) - output = torch.cat(tuple(output_tensor_list), comm_spec.gather_dim).contiguous() - return output + process_groups = comm_spec.device_mesh.get_process_group_for_all_axes() + process_group = process_groups[comm_spec.logical_process_axis] + world_size = dist.get_world_size(process_group) + + new_shape = list(tensor.shape) + new_shape[comm_spec.shard_dim] = new_shape[comm_spec.shard_dim] // world_size + new_shape = torch.Size(new_shape) + output_tensor_list = [torch.zeros(new_shape, dtype=tensor.dtype, device=tensor.device) for _ in range(world_size)] + dim = comm_spec.shard_dim + length = tensor.shape[comm_spec.shard_dim] // world_size + input_tensor_list = [torch.narrow(tensor, dim, length * i, length).contiguous() for i in range(world_size)] + group = process_group + dist.all_to_all(output_tensor_list, input_tensor_list, group) + output = torch.cat(tuple(output_tensor_list), comm_spec.gather_dim).contiguous() + return output def _all_reduce(tensor, comm_spec, async_op=False): ''' Implement all reduce operation on device mesh based on information provided by comm_spec. ''' - process_groups_list = comm_spec.device_mesh.process_groups_dict[comm_spec.logical_process_axis] - for rank_list, process_group in process_groups_list: - if dist.get_rank() in rank_list: - if not tensor.is_contiguous(): - tensor = tensor.contiguous() - dist.all_reduce(tensor, op=ReduceOp.SUM, group=process_group, async_op=async_op) - return tensor + process_groups = comm_spec.device_mesh.get_process_group_for_all_axes() + process_group = process_groups[comm_spec.logical_process_axis] + + if not tensor.is_contiguous(): + tensor = tensor.contiguous() + dist.all_reduce(tensor, op=ReduceOp.SUM, group=process_group, async_op=async_op) + return tensor def _mix_gather(tensor, comm_spec): @@ -128,7 +125,7 @@ def _mix_gather(tensor, comm_spec): process_group = "[0, 1, 2, 3, 4, 5, 6, 7]" tensor_list = [(0,0),(1,0),(2,0),(3,0),(4,0),(5,0),(6,0),(7,0)] ''' - total_slices = comm_spec.device_mesh.mesh_shape[0] + total_slices = comm_spec.device_mesh.shape[0] tensor_list = [torch.zeros(tensor.shape, dtype=tensor.dtype, device=tensor.device) for _ in range(total_slices)] leading_group_dim = comm_spec.logical_process_axes[0] assert len(comm_spec.device_mesh.process_groups_dict) == 1 @@ -149,7 +146,7 @@ def _mix_gather(tensor, comm_spec): if comm_spec.logical_process_axes[0] == comm_spec.logical_process_axes[1]: output = torch.cat(tuple(tensor_list), comm_spec.gather_dim[0]).contiguous() else: - mesh_shape = comm_spec.device_meshes.mesh_shape + mesh_shape = comm_spec.device_meshes.shape cat_slice = [mesh_shape[comm_spec.logical_process_axes[0]], mesh_shape[comm_spec.logical_process_axes[1]]] tmp_tensor_shape = list(tensor.shape) tmp_tensor_shape[comm_spec.gather_dim[0]] *= cat_slice[0] @@ -181,9 +178,9 @@ def _mix_split(tensor, comm_spec): # [4, 5, 6, 7]] # return {0: [0, 4, 1, 5, 2, 6, 3, 7], 1: [0, 1, 2, 3, 4, 5, 6, 7]} ''' - mesh_shape = comm_spec.device_meshes.mesh_shape + mesh_shape = comm_spec.device_meshes.shape dim = comm_spec.gather_dim - total_slices = comm_spec.device_mesh.mesh_shape[0] + total_slices = comm_spec.device_mesh.shape[0] # Get global rank rank = dist.get_rank() @@ -414,7 +411,7 @@ def __init__(self, self.forward_only = forward_only if isinstance(self.logical_process_axis, list): if not mix_gather: - self.device_mesh = self.sharding_spec.device_mesh.flatten_device_mesh + self.device_mesh = self.sharding_spec.device_mesh.flatten() self.logical_process_axis = 0 else: self.device_meshes = self.sharding_spec.device_mesh.flatten_device_meshes diff --git a/colossalai/tensor/d_tensor/comm_spec.py b/colossalai/tensor/d_tensor/comm_spec.py index 159125fa16db..79b2e3ef936a 100644 --- a/colossalai/tensor/d_tensor/comm_spec.py +++ b/colossalai/tensor/d_tensor/comm_spec.py @@ -24,12 +24,12 @@ class CommSpec: ''' Communication spec is used to record the communication action. It converts the communication spec to real action which will be used in runtime. It contains comm_pattern to determine the - communication method, process_groups_dict to determine the process groups, gather_dim and shard_dim + communication method, process_group_dict to determine the process groups, gather_dim and shard_dim to determine the buffer shape, and logical_process_axis Argument: - comm_pattern(CollectiveCommPattern): describe the communication method used in this spec. - process_groups_dict(Dict): A dict which contains the process groups used to apply this CommSpec. + comm_pattern(CollectiveCommPattern): decribe the communication method used in this spec. + process_group_dict(Dict): A dict which contains the process groups used to apply this CommSpec. gather_dim(int, Optional): The gather_dim of the tensor will be gathered. shard_dim(int, Optional): The shard_dim of the tensor will be sharded. logical_process_axis(Union(int, List[int]), Optional): The mesh_dim to implement the communication action. @@ -37,7 +37,7 @@ class CommSpec: def __init__(self, comm_pattern: CollectiveCommPattern, - process_groups_dict: Dict, + process_group_dict: Dict, gather_dim: int = None, shard_dim: int = None, logical_process_axis: int = None): @@ -45,7 +45,7 @@ def __init__(self, self.gather_dim = gather_dim self.shard_dim = shard_dim self.logical_process_axis = logical_process_axis - self.process_groups_dict = process_groups_dict + self.process_group_dict = process_group_dict def __repr__(self): res_list = ["CommSpec:("] @@ -92,68 +92,56 @@ def _all_gather(tensor: torch.Tensor, comm_spec: CommSpec): ''' Implement all gather operation on device mesh based on information provided by comm_spec. ''' - process_groups_list = comm_spec.process_groups_dict[comm_spec.logical_process_axis] - for rank_list, process_group in process_groups_list: - if dist.get_rank() in rank_list: - tensor_list = [ - torch.zeros(tensor.shape, dtype=tensor.dtype, device=tensor.device) for _ in range(len(rank_list)) - ] - # without this contiguous operation, the all gather may get some unexpected results. - tensor = tensor.contiguous() - dist.all_gather(tensor_list, tensor, group=process_group) - output = torch.cat(tuple(tensor_list), comm_spec.gather_dim).contiguous() - return output + process_group = comm_spec.process_group_dict[comm_spec.logical_process_axis] + world_size = dist.get_world_size(process_group) + tensor_list = [torch.zeros(tensor.shape, dtype=tensor.dtype, device=tensor.device) for _ in range(world_size)] + # without this contiguous operation, the all gather may get some unexpected results. + tensor = tensor.contiguous() + dist.all_gather(tensor_list, tensor, group=process_group) + output = torch.cat(tuple(tensor_list), comm_spec.gather_dim).contiguous() + return output def _split(tensor: torch.Tensor, comm_spec: CommSpec): ''' Implement shard operation on device mesh based on information provided by comm_spec. ''' - process_groups_list = comm_spec.process_groups_dict[comm_spec.logical_process_axis] - for rank_list, _ in process_groups_list: - if dist.get_rank() in rank_list: - dim = comm_spec.shard_dim - length = tensor.shape[comm_spec.shard_dim] // len(rank_list) - start = length * rank_list.index(dist.get_rank()) - output = torch.narrow(tensor, dim, start, length).contiguous() - return output + process_group = comm_spec.process_group_dict[comm_spec.logical_process_axis] + dim = comm_spec.shard_dim + length = tensor.shape[comm_spec.shard_dim] // dist.get_world_size(process_group) + start = length * dist.get_rank(process_group) + output = torch.narrow(tensor, dim, start, length).contiguous() + return output def _all_to_all(tensor: torch.Tensor, comm_spec: CommSpec): ''' Implement all to all operation on device mesh based on information provided by comm_spec. ''' - process_groups_list = comm_spec.process_groups_dict[comm_spec.logical_process_axis] - for rank_list, process_group in process_groups_list: - if dist.get_rank() in rank_list: - new_shape = list(tensor.shape) - new_shape[comm_spec.shard_dim] = new_shape[comm_spec.shard_dim] // len(rank_list) - new_shape = torch.Size(new_shape) - output_tensor_list = [ - torch.zeros(new_shape, dtype=tensor.dtype, device=tensor.device) for _ in range(len(rank_list)) - ] - dim = comm_spec.shard_dim - length = tensor.shape[comm_spec.shard_dim] // len(rank_list) - input_tensor_list = [ - torch.narrow(tensor, dim, length * i, length).contiguous() for i in range(len(rank_list)) - ] - group = process_group - dist.all_to_all(output_tensor_list, input_tensor_list, group) - output = torch.cat(tuple(output_tensor_list), comm_spec.gather_dim).contiguous() - return output + process_group = comm_spec.process_group_dict[comm_spec.logical_process_axis] + world_size = dist.get_world_size(process_group) + new_shape = list(tensor.shape) + new_shape[comm_spec.shard_dim] = new_shape[comm_spec.shard_dim] // world_size + new_shape = torch.Size(new_shape) + output_tensor_list = [torch.zeros(new_shape, dtype=tensor.dtype, device=tensor.device) for _ in range(world_size)] + dim = comm_spec.shard_dim + length = tensor.shape[comm_spec.shard_dim] // world_size + input_tensor_list = [torch.narrow(tensor, dim, length * i, length).contiguous() for i in range(world_size)] + group = process_group + dist.all_to_all(output_tensor_list, input_tensor_list, group) + output = torch.cat(tuple(output_tensor_list), comm_spec.gather_dim).contiguous() + return output def _all_reduce(tensor: torch.Tensor, comm_spec: CommSpec, async_op: bool = False): ''' Implement all reduce operation on device mesh based on information provided by comm_spec. ''' - process_groups_list = comm_spec.process_groups_dict[comm_spec.logical_process_axis] - for rank_list, process_group in process_groups_list: - if dist.get_rank() in rank_list: - if not tensor.is_contiguous(): - tensor = tensor.contiguous() - dist.all_reduce(tensor, op=ReduceOp.SUM, group=process_group, async_op=async_op) - return tensor + process_group = comm_spec.process_group_dict[comm_spec.logical_process_axis] + if not tensor.is_contiguous(): + tensor = tensor.contiguous() + dist.all_reduce(tensor, op=ReduceOp.SUM, group=process_group, async_op=async_op) + return tensor class _ReduceGrad(torch.autograd.Function): @@ -269,7 +257,7 @@ def symbolic(graph, input_): def forward(ctx, input_, comm_spec): output = _all_to_all(input_, comm_spec) comm_spec_for_backward = CommSpec(comm_pattern=comm_spec.comm_pattern, - process_groups_dict=comm_spec.process_groups_dict, + process_group_dict=comm_spec.process_group_dict, gather_dim=comm_spec.shard_dim, shard_dim=comm_spec.gather_dim, logical_process_axis=comm_spec.logical_process_axis) diff --git a/colossalai/tensor/d_tensor/layout.py b/colossalai/tensor/d_tensor/layout.py index 4185b85860e3..a35b2f43e44b 100644 --- a/colossalai/tensor/d_tensor/layout.py +++ b/colossalai/tensor/d_tensor/layout.py @@ -14,24 +14,21 @@ class Layout: Attributes: device_mesh: the device mesh to store the tensor distributed. - device_type: the type of the device mesh, e.g. 'cpu' or 'cuda'. sharding_spec: the sharding specification to describe how the tensor is sharded. - entire_shape: the entire shape of the global tensor. + global_shape: the entire shape of the global tensor. """ - def __init__(self, device_mesh: DeviceMesh, device_type: torch.device, sharding_spec: ShardingSpec, - entire_shape: torch.Size): + def __init__(self, device_mesh: DeviceMesh, sharding_spec: ShardingSpec, global_shape: torch.Size): self.device_mesh = device_mesh - self.device_type = device_type self.sharding_spec = sharding_spec - self.entire_shape = entire_shape + self.global_shape = global_shape self._sanity_check() def __hash__(self) -> int: return hash(f'{self.sharding_spec}') def get_sharded_shape_per_device(self): - sharded_shape = list(self.entire_shape) + sharded_shape = list(self.global_shape) for dim, shard_list in self.sharding_spec.dim_partition_dict.items(): mesh_list = [self.device_mesh.shape[mesh_dim] for mesh_dim in shard_list] shard_partitions = reduce(operator.mul, mesh_list, 1) @@ -56,7 +53,7 @@ def _sanity_check(self): # make sure that the sharding for a dimension is divisible by the number of devices for dim, shard_list in sharding_spec.dim_partition_dict.items(): - tensor_dim_size = self.entire_shape[dim] + tensor_dim_size = self.global_shape[dim] num_devices = 1 for element in shard_list: diff --git a/colossalai/tensor/d_tensor/layout_converter.py b/colossalai/tensor/d_tensor/layout_converter.py index 14f9c4561622..528ed7901c4f 100644 --- a/colossalai/tensor/d_tensor/layout_converter.py +++ b/colossalai/tensor/d_tensor/layout_converter.py @@ -3,10 +3,8 @@ from dataclasses import dataclass from typing import Dict, List, Tuple -import numpy as np import torch -from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, TrainCycleItem from colossalai.context.singleton_meta import SingletonMeta from colossalai.tensor.d_tensor.comm_spec import * from colossalai.tensor.d_tensor.layout import Layout @@ -37,6 +35,9 @@ def set_layout_converting_options(options: LayoutConverterOptions): class LayoutConverter(metaclass=SingletonMeta): + """ + LayoutConverter is a singleton class which converts the layout of a distributed tensor. + """ def __init__(self): self._options = None @@ -79,15 +80,14 @@ def all_gather_transform_layouts(self, source_layout: Layout) -> Dict[Layout, Co # [[0, 1, # [2, 3]] device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) - entire_shape = (4, 4, 4) + global_shape = (4, 4, 4) dim_partition_dict = {0: [0], 1: [1]} # [S0,S1,R] sharding_spec = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_dict) layout = Layout(device_mesh=device_mesh, - device_type=torch.device('cuda'), sharding_spec=sharding_spec, - entire_shape=entire_shape) + global_shape=global_shape) rst_dict = layout_converter.all_gather_transform_layouts(layout) for layout, comm_spec in rst_dict.items(): @@ -100,7 +100,12 @@ def all_gather_transform_layouts(self, source_layout: Layout) -> Dict[Layout, Co valid_spec_dict = {} comm_pattern = CollectiveCommPattern.GATHER_FWD_SPLIT_BWD source_spec = source_layout.sharding_spec - process_groups_dict = source_layout.device_mesh.process_groups_dict + + # the key of the dict is the axis + # the value is the process group + current_rank = source_layout.device_mesh._global_rank_of_current_process + process_group_dict = source_layout.device_mesh._process_group_dict[current_rank] + for target_pair in source_spec.dim_partition_dict.items(): shard_list = all_gather_simulator(target_pair) index = target_pair[0] @@ -118,7 +123,7 @@ def all_gather_transform_layouts(self, source_layout: Layout) -> Dict[Layout, Co logical_process_axis = target_pair[1][-1] comm_spec = CommSpec( comm_pattern, - process_groups_dict=process_groups_dict, + process_group_dict=process_group_dict, gather_dim=gather_dim, # shard_dim will be used during backward shard_dim=gather_dim, @@ -129,8 +134,7 @@ def all_gather_transform_layouts(self, source_layout: Layout) -> Dict[Layout, Co new_sharding_spec = ShardingSpec(source_spec.dims, dim_partition_dict=new_dim_partition_dict) new_layout = Layout(device_mesh=source_layout.device_mesh, sharding_spec=new_sharding_spec, - device_type=source_layout.device_type, - entire_shape=source_layout.entire_shape) + global_shape=source_layout.global_shape) valid_spec_dict[new_layout] = comm_spec except LayoutException: @@ -155,15 +159,14 @@ def all_to_all_transform_layout(self, source_layout: Layout) -> Dict[Layout, Com # [[0, 1, # [2, 3]] device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) - entire_shape = (4, 4, 4) + global_shape = (4, 4, 4) dim_partition_dict = {0: [0], 1: [1]} # [S0,S1,R] sharding_spec = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_dict) layout = Layout(device_mesh=device_mesh, - device_type=torch.device('cuda'), sharding_spec=sharding_spec, - entire_shape=entire_shape) + global_shape=global_shape) rst_dict = layout_converter.all_to_all_transform_layout(layout) for layout, comm_spec in rst_dict.items(): @@ -176,7 +179,12 @@ def all_to_all_transform_layout(self, source_layout: Layout) -> Dict[Layout, Com ''' valid_spec_dict = {} comm_pattern = CollectiveCommPattern.ALL2ALL_FWD_ALL2ALL_BWD - process_groups_dict = source_layout.device_mesh.process_groups_dict + + # the key of the dict is the axis + # the value is the process group + current_rank = source_layout.device_mesh._global_rank_of_current_process + process_group_dict = source_layout.device_mesh._process_group_dict[current_rank] + source_spec = source_layout.sharding_spec tensor_dims = source_spec.dims for f_index in range(tensor_dims - 1): @@ -217,7 +225,7 @@ def all_to_all_transform_layout(self, source_layout: Layout) -> Dict[Layout, Com shard_dim = f_index logical_process_axis = b_target_pair[1][-1] comm_spec = CommSpec(comm_pattern, - process_groups_dict, + process_group_dict=process_group_dict, gather_dim=gather_dim, shard_dim=shard_dim, logical_process_axis=logical_process_axis) @@ -240,8 +248,7 @@ def all_to_all_transform_layout(self, source_layout: Layout) -> Dict[Layout, Com new_sharding_spec = ShardingSpec(source_spec.dims, dim_partition_dict=new_dim_partition_dict) new_layout = Layout(device_mesh=source_layout.device_mesh, sharding_spec=new_sharding_spec, - device_type=source_layout.device_type, - entire_shape=source_layout.entire_shape) + global_shape=source_layout.global_shape) valid_spec_dict[new_layout] = comm_spec except LayoutException: pass @@ -266,16 +273,15 @@ def shard_transform_layout(self, source_layout: Layout) -> Dict[Layout, CommSpec # [[0, 1, # [2, 3]] device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) - entire_shape = (4, 4, 4) + global_shape = (4, 4, 4) dim_partition_dict = {0: [0]} # [S0,R,R] sharding_spec = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_dict) layout = Layout(device_mesh=device_mesh, - device_type=torch.device('cuda'), sharding_spec=sharding_spec, - entire_shape=entire_shape) + global_shape=global_shape) rst_dict = layout_converter.shard_transform_layout(layout) for layout, comm_spec in rst_dict.items(): @@ -289,7 +295,11 @@ def shard_transform_layout(self, source_layout: Layout) -> Dict[Layout, CommSpec valid_spec_dict = {} comm_pattern = CollectiveCommPattern.SPLIT_FWD_GATHER_BWD source_spec = source_layout.sharding_spec - process_groups_dict = source_layout.device_mesh.process_groups_dict + + # the key of the dict is the axis + # the value is the process group + current_rank = source_layout.device_mesh._global_rank_of_current_process + process_group_dict = source_layout.device_mesh._process_group_dict[current_rank] # legal sharding dims means the mesh_id is still available to use. legal_sharding_dims = [i for i in range(len(source_layout.device_mesh.shape))] @@ -317,7 +327,7 @@ def shard_transform_layout(self, source_layout: Layout) -> Dict[Layout, CommSpec shard_dim = index logical_process_axis = shard_list[-1] comm_spec = CommSpec(comm_pattern, - process_groups_dict, + process_group_dict=process_group_dict, gather_dim=shard_dim, shard_dim=shard_dim, logical_process_axis=logical_process_axis) @@ -328,8 +338,7 @@ def shard_transform_layout(self, source_layout: Layout) -> Dict[Layout, CommSpec dim_partition_dict=new_dim_partition_dict) new_layout = Layout(device_mesh=source_layout.device_mesh, sharding_spec=new_sharding_spec, - device_type=source_layout.device_type, - entire_shape=source_layout.entire_shape) + global_shape=source_layout.global_shape) valid_spec_dict[new_layout] = comm_spec except LayoutException: pass @@ -387,7 +396,7 @@ def layout_converting(self, source_layout: Layout, # [[0, 1, # [2, 3]] device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) - entire_shape = (4, 4, 4) + global_shape = (4, 4, 4) dim_partition_source = {1: [0, 1]} dim_partition_target = {0: [0, 1]} @@ -395,16 +404,14 @@ def layout_converting(self, source_layout: Layout, # [R,S01,R] sharding_spec_source = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_source) source_layout = Layout(device_mesh=device_mesh, - device_type=torch.device('cuda'), sharding_spec=sharding_spec_source, - entire_shape=entire_shape) + global_shape=global_shape) # [S01,R,R] sharding_spec_target = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_target) target_layout = Layout(device_mesh=device_mesh, - device_type=torch.device('cuda'), sharding_spec=sharding_spec_target, - entire_shape=entire_shape) + global_shape=global_shape) transform_path, comm_action_sequence = layout_converter.layout_converting(source_layout, target_layout) transform_path_str = '->'.join([str(layout.sharding_spec.sharding_sequence) for layout in transform_path]) @@ -493,21 +500,19 @@ def apply(self, tensor: torch.Tensor, source_layout: Layout, target_layout: Layo # [[0, 1, # [2, 3]] device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) - entire_shape = (4, 4, 4) + global_shape = (4, 4, 4) # [S0,R,R] sharding_spec_source = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_source) source_layout = Layout(device_mesh=device_mesh, - device_type=torch.device('cuda'), sharding_spec=sharding_spec_source, - entire_shape=entire_shape) + global_shape=global_shape) # [R,S0,R] sharding_spec_target = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_target) target_layout = Layout(device_mesh=device_mesh, - device_type=torch.device('cuda'), sharding_spec=sharding_spec_target, - entire_shape=entire_shape) + global_shape=global_shape) if rank in (0, 1): sharded_tensor_0 = torch.zeros(2, 1) diff --git a/colossalai/tensor/shape_consistency.py b/colossalai/tensor/shape_consistency.py index 5bec552d69d5..99d782c3f6e8 100644 --- a/colossalai/tensor/shape_consistency.py +++ b/colossalai/tensor/shape_consistency.py @@ -285,7 +285,7 @@ def get_all_shard_spec(self, source_spec: ShardingSpec, orig_cost_dict): comm_pattern = CollectiveCommPattern.SPLIT_FWD_GATHER_BWD # legal sharding dims means the mesh_id is still available to use. - legal_sharding_dims = [i for i in range(len(source_spec.device_mesh.mesh_shape))] + legal_sharding_dims = [i for i in range(len(source_spec.device_mesh.shape))] for dim, shard_list in source_spec.dim_partition_dict.items(): for element in shard_list: legal_sharding_dims.remove(element) @@ -435,7 +435,7 @@ def gather_analysis(comm_spec: CommSpec, discard_input: bool, alloc_numel: int, """ input_shape = compute_shape(comm_spec.sharding_spec) input_numel = np.prod(input_shape) - output_numel = input_numel * comm_spec.device_mesh.mesh_shape[comm_spec.logical_process_axis] + output_numel = input_numel * comm_spec.device_mesh.shape[comm_spec.logical_process_axis] peak_numel = max(peak_numel, alloc_numel + output_numel * 2) alloc_numel += output_numel if discard_input: @@ -461,7 +461,7 @@ def split_analysis(comm_spec: CommSpec, discard_input: bool, alloc_numel: int, p # generate a new tensor input_shape = compute_shape(comm_spec.sharding_spec) input_numel = np.prod(input_shape) - output_numel = input_numel // comm_spec.device_mesh.mesh_shape[comm_spec.logical_process_axis] + output_numel = input_numel // comm_spec.device_mesh.shape[comm_spec.logical_process_axis] alloc_numel += output_numel peak_numel = max(peak_numel, alloc_numel) if discard_input: diff --git a/colossalai/tensor/sharding_spec.py b/colossalai/tensor/sharding_spec.py index 406ad49097b5..e594fd297dc4 100644 --- a/colossalai/tensor/sharding_spec.py +++ b/colossalai/tensor/sharding_spec.py @@ -195,7 +195,7 @@ def __init__(self, def __repr__(self): res_list = ["DistSpec:"] res_list.append(f"\n\tshard_sequence: " + ",".join(str(dimspec) for dimspec in self.sharding_sequence)) - res_list.append(f"\n\tdevice_mesh_shape: {self.device_mesh.mesh_shape}") + res_list.append(f"\n\tdevice_mesh_shape: {self.device_mesh.shape}") return ' '.join(res_list) def _sanity_check(self): @@ -222,7 +222,7 @@ def _sanity_check(self): num_devices = 1 for element in shard_list: - num_devices *= self.device_mesh.mesh_shape[element] + num_devices *= self.device_mesh.shape[element] if tensor_dim_size % num_devices != 0: raise ShardingNotDivisibleError( @@ -288,7 +288,7 @@ def get_sharded_shape_per_device(self): sharded_shape = list(self.entire_shape) for dim, shard_list in self.dim_partition_dict.items(): - mesh_list = [self.device_mesh.mesh_shape[mesh_dim] for mesh_dim in shard_list] + mesh_list = [self.device_mesh.shape[mesh_dim] for mesh_dim in shard_list] shard_partitions = reduce(operator.mul, mesh_list, 1) assert sharded_shape[ dim] % shard_partitions == 0, f'Cannot shard dimension {dim} into {shard_partitions} partitions.' diff --git a/test.py b/test.py deleted file mode 100644 index f283e21a1ebd..000000000000 --- a/test.py +++ /dev/null @@ -1 +0,0 @@ -from colossalai.tensor.d_tensor.api import to_distributed_tensor diff --git a/tests/test_autochunk/test_autochunk_diffuser/test_autochunk_unet.py b/tests/test_autochunk/test_autochunk_diffuser/test_autochunk_unet.py index fc9d8455ed5c..f0cf2a5fcbca 100644 --- a/tests/test_autochunk/test_autochunk_diffuser/test_autochunk_unet.py +++ b/tests/test_autochunk/test_autochunk_diffuser/test_autochunk_unet.py @@ -58,13 +58,4 @@ def test_evoformer_block(model, shape, max_memory): if __name__ == "__main__": - run_test( - rank=0, - data=get_data(LATENTS_SHAPE), - max_memory=None, - model=UNet2DModel, - print_code=False, - print_mem=True, - print_est_mem=False, - print_progress=False, - ) + test_evoformer_block() diff --git a/tests/test_checkpoint_io/test_gemini_checkpoint_io.py b/tests/test_checkpoint_io/test_gemini_checkpoint_io.py index 14d69cab2176..602cf468c944 100644 --- a/tests/test_checkpoint_io/test_gemini_checkpoint_io.py +++ b/tests/test_checkpoint_io/test_gemini_checkpoint_io.py @@ -22,7 +22,7 @@ @parameterize('use_safetensors', [False, True]) def exam_state_dict_with_origin(placement_policy, model_name, use_safetensors: bool): from transformers import BertForSequenceClassification - (model_fn, data_gen_fn, output_transform_fn, _) = next(iter(model_zoo.get_sub_registry(model_name).values())) + (model_fn, data_gen_fn, output_transform_fn, _, _) = next(iter(model_zoo.get_sub_registry(model_name).values())) bert_model = model_fn() with shared_tempdir() as tempdir: @@ -53,7 +53,7 @@ def exam_state_dict_with_origin(placement_policy, model_name, use_safetensors: b @parameterize('shard', [True, False]) @parameterize('model_name', ['transformers_gpt']) def exam_state_dict(placement_policy, shard: bool, model_name: str): - (model_fn, data_gen_fn, output_transform_fn, _) = next(iter(model_zoo.get_sub_registry(model_name).values())) + (model_fn, data_gen_fn, output_transform_fn, _, _) = next(iter(model_zoo.get_sub_registry(model_name).values())) criterion = lambda x: x.mean() plugin = GeminiPlugin(placement_policy=placement_policy) booster = Booster(plugin=plugin) diff --git a/tests/test_device/test_device_mesh.py b/tests/test_device/test_device_mesh.py index e9f0f9477e4a..590d6966bff6 100644 --- a/tests/test_device/test_device_mesh.py +++ b/tests/test_device/test_device_mesh.py @@ -8,18 +8,16 @@ def test_device_mesh(): - physical_mesh_id = torch.arange(0, 16).reshape(2, 8) + physical_mesh_id = torch.arange(0, 16) mesh_shape = (4, 4) # [[0, 1, 2, 3], # [4, 5, 6, 7], # [8, 9, 10,11], # [12,13,14,15]] device_mesh = DeviceMesh(physical_mesh_id, mesh_shape) - assert device_mesh.convert_map[5] == [1, 1] - assert device_mesh.convert_map[11] == [2, 3] - assert device_mesh.global_rank_to_process_groups_with_logical_rank(0)[0] == [[0, 0], [1, 0], [2, 0], [3, 0]] - assert device_mesh.global_rank_to_process_groups_with_logical_rank(2)[1] == [[0, 0], [0, 1], [0, 2], [0, 3]] - assert device_mesh.global_rank_to_process_groups_with_global_rank(2)[1] == [0, 1, 2, 3] + assert device_mesh.global_rank_to_local_rank(5) == [1, 1] + assert device_mesh.global_rank_to_local_rank(11) == [2, 3] + assert device_mesh.get_ranks_in_process_group(axis=1, global_rank=2) == [0, 1, 2, 3] def check_1d_device_mesh(): diff --git a/tests/test_device/test_init_logical_pg.py b/tests/test_device/test_init_logical_pg.py index 2b7060c4846a..7c6339eff67e 100644 --- a/tests/test_device/test_init_logical_pg.py +++ b/tests/test_device/test_init_logical_pg.py @@ -20,16 +20,12 @@ def check_layer(rank, world_size, port): # [[0, 1, # [2, 3]] device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) - logical_pg_dict = {0: [[0, 2], [1, 3]], 1: [[0, 1], [2, 3]]} - logical_process_groups = device_mesh.process_groups_dict - - for mesh_dim, pgs in logical_pg_dict.items(): - for index, pg in enumerate(pgs): - if rank in pg: - tensor = torch.ones(4).cuda() - group = logical_process_groups[mesh_dim][index][1] - dist.all_reduce(tensor, op=ReduceOp.SUM, group=group) - assert tensor.equal(tensor_to_check) + + for axis in range(len(mesh_shape)): + tensor = torch.ones(4).cuda() + pg = device_mesh.get_process_group(axis=axis) + dist.all_reduce(tensor, op=ReduceOp.SUM, group=pg) + assert tensor.equal(tensor_to_check) gpc.destroy() diff --git a/tests/test_fx/test_tracer/test_hf_model/hf_tracer_utils.py b/tests/test_fx/test_tracer/test_hf_model/hf_tracer_utils.py index 7a4bf131ae36..58c8132e1490 100644 --- a/tests/test_fx/test_tracer/test_hf_model/hf_tracer_utils.py +++ b/tests/test_fx/test_tracer/test_hf_model/hf_tracer_utils.py @@ -1,3 +1,5 @@ +from typing import List + import torch from numpy import isin from torch.fx import GraphModule @@ -7,19 +9,23 @@ from colossalai._analyzer.fx import symbolic_trace -def trace_model_and_compare_output(model, data_gen): +def trace_model_and_compare_output(model, data_gen, ignore_data: List[str] = None): # must turn on eval mode to ensure the output is consistent model.eval() + inputs = data_gen() + + if ignore_data is not None: + # drop the ignore_data key + inputs = {k: v for k, v in inputs.items() if k not in ignore_data} + try: - kwargs = data_gen() - meta_args = {k: v.to('meta') for k, v in kwargs.items()} + meta_args = {k: v.to('meta') for k, v in inputs.items()} gm = symbolic_trace(model, meta_args=meta_args) except Exception as e: raise RuntimeError(f"Failed to trace {model.__class__.__name__}, error: {e}") # run forward - inputs = data_gen() non_fx_out = model(**inputs) fx_out = gm(**inputs) diff --git a/tests/test_fx/test_tracer/test_hf_model/test_hf_albert.py b/tests/test_fx/test_tracer/test_hf_model/test_hf_albert.py index f4d681221191..a1470400ad82 100644 --- a/tests/test_fx/test_tracer/test_hf_model/test_hf_albert.py +++ b/tests/test_fx/test_tracer/test_hf_model/test_hf_albert.py @@ -15,7 +15,7 @@ def test_albert(): sub_registry = model_zoo.get_sub_registry('transformers_albert') - for name, (model_fn, data_gen_fn, _, _) in sub_registry.items(): + for name, (model_fn, data_gen_fn, _, _, _) in sub_registry.items(): model = model_fn() trace_model_and_compare_output(model, data_gen_fn) diff --git a/tests/test_fx/test_tracer/test_hf_model/test_hf_bert.py b/tests/test_fx/test_tracer/test_hf_model/test_hf_bert.py index a833bb30c056..632ad366ccc4 100644 --- a/tests/test_fx/test_tracer/test_hf_model/test_hf_bert.py +++ b/tests/test_fx/test_tracer/test_hf_model/test_hf_bert.py @@ -12,9 +12,9 @@ def test_bert(): sub_registry = model_zoo.get_sub_registry('transformers_bert') - for name, (model_fn, data_gen_fn, _, _) in sub_registry.items(): + for name, (model_fn, data_gen_fn, _, _, _) in sub_registry.items(): model = model_fn() - trace_model_and_compare_output(model, data_gen_fn) + trace_model_and_compare_output(model, data_gen_fn, ignore_data=['labels', 'next_sentence_label']) if __name__ == '__main__': diff --git a/tests/test_fx/test_tracer/test_hf_model/test_hf_diffuser.py b/tests/test_fx/test_tracer/test_hf_model/test_hf_diffuser.py index ccbe2da58bf2..ac87a7fcb13b 100644 --- a/tests/test_fx/test_tracer/test_hf_model/test_hf_diffuser.py +++ b/tests/test_fx/test_tracer/test_hf_model/test_hf_diffuser.py @@ -47,7 +47,7 @@ def test_diffusers(): sub_model_zoo = model_zoo.get_sub_registry('diffusers') - for name, (model_fn, data_gen_fn, output_transform_fn, _, attribute) in sub_model_zoo.items(): + for name, (model_fn, data_gen_fn, output_transform_fn, _, _, attribute) in sub_model_zoo.items(): data = data_gen_fn() trace_and_compare(model_fn, data, output_transform_fn) torch.cuda.synchronize() diff --git a/tests/test_fx/test_tracer/test_hf_model/test_hf_gpt.py b/tests/test_fx/test_tracer/test_hf_model/test_hf_gpt.py index 67107469d8bb..31bcb7028e25 100644 --- a/tests/test_fx/test_tracer/test_hf_model/test_hf_gpt.py +++ b/tests/test_fx/test_tracer/test_hf_model/test_hf_gpt.py @@ -12,7 +12,7 @@ def test_gpt(): sub_registry = model_zoo.get_sub_registry('transformers_gpt') - for name, (model_fn, data_gen_fn, _, _) in sub_registry.items(): + for name, (model_fn, data_gen_fn, _, _, _) in sub_registry.items(): model = model_fn() # TODO: support the following models @@ -21,7 +21,7 @@ def test_gpt(): if model.__class__.__name__ in ['GPT2DoubleHeadsModel']: continue - trace_model_and_compare_output(model, data_gen_fn) + trace_model_and_compare_output(model, data_gen_fn, ignore_data=['labels']) if __name__ == '__main__': diff --git a/tests/test_fx/test_tracer/test_hf_model/test_hf_opt.py b/tests/test_fx/test_tracer/test_hf_model/test_hf_opt.py index 369545b03de1..f528db6a64ef 100644 --- a/tests/test_fx/test_tracer/test_hf_model/test_hf_opt.py +++ b/tests/test_fx/test_tracer/test_hf_model/test_hf_opt.py @@ -12,7 +12,7 @@ def test_opt(): sub_registry = model_zoo.get_sub_registry('transformers_opt') - for name, (model_fn, data_gen_fn, _, _) in sub_registry.items(): + for name, (model_fn, data_gen_fn, _, _, _) in sub_registry.items(): model = model_fn() trace_model_and_compare_output(model, data_gen_fn) diff --git a/tests/test_fx/test_tracer/test_hf_model/test_hf_t5.py b/tests/test_fx/test_tracer/test_hf_model/test_hf_t5.py index 811cf3b21430..45e06bc2bbb0 100644 --- a/tests/test_fx/test_tracer/test_hf_model/test_hf_t5.py +++ b/tests/test_fx/test_tracer/test_hf_model/test_hf_t5.py @@ -12,9 +12,14 @@ def test_t5(): sub_registry = model_zoo.get_sub_registry('transformers_t5') - for name, (model_fn, data_gen_fn, _, _) in sub_registry.items(): + for name, (model_fn, data_gen_fn, _, _, _) in sub_registry.items(): + if name == "transformers_t5_for_conditional_generation": + # cannot trace for loss function yet + # so we use a data gen which does not produce labels + data_gen_fn = sub_registry.get('transformers_t5')[1] + model = model_fn() - trace_model_and_compare_output(model, data_gen_fn) + trace_model_and_compare_output(model, data_gen_fn, ignore_data=['labels']) if __name__ == '__main__': diff --git a/tests/test_fx/test_tracer/test_timm_model/test_timm_model.py b/tests/test_fx/test_tracer/test_timm_model/test_timm_model.py index 117c70c84aa8..98433b8f7c3b 100644 --- a/tests/test_fx/test_tracer/test_timm_model/test_timm_model.py +++ b/tests/test_fx/test_tracer/test_timm_model/test_timm_model.py @@ -56,7 +56,7 @@ def test_timm_models(): sub_model_zoo = model_zoo.get_sub_registry('timm') - for name, (model_fn, data_gen_fn, output_transform_fn, _, attribute) in sub_model_zoo.items(): + for name, (model_fn, data_gen_fn, output_transform_fn, _, _, attribute) in sub_model_zoo.items(): data = data_gen_fn() if attribute is not None and attribute.has_control_flow: meta_args = {k: v.to('meta') for k, v in data.items()} diff --git a/tests/test_fx/test_tracer/test_torchaudio_model/test_torchaudio_model.py b/tests/test_fx/test_tracer/test_torchaudio_model/test_torchaudio_model.py index f73c5bb9a590..2b7def5bef85 100644 --- a/tests/test_fx/test_tracer/test_torchaudio_model/test_torchaudio_model.py +++ b/tests/test_fx/test_tracer/test_torchaudio_model/test_torchaudio_model.py @@ -16,7 +16,7 @@ def test_torchaudio_models(): sub_model_zoo = model_zoo.get_sub_registry('torchaudio') - for name, (model_fn, data_gen_fn, output_transform_fn, _, attribute) in sub_model_zoo.items(): + for name, (model_fn, data_gen_fn, output_transform_fn, _, _, attribute) in sub_model_zoo.items(): model = model_fn() trace_and_compare(model, data_gen_fn, diff --git a/tests/test_fx/test_tracer/test_torchrec_model/test_deepfm_model.py b/tests/test_fx/test_tracer/test_torchrec_model/test_deepfm_model.py index df02568c0049..f969c8e6c3da 100644 --- a/tests/test_fx/test_tracer/test_torchrec_model/test_deepfm_model.py +++ b/tests/test_fx/test_tracer/test_torchrec_model/test_deepfm_model.py @@ -53,7 +53,7 @@ def test_torchrec_deepfm_models(): deepfm_models = model_zoo.get_sub_registry('deepfm') torch.backends.cudnn.deterministic = True - for name, (model_fn, data_gen_fn, output_transform_fn, attribute) in deepfm_models.items(): + for name, (model_fn, data_gen_fn, output_transform_fn, _, attribute) in deepfm_models.items(): data = data_gen_fn() if attribute is not None and attribute.has_control_flow: meta_args = {k: v.to('meta') for k, v in data.items()} diff --git a/tests/test_fx/test_tracer/test_torchrec_model/test_dlrm_model.py b/tests/test_fx/test_tracer/test_torchrec_model/test_dlrm_model.py index 9776452be9c8..94fb24f33376 100644 --- a/tests/test_fx/test_tracer/test_torchrec_model/test_dlrm_model.py +++ b/tests/test_fx/test_tracer/test_torchrec_model/test_dlrm_model.py @@ -53,7 +53,7 @@ def test_torchrec_dlrm_models(): torch.backends.cudnn.deterministic = True dlrm_models = model_zoo.get_sub_registry('dlrm') - for name, (model_fn, data_gen_fn, output_transform_fn, attribute) in dlrm_models.items(): + for name, (model_fn, data_gen_fn, output_transform_fn, _, attribute) in dlrm_models.items(): data = data_gen_fn() # dlrm_interactionarch is not supported diff --git a/tests/test_fx/test_tracer/test_torchvision_model/test_torchvision_model.py b/tests/test_fx/test_tracer/test_torchvision_model/test_torchvision_model.py index bd259475ae5a..74cb753e2937 100644 --- a/tests/test_fx/test_tracer/test_torchvision_model/test_torchvision_model.py +++ b/tests/test_fx/test_tracer/test_torchvision_model/test_torchvision_model.py @@ -10,7 +10,7 @@ def test_torchvision_models(): torch.backends.cudnn.deterministic = True tv_sub_registry = model_zoo.get_sub_registry('torchvision') - for name, (model_fn, data_gen_fn, output_transform_fn, model_attribute) in tv_sub_registry.items(): + for name, (model_fn, data_gen_fn, output_transform_fn, _, model_attribute) in tv_sub_registry.items(): data = data_gen_fn() if model_attribute is not None and model_attribute.has_stochastic_depth_prob: diff --git a/tests/test_lazy/lazy_init_utils.py b/tests/test_lazy/lazy_init_utils.py index 3879363bcd1b..73c3c5422d8a 100644 --- a/tests/test_lazy/lazy_init_utils.py +++ b/tests/test_lazy/lazy_init_utils.py @@ -6,6 +6,7 @@ import torch from packaging import version +from colossalai.device.device_mesh import DeviceMesh from colossalai.lazy.lazy_init import LazyInitContext, LazyTensor, _MyTensor from colossalai.tensor.d_tensor import to_global from colossalai.tensor.d_tensor.layout import Layout @@ -82,7 +83,8 @@ def check_lazy_init(entry: TestingEntry, seed: int = 42, verbose: bool = False, print(f'{model.__class__.__name__} pass') -def assert_dist_model_equal(model: torch.nn.Module, distributed_model: torch.nn.Module, layout_dict: dict) -> None: +def assert_dist_model_equal(model: torch.nn.Module, distributed_model: torch.nn.Module, device_mesh: DeviceMesh, + sharding_spec_dict: dict) -> None: state = model.state_dict() distributed_state = distributed_model.state_dict() diff --git a/tests/test_lazy/test_distribute.py b/tests/test_lazy/test_distribute.py index f33c037e3de6..622d9deb601d 100644 --- a/tests/test_lazy/test_distribute.py +++ b/tests/test_lazy/test_distribute.py @@ -26,23 +26,19 @@ def find_shard_dim(shape: torch.Size) -> Optional[int]: return dim -def make_layout(device_mesh: DeviceMesh, original_tensor: torch.Tensor) -> Layout: +def make_sharding_spec(original_tensor: torch.Tensor) -> Layout: shard_dim = find_shard_dim(original_tensor.shape) dim_partition_dict = {shard_dim: [0]} if shard_dim is not None else {} target_sharding_spec = ShardingSpec(dim_size=original_tensor.dim(), dim_partition_dict=dim_partition_dict) - layout = Layout(device_mesh=device_mesh, - device_type=torch.device('cuda'), - sharding_spec=target_sharding_spec, - entire_shape=original_tensor.shape) - return layout + return target_sharding_spec def _get_current_name(prefix: str, name: str) -> str: return f'{prefix}.{name}'.lstrip('.') -def generate_layout_dict(model: nn.Module, device_mesh: DeviceMesh) -> dict: - layout_dict = {} +def generate_sharding_spec_dict(model: nn.Module) -> dict: + sharding_spec_dict = {} @torch.no_grad() def generate_recursively(module: nn.Module, prefix: str = ''): @@ -53,17 +49,17 @@ def generate_recursively(module: nn.Module, prefix: str = ''): # initialize tensors directly attached to the current module for name, param in module.named_parameters(recurse=False): if isinstance(param, LazyTensor): - layout = make_layout(device_mesh, param) - layout_dict[_get_current_name(prefix, name)] = layout + sharding_spec = make_sharding_spec(param) + sharding_spec_dict[_get_current_name(prefix, name)] = sharding_spec for name, buf in module.named_buffers(recurse=False): if isinstance(buf, LazyTensor): - layout = make_layout(device_mesh, buf) - layout_dict[_get_current_name(prefix, name)] = layout + sharding_spec = make_sharding_spec(buf) + sharding_spec_dict[_get_current_name(prefix, name)] = sharding_spec generate_recursively(model) - return layout_dict + return sharding_spec_dict @parameterize('subset', ['torchvision', 'diffusers', 'timm', 'transformers', 'torchaudio', 'deepfm', 'dlrm']) @@ -75,7 +71,7 @@ def run_dist_lazy_init(subset, seed: int = 42): for name, entry in sub_model_zoo.items(): # TODO(ver217): lazy init does not support weight norm, skip these models - if name in ('torchaudio_wav2vec2_base', 'torchaudio_hubert_base'): + if name in ('torchaudio_wav2vec2_base', 'torchaudio_hubert_base') or name.startswith('transformers_llama'): continue print_rank_0(name) model_fn, data_gen_fn, output_transform_fn, _, model_attr = entry @@ -85,9 +81,9 @@ def run_dist_lazy_init(subset, seed: int = 42): ctx = LazyInitContext() with ctx: deferred_model = model_fn() - layout_dict = generate_layout_dict(deferred_model, device_mesh) - ctx.distribute(deferred_model, layout_dict, verbose=True) - assert_dist_model_equal(model, deferred_model, layout_dict) + sharding_spec_dict = generate_sharding_spec_dict(deferred_model) + ctx.distribute(deferred_model, device_mesh, sharding_spec_dict, verbose=True) + assert_dist_model_equal(model, deferred_model, device_mesh, sharding_spec_dict) def run_dist(rank, world_size, port) -> None: diff --git a/tests/test_lazy/test_models.py b/tests/test_lazy/test_models.py index f828b23a94c4..4b7aeed73a69 100644 --- a/tests/test_lazy/test_models.py +++ b/tests/test_lazy/test_models.py @@ -10,7 +10,7 @@ def test_torchvision_models_lazy_init(subset): sub_model_zoo = model_zoo.get_sub_registry(subset) for name, entry in sub_model_zoo.items(): # TODO(ver217): lazy init does not support weight norm, skip these models - if name in ('torchaudio_wav2vec2_base', 'torchaudio_hubert_base'): + if name in ('torchaudio_wav2vec2_base', 'torchaudio_hubert_base') or name.startswith('transformers_llama'): continue check_lazy_init(entry, verbose=True) diff --git a/tests/test_tensor/test_dtensor/test_comm_spec.py b/tests/test_tensor/test_dtensor/test_comm_spec.py index 958eabb65fac..95fcd2aaf8f3 100644 --- a/tests/test_tensor/test_dtensor/test_comm_spec.py +++ b/tests/test_tensor/test_dtensor/test_comm_spec.py @@ -122,23 +122,6 @@ def check_all_reduce_bwd(process_groups_dict, rank): assert tensor_to_comm.equal(tensor_to_check) -def check_all_reduce_in_flatten_device_mesh(process_groups_dict, rank): - # tensor to comm - tensor_to_comm = torch.ones(2, 2).cuda() * rank - - # reduce through logical process axis 0 at flatten device mesh - # tensor to check - # tensor([[6., 6.], - # [6., 6.]]) - tensor_to_check = torch.tensor([[6, 6], [6, 6]], dtype=tensor_to_comm.dtype).cuda() - - # CommSpec:(comm_pattern:all_reduce, logical_process_axis:[0, 1]) - comm_spec = CommSpec(CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD, process_groups_dict, logical_process_axis=0) - tensor_to_comm = comm_spec.covert_spec_to_action(tensor_to_comm) - - assert tensor_to_comm.equal(tensor_to_check) - - def check_comm(rank, world_size, port): disable_existing_loggers() launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') @@ -150,24 +133,22 @@ def check_comm(rank, world_size, port): # [[0, 1, # [2, 3]] device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) - process_groups_dict = device_mesh.process_groups_dict + + process_group_dict = device_mesh._process_group_dict[rank] # test all gather - check_all_gather(process_groups_dict, rank) + check_all_gather(process_group_dict, rank) # test shard - check_shard(process_groups_dict, rank) + check_shard(process_group_dict, rank) # test all to all - check_all_to_all(process_groups_dict, rank) + check_all_to_all(process_group_dict, rank) # test all reduce - check_all_reduce_fwd(process_groups_dict, rank) - check_all_reduce_bwd(process_groups_dict, rank) + check_all_reduce_fwd(process_group_dict, rank) + check_all_reduce_bwd(process_group_dict, rank) - flatten_process_groups_dict = device_mesh.flatten_device_mesh.process_groups_dict - # test all reduce in 1D flatten device mesh - check_all_reduce_in_flatten_device_mesh(flatten_process_groups_dict, rank) gpc.destroy() diff --git a/tests/test_tensor/test_dtensor/test_dtensor.py b/tests/test_tensor/test_dtensor/test_dtensor.py index 8350fb3e7fe6..5a1aef79f332 100644 --- a/tests/test_tensor/test_dtensor/test_dtensor.py +++ b/tests/test_tensor/test_dtensor/test_dtensor.py @@ -64,7 +64,7 @@ def check_dtensor(rank, world_size, port): else: raise ValueError(f'rank {rank} is not in the device mesh') - dtensor_from_local = distribute_tensor(original_tensor, new_layout) + dtensor_from_local = distribute_tensor(original_tensor, device_mesh, new_sharding_spec) if rank == 0: assert dtensor_from_local.equal(original_tensor.narrow(0, 0, 1)) diff --git a/tests/test_tensor/test_dtensor/test_layout_converter.py b/tests/test_tensor/test_dtensor/test_layout_converter.py index d9dff8af933d..5388fd901e09 100644 --- a/tests/test_tensor/test_dtensor/test_layout_converter.py +++ b/tests/test_tensor/test_dtensor/test_layout_converter.py @@ -12,9 +12,9 @@ from colossalai.tensor.d_tensor.sharding_spec import ShardingSpec from colossalai.testing import rerun_if_address_is_in_use, spawn -entire_shape = torch.Size((64, 32, 16)) +global_shape = torch.Size((64, 32, 16)) layout_converter = LayoutConverter() -physical_mesh_id = torch.arange(0, 4).reshape(2, 2) +physical_mesh_id = torch.arange(0, 4) mesh_shape = (2, 2) @@ -30,10 +30,7 @@ def check_one_step_transform(rank, world_size, port): # shard_sequence: S0,S1,R # device_mesh_shape: (2, 2) sharding_spec = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_dict) - layout = Layout(device_mesh=device_mesh, - device_type=torch.device('cuda'), - sharding_spec=sharding_spec, - entire_shape=entire_shape) + layout = Layout(device_mesh=device_mesh, sharding_spec=sharding_spec, global_shape=global_shape) rst_dict = layout_converter.all_gather_transform_layouts(layout) @@ -49,10 +46,7 @@ def check_one_step_transform(rank, world_size, port): # shard_sequence: S0,S1,R # device_mesh_shape: (4, 4) sharding_spec_all2all = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_dict_all2all) - layout_all2all = Layout(device_mesh=device_mesh, - device_type=torch.device('cuda'), - sharding_spec=sharding_spec_all2all, - entire_shape=entire_shape) + layout_all2all = Layout(device_mesh=device_mesh, sharding_spec=sharding_spec_all2all, global_shape=global_shape) rst_dict_all2all = layout_converter.all_to_all_transform_layout(layout_all2all) @@ -71,10 +65,7 @@ def check_one_step_transform(rank, world_size, port): # shard_sequence: S0,R,R # device_mesh_shape: (4, 4) sharding_spec_shard = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_shard) - shard_layout = Layout(device_mesh=device_mesh, - device_type=torch.device('cuda'), - sharding_spec=sharding_spec_shard, - entire_shape=entire_shape) + shard_layout = Layout(device_mesh=device_mesh, sharding_spec=sharding_spec_shard, global_shape=global_shape) rst_dict_shard = layout_converter.shard_transform_layout(shard_layout) @@ -100,19 +91,13 @@ def check_layout_converting(rank, world_size, port): # shard_sequence: R,S01,R # device_mesh_shape: (4, 4) sharding_spec_source = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_source) - source_layout = Layout(device_mesh=device_mesh, - device_type=torch.device('cuda'), - sharding_spec=sharding_spec_source, - entire_shape=entire_shape) + source_layout = Layout(device_mesh=device_mesh, sharding_spec=sharding_spec_source, global_shape=global_shape) # DistSpec: # shard_sequence: S01,R,R # device_mesh_shape: (4, 4) sharding_spec_target = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_target) - target_layout = Layout(device_mesh=device_mesh, - device_type=torch.device('cuda'), - sharding_spec=sharding_spec_target, - entire_shape=entire_shape) + target_layout = Layout(device_mesh=device_mesh, sharding_spec=sharding_spec_target, global_shape=global_shape) transform_path, comm_action_sequence = layout_converter.layout_converting(source_layout, target_layout) @@ -137,7 +122,7 @@ def check_layout_converting(rank, world_size, port): assert comm_action_sequence[2].shard_dim == 0 assert comm_action_sequence[2].logical_process_axis == 1 - # checkout cached_spec_pairs_transform_path + # checkout chached_spec_pairs_transform_path assert layout_converter.cached_solution[('[R, S01, R]', '[S01, R, R]')][0] == transform_path assert layout_converter.cached_solution[('[R, S01, R]', '[S01, R, R]')][1] == comm_action_sequence @@ -159,21 +144,15 @@ def check_layout_converting_apply(rank, world_size, port): # shard_sequence: R,S01,R # device_mesh_shape: (4, 4) sharding_spec_source = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_source) - source_layout = Layout(device_mesh=device_mesh, - device_type=torch.device('cuda'), - sharding_spec=sharding_spec_source, - entire_shape=entire_shape) + source_layout = Layout(device_mesh=device_mesh, sharding_spec=sharding_spec_source, global_shape=global_shape) # DistSpec: # shard_sequence: S01,R,R # device_mesh_shape: (4, 4) sharding_spec_target = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_target) - target_layout = Layout(device_mesh=device_mesh, - device_type=torch.device('cuda'), - sharding_spec=sharding_spec_target, - entire_shape=entire_shape) + target_layout = Layout(device_mesh=device_mesh, sharding_spec=sharding_spec_target, global_shape=global_shape) - original_tensor = torch.rand(entire_shape).cuda() + original_tensor = torch.rand(global_shape).cuda() # tensor_to_apply: [R, S01, R] tensor_to_apply = original_tensor.narrow(1, rank * 8, 8) diff --git a/tests/test_tensor/test_shape_consistency.py b/tests/test_tensor/test_shape_consistency.py index 6fe9ee292cd0..859eef051256 100644 --- a/tests/test_tensor/test_shape_consistency.py +++ b/tests/test_tensor/test_shape_consistency.py @@ -1,9 +1,10 @@ -from colossalai.tensor.shape_consistency import ShapeConsistencyManager, CollectiveCommPattern import torch -from colossalai.tensor.sharding_spec import _DimSpec, ShardingSpec + from colossalai.device.device_mesh import DeviceMesh +from colossalai.tensor.shape_consistency import CollectiveCommPattern, ShapeConsistencyManager +from colossalai.tensor.sharding_spec import ShardingSpec, _DimSpec -physical_mesh_id = torch.arange(0, 16).reshape(2, 8) +physical_mesh_id = torch.arange(0, 16) mesh_shape = (4, 4) # [[0, 1, 2, 3], # [4, 5, 6, 7], diff --git a/tests/test_tensor/test_sharded_linear.py b/tests/test_tensor/test_sharded_linear.py index d66d4fec14d1..9bd9805e9b8f 100644 --- a/tests/test_tensor/test_sharded_linear.py +++ b/tests/test_tensor/test_sharded_linear.py @@ -26,7 +26,7 @@ def run_dist(rank, world_size, port): # the mesh is in the following topo # [[0, 1], # [2, 3]] - physical_mesh_id = torch.arange(0, 4).reshape(2, 2) + physical_mesh_id = torch.arange(0, 4) mesh_shape = (2, 2) device_mesh = DeviceMesh(physical_mesh_id, mesh_shape) row_id = rank // 2 diff --git a/tests/test_tensor/test_sharding_spec.py b/tests/test_tensor/test_sharding_spec.py index 909c84ef0f0e..5007c4141849 100644 --- a/tests/test_tensor/test_sharding_spec.py +++ b/tests/test_tensor/test_sharding_spec.py @@ -5,7 +5,7 @@ def test_sharding_spec(): - physical_mesh_id = torch.arange(0, 16).reshape(2, 8) + physical_mesh_id = torch.arange(0, 16) mesh_shape = (4, 4) # [[0, 1, 2, 3], # [4, 5, 6, 7],