diff --git a/.github/workflows/hpu-gaudi2-nightly.yml b/.github/workflows/hpu-gaudi2-nightly.yml index 5c5caff1ebb0..c0576360cd61 100644 --- a/.github/workflows/hpu-gaudi2-nightly.yml +++ b/.github/workflows/hpu-gaudi2-nightly.yml @@ -21,7 +21,7 @@ jobs: # The type of runner that the job will run on runs-on: [self-hosted, intel, gaudi2] container: - image: vault.habana.ai/gaudi-docker/1.18.0/ubuntu22.04/habanalabs/pytorch-installer-2.4.0:latest + image: vault.habana.ai/gaudi-docker/1.19.0/ubuntu22.04/habanalabs/pytorch-installer-2.5.1:latest ports: - 80 options: --runtime=habana -e HABANA_VISIBLE_DEVICES=all -e OMPI_MCA_btl_vader_single_copy_mechanism=none --cap-add=sys_nice diff --git a/.github/workflows/hpu-gaudi2.yml b/.github/workflows/hpu-gaudi2.yml index a06f871b7c56..b8b6f3cb5502 100644 --- a/.github/workflows/hpu-gaudi2.yml +++ b/.github/workflows/hpu-gaudi2.yml @@ -39,7 +39,7 @@ jobs: # The type of runner that the job will run on runs-on: [self-hosted, intel, gaudi2] container: - image: vault.habana.ai/gaudi-docker/1.18.0/ubuntu22.04/habanalabs/pytorch-installer-2.4.0:latest + image: vault.habana.ai/gaudi-docker/1.19.0/ubuntu22.04/habanalabs/pytorch-installer-2.5.1:latest ports: - 80 options: --runtime=habana -e HABANA_VISIBLE_DEVICES=all -e OMPI_MCA_btl_vader_single_copy_mechanism=none --cap-add=sys_nice diff --git a/.github/workflows/nv-ds-chat.yml b/.github/workflows/nv-ds-chat.yml index 329a1060f5eb..7e209cbe4397 100644 --- a/.github/workflows/nv-ds-chat.yml +++ b/.github/workflows/nv-ds-chat.yml @@ -43,7 +43,7 @@ jobs: - name: Install deepspeed run: | - pip install transformers==4.45.2 + pip install transformers pip install .[dev] ds_report diff --git a/SECURITY.md b/SECURITY.md index 9e9391ee0bac..3061748e610b 100644 --- a/SECURITY.md +++ b/SECURITY.md @@ -39,3 +39,7 @@ We prefer all communications to be in English. Microsoft follows the principle of [Coordinated Vulnerability Disclosure](https://www.microsoft.com/en-us/msrc/cvd). + +--- + +Please see [PyTorch's Security Policy](https://github.com/pytorch/pytorch/blob/main/SECURITY.md) for more information and recommendations on how to securely interact with models. diff --git a/blogs/windows/08-2024/README.md b/blogs/windows/08-2024/README.md index 34e11bd47792..8a23372a1d64 100644 --- a/blogs/windows/08-2024/README.md +++ b/blogs/windows/08-2024/README.md @@ -48,7 +48,7 @@ Regardless of the installation choice, you can check that the installation was s We use an image classification model, CIFAR10, and a language model, BERT, to demonstrate pretraining on Windows with DeepSpeed. ## Pretraining CIFAR10 -The scripts and codes required for CIFAR10 pretraining example are available in the following path: DeepSpeedExamples\training\cifar. You can launch the CIFAR10 pretraining experiment using the following command: `deepspeed cifar10_deepspeed.py –deepspeed`. The final output should look something like this: +The scripts and codes required for CIFAR10 pretraining example are available in the following path: DeepSpeedExamples\training\cifar. You can launch the CIFAR10 pretraining experiment using the following command: `deepspeed cifar10_deepspeed.py --deepspeed`. The final output should look something like this:
diff --git a/deepspeed/inference/engine.py b/deepspeed/inference/engine.py index cfca1ff4fe4c..131dce07d22d 100755 --- a/deepspeed/inference/engine.py +++ b/deepspeed/inference/engine.py @@ -80,7 +80,6 @@ def __init__(self, model, config): self.mp_group = config.tensor_parallel.tp_group self.mpu = config.tensor_parallel.mpu - #self._validate_args(self.mpu, config.replace_with_kernel_inject) self.quantize_merge_count = 1 self.quantization_scales = None @@ -300,29 +299,6 @@ def _init_quantization_setting(self, quantization_setting): f"mlp_extra_grouping = {self.mlp_extra_grouping}, " f"quantize_groups = {self.quantize_groups}", [0]) - # TODO: remove this function and add this functionality to pydantic config checking - def _validate_args(self, mpu, replace_with_kernel_inject): - # TODO: to support SD pipeline we need to avoid this check for now - if replace_with_kernel_inject and not isinstance(self.module, Module): - raise ValueError(f"model must be a torch.nn.Module, got {type(self.module)}") - if not isinstance(self._config.tensor_parallel.tp_size, int) or self._config.tensor_parallel.tp_size < 1: - raise ValueError(f"mp_size must be an int >= 1, got {self._config.tensor_parallel.tp_size}") - - if mpu: - methods = ["get_model_parallel_group", "get_data_parallel_group"] - for method in methods: - if not hasattr(mpu, method): - raise ValueError(f"mpu is missing {method}") - if self._config.checkpoint is not None and not isinstance(self._config.checkpoint, (str, dict)): - raise ValueError(f"checkpoint must be None, str or dict, got {type(self._config.checkpoint)}") - - supported_dtypes = [None, torch.half, torch.int8, torch.float, torch.bfloat16] - if self._config.dtype not in supported_dtypes: - raise ValueError(f"{self._config.dtype} not supported, valid dtype: {supported_dtypes}") - - if self.injection_dict is not None and not isinstance(self.injection_dict, dict): - raise ValueError(f"injection_dict must be None or a dict, got: {self.injection_dict}") - def load_model_with_checkpoint(self, r_module): self.mp_replace = ReplaceWithTensorSlicing( mp_group=self.mp_group, mp_size=self._config.tensor_parallel.tp_size) #, out_dim=0, in_dim=1) diff --git a/deepspeed/module_inject/auto_tp.py b/deepspeed/module_inject/auto_tp.py index 5441000e581d..66d7c2659359 100755 --- a/deepspeed/module_inject/auto_tp.py +++ b/deepspeed/module_inject/auto_tp.py @@ -134,7 +134,8 @@ def is_load_module(module): load_layer_names = [ "LPLayerNorm", "SharedEmbedding", "OPTLearnedPositionalEmbedding", "LlamaRMSNorm", "FalconLinear", "MistralRMSNorm", "T5LayerNorm", "MixtralRMSNorm", "Phi3RotaryEmbedding", "Phi3SuScaledRotaryEmbedding", - "Phi3RMSNorm", "YuanRMSNorm", "YuanRotaryEmbedding", "Phi3LongRoPEScaledRotaryEmbedding", "Qwen2RMSNorm" + "Phi3RMSNorm", "YuanRMSNorm", "YuanRotaryEmbedding", "Phi3LongRoPEScaledRotaryEmbedding", "Qwen2RMSNorm", + "DeepseekV2RMSNorm", "DeepseekV2YarnRotaryEmbedding", "MoEGate" ] return module.__class__ in load_layers or module._get_name() in load_layer_names @@ -332,9 +333,9 @@ def _replace(self, child, name, conv_linear_layer): return weight_shape = child.weight.shape mp_replace = ReplaceWithTensorSlicing(mp_group=self.mp_group) - # For mixtral-7x8b, need to skip MoE gate linear replace. - if name == "block_sparse_moe.gate" or (('mlp.shared_expert_gate' == name or 'mlp.gate' == name) - and 'qwen2_moe' in str(type(self.module))): + # For TP layer skip, e.g., MoE gate, deepseek low rank layer skip + if "q_a_proj" in name or "kv_a_proj_with_mqa" in name or name == "block_sparse_moe.gate" or ( + ('mlp.shared_expert_gate' == name or 'mlp.gate' == name) and 'qwen2_moe' in str(type(self.module))): return child # For Yuan model if 'Yuan' in str(self.module): @@ -350,11 +351,15 @@ def _replace(self, child, name, conv_linear_layer): arctic_w2_all_reduce_linear = False if 'Arctic' in str(self.module) and 'w2' in name: arctic_w2_all_reduce_linear = True + # For MoE MLP model, e.g., deepseek and jamba + down_proj = False + if 'down_proj' in name: + down_proj = True # For MLP including chunk layer. if 'gate_up_proj' in name or ('dense_h_to_4h' in name and 'GLM' in str(self.module)): weight, bias = shard_chunk_mlp(child.weight.data, child.bias, dist.get_rank(), dist.get_world_size()) return LinearLayer(weight=weight, bias=bias) - if name in self.all_reduce_linears or arctic_w2_all_reduce_linear: + if name in self.all_reduce_linears or arctic_w2_all_reduce_linear or down_proj: # if conv_linear_layer [weight_shape[1], weight_shape[0] // mp_size] # else [weight_shape[0], weight_shape[1] // mp_size] diff --git a/deepspeed/module_inject/containers/bloom.py b/deepspeed/module_inject/containers/bloom.py index a78ac8120346..7a9b9ca2065b 100644 --- a/deepspeed/module_inject/containers/bloom.py +++ b/deepspeed/module_inject/containers/bloom.py @@ -19,6 +19,18 @@ class DS_BloomContainer(MetaTensorContainer, HybridEngineContainer, BaseTransformerContainer): def __init__(self, **kwargs): + # Check transformers version, error if > 4.43.4 (breaks at 4.44.0) + from importlib.metadata import version + v_transformers = version('transformers') + vers = v_transformers.split('.') + major = int(vers[0]) + minor = int(vers[1]) + if major > 4 or (major == 4 and minor > 43): + import sys + sys.exit( + f"Transformers version {v_transformers} exceeds version 4.43.4! After transformers version 4.43.4, BLOOM inference with DeepSpeed is no longer supported." + ) + super().__init__(**kwargs) # All model specific things should be defined here instead of the base class. diff --git a/deepspeed/module_inject/layers.py b/deepspeed/module_inject/layers.py index 22d24820d404..722ba413a671 100644 --- a/deepspeed/module_inject/layers.py +++ b/deepspeed/module_inject/layers.py @@ -191,7 +191,7 @@ def __init__(self, weight_shape=None, weight=None, bias=None): self.offset = 2 super().__init__(weight_shape, weight=weight) - def forward(self, attention_mask: torch.LongTensor, past_key_values_length: int = 0): + def forward(self, attention_mask: torch.LongTensor, past_key_values_length: int = 0, position_ids: int = 0): """`input_ids_shape` is expected to be [bsz x seqlen].""" attention_mask = attention_mask.long() diff --git a/deepspeed/module_inject/tp_shard.py b/deepspeed/module_inject/tp_shard.py index 3e6fc2b63ef1..ded262edcf61 100644 --- a/deepspeed/module_inject/tp_shard.py +++ b/deepspeed/module_inject/tp_shard.py @@ -42,11 +42,16 @@ def get_num_attention_heads(): def get_shard_size(total_size, mp_size, name=None, rank=None): global num_kv_heads last_linear = ["lm_head", "embed_out"] + # MoE MLP layer use near even division will get better perf. + moe_mlp_layer = ["gate_proj", "up_proj", "down_proj", "w1", "w2", "w3"] + not_moe_mlp_layer = True + if name != None and any(s in str(name) for s in moe_mlp_layer): + not_moe_mlp_layer = False # When we have num_kv_heads defined, uneven division is possible, otherwise enforce near even division if rank == None: rank = dist.get_rank() if num_kv_heads != None and total_size % num_kv_heads == 0 and "mlp" not in str(name) and str( - name) not in last_linear: + name) not in last_linear and not_moe_mlp_layer: my_slices = (num_kv_heads // mp_size) + (1 if rank < (num_kv_heads % mp_size) else 0) return total_size * my_slices // num_kv_heads else: diff --git a/deepspeed/ops/fp_quantizer/__init__.py b/deepspeed/ops/fp_quantizer/__init__.py index 51377bc6092c..f9cf23373c26 100644 --- a/deepspeed/ops/fp_quantizer/__init__.py +++ b/deepspeed/ops/fp_quantizer/__init__.py @@ -4,9 +4,4 @@ # DeepSpeed Team from .quantize import FP_Quantize, Quantizer - -try: - import triton - from .fp8_gemm import matmul_fp8 -except ImportError: - pass +from .fp8_gemm import matmul_fp8 diff --git a/deepspeed/ops/fp_quantizer/fp8_gemm.py b/deepspeed/ops/fp_quantizer/fp8_gemm.py index 55504e3af8c9..db4fa5ae2c92 100644 --- a/deepspeed/ops/fp_quantizer/fp8_gemm.py +++ b/deepspeed/ops/fp_quantizer/fp8_gemm.py @@ -11,161 +11,18 @@ ################################### import torch -import triton -import triton.language as tl -@triton.jit -def matmul_kernel_fp8_bf16(inp_ptr, weight_ptr, out_ptr, scale_ptr, M, N, K, stride_am, stride_ak, stride_bk, - stride_bn, stride_cm, stride_cn, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, - BLOCK_SIZE_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr, - quantization_group_size: tl.constexpr): - pid = tl.program_id(axis=0) - num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) - num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) - num_pid_in_group = GROUP_SIZE_M * num_pid_n - group_id = pid // num_pid_in_group - first_pid_m = group_id * GROUP_SIZE_M - group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) - pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) - pid_n = (pid % num_pid_in_group) // group_size_m +def matmul_fp8(inp, weight, scale, quantization_group_size, quantizer): + from deepspeed import get_accelerator - offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M - offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N - offs_k = tl.arange(0, BLOCK_SIZE_K) + if not get_accelerator().is_triton_supported(): + return matmul_fp8_fallback(inp, weight, scale, quantization_group_size, quantizer) + else: + # Import dynamically to prevent failures on systems without triton. + from .fp8_gemm_triton import matmul_fp8_triton + return matmul_fp8_triton(inp, weight, scale, quantization_group_size) - inp_data = inp_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) - weight_data = weight_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) - weight_ptrs_offset = offs_k[:, None] * (stride_bk // quantization_group_size) + ( - (pid_n * BLOCK_SIZE_N) // quantization_group_size) - weight = tl.load(weight_data, mask=offs_k[:, None] < K, other=0.0) - scale = tl.load(scale_ptr + weight_ptrs_offset) - - accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) - for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): - inp = tl.load(inp_data, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) - # Dequantize weight (fp8 -> bf16) - w = (((weight & 0x80) << 8) | ((weight & 0x7f) << 4)).to(tl.uint16) - w = (w + 0x3C00).to(tl.uint16) - w = (w.to(tl.bfloat16, bitcast=True) * scale).to(tl.bfloat16) - - inp_data += BLOCK_SIZE_K * stride_ak - weight_data += BLOCK_SIZE_K * stride_bk - weight_mask = offs_k[:, None] < K - (k + 1) * BLOCK_SIZE_K - weight = tl.load(weight_data, mask=weight_mask, other=0.0) - scale = tl.load(scale_ptr + (weight_ptrs_offset + - (((k + 1) * BLOCK_SIZE_K * stride_bk) // quantization_group_size)), - mask=weight_mask, - other=0.0) - - accumulator += tl.dot(inp, w) - - out = accumulator.to(tl.bfloat16) - - offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) - offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) - out_data = out_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] - tl.store(out_data, out, mask=(offs_cm[:, None] < M) & (offs_cn[None, :] < N)) - - -@triton.jit -def matmul_kernel_fp8_fp16(inp_ptr, weight_ptr, out_ptr, scale_ptr, M, N, K, stride_am, stride_ak, stride_bk, - stride_bn, stride_cm, stride_cn, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, - BLOCK_SIZE_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr, - quantization_group_size: tl.constexpr): - pid = tl.program_id(axis=0) - num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) - num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) - num_pid_in_group = GROUP_SIZE_M * num_pid_n - group_id = pid // num_pid_in_group - first_pid_m = group_id * GROUP_SIZE_M - group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) - pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) - pid_n = (pid % num_pid_in_group) // group_size_m - - offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M - offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N - offs_k = tl.arange(0, BLOCK_SIZE_K) - - inp_data = inp_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) - weight_data = weight_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) - weight_ptrs_offset = offs_k[:, None] * (stride_bk // quantization_group_size) + ( - (pid_n * BLOCK_SIZE_N) // quantization_group_size) - - weight = tl.load(weight_data, mask=offs_k[:, None] < K, other=0.0) - scale = tl.load(scale_ptr + weight_ptrs_offset) - - accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) - for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): - inp = tl.load(inp_data, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) - # Dequantize weight (fp8 -> fp16) - w = (((weight & 0x80) << 8) | ((weight & 0x7f) << 7)).to(tl.uint16) - w = (w + 0x2000).to(tl.uint16) - w = (w.to(tl.float16, bitcast=True) * scale).to(tl.float16) - - inp_data += BLOCK_SIZE_K * stride_ak - weight_data += BLOCK_SIZE_K * stride_bk - - weight = tl.load(weight_data, mask=offs_k[:, None] < K - (k + 1) * BLOCK_SIZE_K, other=0.0) - scale = tl.load(scale_ptr + (weight_ptrs_offset + - (((k + 1) * BLOCK_SIZE_K * stride_bk) // quantization_group_size))) - - accumulator += tl.dot(inp, w) - - out = accumulator.to(tl.float16) - - offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) - offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) - out_data = out_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] - tl.store(out_data, out, mask=(offs_cm[:, None] < M) & (offs_cn[None, :] < N)) - - -def matmul_fp8(inp, weight, scale, quantization_group_size): - - assert inp.shape[1] == weight.shape[0], \ - f"Incompatible dimensions (input: {inp.shape}, weight: {weight.shape})" - - M, K = inp.shape - K, N = weight.shape - - out = torch.empty((M, N), device=inp.device, dtype=inp.dtype) - - # GEMM tuning parameters! - # TODO: Add a more configurable tuning for selecting the best GeMM - BLOCK_SIZE_M = 16 if M <= 16 else 32 if M <= 32 else 64 if M <= 64 else 128 - BLOCK_SIZE_N = 64 - BLOCK_SIZE_K = max(64, quantization_group_size) - GROUP_SIZE_M = 8 - num_stages = 4 - num_warps = 4 - if M >= 256: - BLOCK_SIZE_M = 256 - BLOCK_SIZE_N = 128 - BLOCK_SIZE_K = max(128, quantization_group_size) - num_stages = 3 - num_warps = 8 - - grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), ) - kernel = matmul_kernel_fp8_bf16 if inp.dtype == torch.bfloat16 else matmul_kernel_fp8_fp16 - kernel[grid](inp, - weight, - out, - scale, - M, - N, - K, - inp.stride(0), - inp.stride(1), - weight.stride(0), - weight.stride(1), - out.stride(0), - out.stride(1), - quantization_group_size=quantization_group_size, - BLOCK_SIZE_M=BLOCK_SIZE_M, - BLOCK_SIZE_N=BLOCK_SIZE_N, - BLOCK_SIZE_K=BLOCK_SIZE_K, - GROUP_SIZE_M=GROUP_SIZE_M, - num_stages=num_stages, - num_warps=num_warps) - return out +def matmul_fp8_fallback(inp, weight, scale, quantization_group_size, quantizer): + return torch.matmul(inp, quantizer.dequantize(weight, scale=scale)) diff --git a/deepspeed/ops/fp_quantizer/fp8_gemm_triton.py b/deepspeed/ops/fp_quantizer/fp8_gemm_triton.py new file mode 100644 index 000000000000..746e217d4194 --- /dev/null +++ b/deepspeed/ops/fp_quantizer/fp8_gemm_triton.py @@ -0,0 +1,171 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +######## Fused MoE kernel ######### +# These kernels are implemented for +# fusing GeMM with dequantization of +# fp8 weight data when using bit-16 +# activation. +################################### + +import torch +import triton +import triton.language as tl + + +@triton.jit +def matmul_kernel_fp8_bf16(inp_ptr, weight_ptr, out_ptr, scale_ptr, M, N, K, stride_am, stride_ak, stride_bk, + stride_bn, stride_cm, stride_cn, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr, + quantization_group_size: tl.constexpr): + pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M + offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N + offs_k = tl.arange(0, BLOCK_SIZE_K) + + inp_data = inp_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) + weight_data = weight_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) + weight_ptrs_offset = offs_k[:, None] * (stride_bk // quantization_group_size) + ( + (pid_n * BLOCK_SIZE_N) // quantization_group_size) + + weight = tl.load(weight_data, mask=offs_k[:, None] < K, other=0.0) + scale = tl.load(scale_ptr + weight_ptrs_offset) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): + inp = tl.load(inp_data, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) + # Dequantize weight (fp8 -> bf16) + w = (((weight & 0x80) << 8) | ((weight & 0x7f) << 4)).to(tl.uint16) + w = (w + 0x3C00).to(tl.uint16) + w = (w.to(tl.bfloat16, bitcast=True) * scale).to(tl.bfloat16) + + inp_data += BLOCK_SIZE_K * stride_ak + weight_data += BLOCK_SIZE_K * stride_bk + weight_mask = offs_k[:, None] < K - (k + 1) * BLOCK_SIZE_K + weight = tl.load(weight_data, mask=weight_mask, other=0.0) + scale = tl.load(scale_ptr + (weight_ptrs_offset + + (((k + 1) * BLOCK_SIZE_K * stride_bk) // quantization_group_size)), + mask=weight_mask, + other=0.0) + + accumulator += tl.dot(inp, w) + + out = accumulator.to(tl.bfloat16) + + offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + out_data = out_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] + tl.store(out_data, out, mask=(offs_cm[:, None] < M) & (offs_cn[None, :] < N)) + + +@triton.jit +def matmul_kernel_fp8_fp16(inp_ptr, weight_ptr, out_ptr, scale_ptr, M, N, K, stride_am, stride_ak, stride_bk, + stride_bn, stride_cm, stride_cn, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr, + quantization_group_size: tl.constexpr): + pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M + offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N + offs_k = tl.arange(0, BLOCK_SIZE_K) + + inp_data = inp_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) + weight_data = weight_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) + weight_ptrs_offset = offs_k[:, None] * (stride_bk // quantization_group_size) + ( + (pid_n * BLOCK_SIZE_N) // quantization_group_size) + + weight = tl.load(weight_data, mask=offs_k[:, None] < K, other=0.0) + scale = tl.load(scale_ptr + weight_ptrs_offset) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): + inp = tl.load(inp_data, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) + # Dequantize weight (fp8 -> fp16) + w = (((weight & 0x80) << 8) | ((weight & 0x7f) << 7)).to(tl.uint16) + w = (w + 0x2000).to(tl.uint16) + w = (w.to(tl.float16, bitcast=True) * scale).to(tl.float16) + + inp_data += BLOCK_SIZE_K * stride_ak + weight_data += BLOCK_SIZE_K * stride_bk + + weight = tl.load(weight_data, mask=offs_k[:, None] < K - (k + 1) * BLOCK_SIZE_K, other=0.0) + scale = tl.load(scale_ptr + (weight_ptrs_offset + + (((k + 1) * BLOCK_SIZE_K * stride_bk) // quantization_group_size))) + + accumulator += tl.dot(inp, w) + + out = accumulator.to(tl.float16) + + offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + out_data = out_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] + tl.store(out_data, out, mask=(offs_cm[:, None] < M) & (offs_cn[None, :] < N)) + + +def matmul_fp8_triton(inp, weight, scale, quantization_group_size): + + assert inp.shape[1] == weight.shape[0], \ + f"Incompatible dimensions (input: {inp.shape}, weight: {weight.shape})" + + M, K = inp.shape + K, N = weight.shape + + out = torch.empty((M, N), device=inp.device, dtype=inp.dtype) + + # GEMM tuning parameters! + # TODO: Add a more configurable tuning for selecting the best GeMM + BLOCK_SIZE_M = 16 if M <= 16 else 32 if M <= 32 else 64 if M <= 64 else 128 + BLOCK_SIZE_N = 64 + BLOCK_SIZE_K = max(64, quantization_group_size) + GROUP_SIZE_M = 8 + num_stages = 4 + num_warps = 4 + if M >= 256: + BLOCK_SIZE_M = 256 + BLOCK_SIZE_N = 128 + BLOCK_SIZE_K = max(128, quantization_group_size) + num_stages = 3 + num_warps = 8 + + grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), ) + kernel = matmul_kernel_fp8_bf16 if inp.dtype == torch.bfloat16 else matmul_kernel_fp8_fp16 + kernel[grid](inp, + weight, + out, + scale, + M, + N, + K, + inp.stride(0), + inp.stride(1), + weight.stride(0), + weight.stride(1), + out.stride(0), + out.stride(1), + quantization_group_size=quantization_group_size, + BLOCK_SIZE_M=BLOCK_SIZE_M, + BLOCK_SIZE_N=BLOCK_SIZE_N, + BLOCK_SIZE_K=BLOCK_SIZE_K, + GROUP_SIZE_M=GROUP_SIZE_M, + num_stages=num_stages, + num_warps=num_warps) + return out diff --git a/deepspeed/ops/transformer/inference/triton/matmul_ext.py b/deepspeed/ops/transformer/inference/triton/matmul_ext.py index 412c8740a216..9be4b0098c37 100644 --- a/deepspeed/ops/transformer/inference/triton/matmul_ext.py +++ b/deepspeed/ops/transformer/inference/triton/matmul_ext.py @@ -19,6 +19,9 @@ # ----------------------------------------------------------------------------- # util class/functions for triton def is_nfs_path(path): + if os.name == 'nt': + return False + # Normalize the path to get the absolute path path = os.path.abspath(path) @@ -99,7 +102,7 @@ def put(self, table): with FileLock(self.lock_path): with open(self.file_path + ".tmp", 'wb') as handle: pickle.dump(table, handle) - os.rename(self.file_path + ".tmp", self.file_path) + os.replace(self.file_path + ".tmp", self.file_path) def load(self): if os.path.exists(self.file_path): diff --git a/deepspeed/runtime/compiler.py b/deepspeed/runtime/compiler.py index fa9220f4fcd0..be778b83f8bb 100644 --- a/deepspeed/runtime/compiler.py +++ b/deepspeed/runtime/compiler.py @@ -4,6 +4,7 @@ # DeepSpeed Team import torch +from deepspeed.utils.torch import required_torch_version try: from torch.compiler import is_compiling as torch_is_compiling @@ -16,7 +17,7 @@ def is_compile_supported(): - return hasattr(torch, "compiler") and hasattr(torch.nn.Module, "compile") + return required_torch_version(min_version=2.1) def disable(func): diff --git a/deepspeed/runtime/domino/transformer.py b/deepspeed/runtime/domino/transformer.py index 8eb95e49c29d..88c5494c8147 100644 --- a/deepspeed/runtime/domino/transformer.py +++ b/deepspeed/runtime/domino/transformer.py @@ -6,8 +6,7 @@ import torch import torch.nn.functional as F from torch.nn.parameter import Parameter -import deepspeed -from deepspeed import comm as dist +import deepspeed.comm as dist from deepspeed.accelerator import get_accelerator @@ -97,7 +96,7 @@ def backward(ctx, grad_output): return grad_output # Async All-reduce. - handle = deepspeed.comm.all_reduce(grad_output, group=ctx.mpu.get_tensor_model_parallel_group(), async_op=True) + handle = dist.all_reduce(grad_output, group=ctx.mpu.get_tensor_model_parallel_group(), async_op=True) ctx.handle_dic[ctx.h_id] = handle return None, grad_output, None, None @@ -249,6 +248,10 @@ def __init__(self, output_bias=None): super(DominoTransformerLayer, self).__init__() + if not dist.is_initialized(): + dist.init_distributed() + assert dist.is_initialized(), "deepspeed.comm is not initialized!" + self.llama_model = config.llama_model self.layer_number = layer_number self.layer_type = layer_type @@ -358,18 +361,14 @@ def forward(self, hidden_states, attention_mask, rotary_pos_emb=None): layernorm_output0, attention_mask, rotary_pos_emb=rotary_pos_emb) - handle0 = deepspeed.comm.all_reduce(attention_output0, - group=self.mpu.get_tensor_model_parallel_group(), - async_op=True) + handle0 = dist.all_reduce(attention_output0, group=self.mpu.get_tensor_model_parallel_group(), async_op=True) attention_output1, attention_bias1 = \ self.self_attention( layernorm_output1, attention_mask, rotary_pos_emb=rotary_pos_emb) - handle1 = deepspeed.comm.all_reduce(attention_output1, - group=self.mpu.get_tensor_model_parallel_group(), - async_op=True) + handle1 = dist.all_reduce(attention_output1, group=self.mpu.get_tensor_model_parallel_group(), async_op=True) handle0.wait() # Residual0 connection. @@ -413,7 +412,7 @@ def forward(self, hidden_states, attention_mask, rotary_pos_emb=None): output0 = output0 + bias_c output0 = self.mlp_activation_func(output0) output0 = torch.matmul(output0, self.weight_r.t()) - handle2 = deepspeed.comm.all_reduce(output0, group=self.mpu.get_tensor_model_parallel_group(), async_op=True) + handle2 = dist.all_reduce(output0, group=self.mpu.get_tensor_model_parallel_group(), async_op=True) handle1.wait() @@ -425,7 +424,7 @@ def forward(self, hidden_states, attention_mask, rotary_pos_emb=None): if bias_c is not None: output1 = output1 + bias_c output1 = torch.matmul(output1, self.weight_r.t()) - deepspeed.comm.all_reduce(output1, group=self.mpu.get_tensor_model_parallel_group()) + dist.all_reduce(output1, group=self.mpu.get_tensor_model_parallel_group()) handle2.wait() diff --git a/deepspeed/runtime/pipe/module.py b/deepspeed/runtime/pipe/module.py index 31fec30be788..49fa2807c355 100644 --- a/deepspeed/runtime/pipe/module.py +++ b/deepspeed/runtime/pipe/module.py @@ -116,7 +116,9 @@ def forward(self, inputs): partition_method (str, optional): The method upon which the layers are partitioned. Defaults to 'parameters'. activation_checkpoint_interval (int, optional): The granularity activation checkpointing in terms of number of layers. 0 disables activation checkpointing. activation_checkpoint_func (callable, optional): The function to use for activation checkpointing. Defaults to ``deepspeed.checkpointing.checkpoint``. - checkpointable_layers(list, optional): Checkpointable layers may not be checkpointed. Defaults to None which does not additional filtering. + checkpointable_layers (list[str], optional): List of layer class names that are eligible for checkpointing. For GPT models, + ParallelTransformerLayerPipe is always checkpointed regardless of this list. If None, all layers with parameters are + considered checkpointable. Defaults to None. dynamic_shape: Allows dynamic shapes of inputs. This might have a performance impact. """ @@ -650,9 +652,17 @@ def _is_checkpointable(self, funcs): # because only non_reentrant_checkpoint can accept inputs with requires_grad=False # otherwise, the backward of the embedding layer won't receive gradients. if self.__class__.__name__ in ('GPTModelPipe', 'GPT2ModelPipe'): - return all('ParallelTransformerLayerPipe' in f.__class__.__name__ for f in funcs) + # For GPT models, checkpoint both transformer layers and any additional + # layers specified in checkpointable_layers (if provided) + return all('ParallelTransformerLayerPipe' in f.__class__.__name__ or ( + self.checkpointable_layers is not None and f.__class__.__name__ in self.checkpointable_layers) + for f in funcs) + if self.checkpointable_layers is not None: + # For non-GPT models, only checkpoint layers specified in checkpointable_layers return all(f.__class__.__name__ in self.checkpointable_layers for f in funcs) + + # Default behavior: checkpoint any layer that has parameters params = [f.parameters() for f in funcs if isinstance(f, torch.nn.Module)] return any(len(list(p)) > 0 for p in params) @@ -662,3 +672,11 @@ def get_additional_losses(self): Return a dictionary of {"loss name": loss_value} or None if no additional losses. """ return None + + def compile(self, *args, **kwargs): + for idx, layer in enumerate(self.forward_funcs): + if isinstance(layer, nn.Module): + layer.compile(*args, **kwargs) + else: + new_layer = torch.compile(layer, *args, **kwargs) + self.forward_funcs[idx] = new_layer diff --git a/deepspeed/runtime/zero/mics.py b/deepspeed/runtime/zero/mics.py index c9ae58a121de..628bf86a61da 100755 --- a/deepspeed/runtime/zero/mics.py +++ b/deepspeed/runtime/zero/mics.py @@ -38,7 +38,7 @@ class MiCS_AllGatherCoalescedHandle(AllGatherCoalescedHandle): def __init__(self, allgather_handle, params: List[Parameter], partitions: List[Tensor], world_size: int) -> None: super().__init__(allgather_handle, params, partitions, world_size) - def wait(self) -> None: + def wait(self, **kwargs) -> None: """ """ # let the current stream to op diff --git a/deepspeed/runtime/zero/parameter_offload.py b/deepspeed/runtime/zero/parameter_offload.py index 0be88a1e1ba6..d5b7bac55146 100644 --- a/deepspeed/runtime/zero/parameter_offload.py +++ b/deepspeed/runtime/zero/parameter_offload.py @@ -145,6 +145,16 @@ def __init__( module.ds_inflight_param_registry = InflightParamRegistry() self.__inflight_param_registry = module.ds_inflight_param_registry + self.fast_sharding_for_leaf_module = False + + if zero_module_granularity_threshold > 0: + self.min_granularity_value = sys.maxsize + self.min_granularity_layer = None + self.granularity_info = set() + self.z3_leaf_layers = [] + self._set_z3_leaf_modules_by_threshold(module, zero_module_granularity_threshold) + self.fast_sharding_for_leaf_module = True + self.param_coordinator = PartitionedParameterCoordinator( prefetch_bucket_sz=self._prefetch_bucket_sz, max_reuse_distance_in_numel=self._max_reuse_distance_in_numel, @@ -155,14 +165,7 @@ def __init__( timers=self.timers, zero_quantized_weights=self.zero_quantized_weights, zero_quantized_nontrainable_weights=self.zero_quantized_nontrainable_weights, - ) - - if zero_module_granularity_threshold > 0: - self.min_granularity_value = sys.maxsize - self.min_granularity_layer = None - self.granularity_info = set() - self.z3_leaf_layers = [] - self._set_z3_leaf_modules_by_threshold(module, zero_module_granularity_threshold) + fast_sharding_for_leaf_module=self.fast_sharding_for_leaf_module) self.forward_hooks = [] self.backward_hooks = [] diff --git a/deepspeed/runtime/zero/partition_parameters.py b/deepspeed/runtime/zero/partition_parameters.py index cb0cd7c8017d..e8cb797b8a5b 100755 --- a/deepspeed/runtime/zero/partition_parameters.py +++ b/deepspeed/runtime/zero/partition_parameters.py @@ -55,7 +55,7 @@ def __init__(self, param: Parameter) -> None: non_blocking=True).view(param.ds_shape) self.__param = param - def wait(self) -> None: + def wait(self, **kwargs) -> None: if not get_accelerator().resolves_data_dependency(): get_accelerator().current_stream().synchronize() self.__param.ds_status = ZeroParamStatus.AVAILABLE @@ -78,7 +78,7 @@ def __init__(self, params: List[Parameter]) -> None: non_blocking=True).view(param.ds_shape) @instrument_w_nvtx - def wait(self) -> None: + def wait(self, **kwargs) -> None: if self.__complete: return @@ -639,7 +639,7 @@ def __init__(self, handle, param: Parameter, quantization=None) -> None: self.__param = param self.__quantization = quantization - def wait(self) -> None: + def wait(self, handle_dependency=True) -> None: instrument_w_nvtx(self.__handle.wait)() if self.__quantization: instrument_w_nvtx(self.__quantization.quant_handle.wait)() @@ -650,6 +650,8 @@ def wait(self) -> None: class AllGatherCoalescedHandle: + data_buffer = [] + def __init__( self, allgather_handle, @@ -672,7 +674,7 @@ def __init__( raise RuntimeError(f"expected param {param.ds_summary()} to not be available") @instrument_w_nvtx - def wait(self) -> None: + def wait(self, handle_dependency=True) -> None: if self.complete: return @@ -704,14 +706,20 @@ def wait(self) -> None: partitions.append(part_to_copy) param.data = instrument_w_nvtx(torch.cat)(partitions).view(param.ds_shape) param.ds_status = ZeroParamStatus.AVAILABLE - - for part_to_copy in partitions: - if not get_accelerator().is_synchronized_device(): + if not get_accelerator().is_synchronized_device() and handle_dependency: + for part_to_copy in partitions: part_to_copy.record_stream(get_accelerator().current_stream()) param_offset += ds_tensor_numel self.complete = True + if not get_accelerator().is_synchronized_device() and not handle_dependency: + # if the device needs to handle dependencies and opts for explicit processing outside the function. + AllGatherCoalescedHandle.data_buffer.append(partitions) + + @staticmethod + def free_buffer(): + AllGatherCoalescedHandle.data_buffer = [] class MultipleAllGatherHandles: @@ -719,9 +727,9 @@ class MultipleAllGatherHandles: def __init__(self, handles: List[AllGatherCoalescedHandle]): self.handles = handles - def wait(self) -> None: + def wait(self, handle_dependency=True) -> None: for handle in self.handles: - handle.wait() + handle.wait(handle_dependency) class AllReduceCoalescedHandle: @@ -1377,13 +1385,13 @@ def all_gather_coalesced(params: Iterable[Parameter], quantization=quant_info, ) - def partition(param_list=None, hierarchy=0, has_been_updated=False): + def partition(param_list=None, hierarchy=0, has_been_updated=False, free_data=True): cls = param print_rank_0(f"{'--'*hierarchy}----Partitioning param {debug_param2name_id_shape_device(cls)}", force=False) if param_list is None: param_list = [cls] - self._partition(param_list, has_been_updated=has_been_updated) + self._partition(param_list, has_been_updated=has_been_updated, free_data=True) def reduce_gradients_at_owner(param_list=None, hierarchy=0): cls = param @@ -1527,12 +1535,12 @@ def _all_gather(self, param_list, async_op=False, hierarchy=None): return handles - def _partition(self, param_list, force=False, has_been_updated=False): + def _partition(self, param_list, force=False, has_been_updated=False, free_data=True): for param in param_list: print_rank_0(f"Before Partitioning Param {param.ds_id}", force=False) if self.zero_param_process_group is not None: self._partition_param_sec(param) - self._partition_param(param, has_been_updated=has_been_updated) + self._partition_param(param, has_been_updated=has_been_updated, free_data=True) param.ds_status = ZeroParamStatus.NOT_AVAILABLE # if param.ds_tensor is not None: @@ -1540,7 +1548,7 @@ def _partition(self, param_list, force=False, has_been_updated=False): # "After the parameters are initially partitioned, make sure we are not recreating the partition." #print_rank_0(f"After Partitioning Param {param.ds_id} {param.ds_tensor.size()} {param.ds_tensor}",force=False) @instrument_w_nvtx - def _partition_param(self, param, buffer=None, has_been_updated=False): + def _partition_param(self, param, buffer=None, has_been_updated=False, free_data=True): assert param.ds_status is not ZeroParamStatus.INFLIGHT, f" {param} Cannot partition a param in flight" global reuse_buffers print_rank_0(f"Param id {param.ds_id} status is {param.ds_status}", force=False) @@ -1565,7 +1573,8 @@ def _partition_param(self, param, buffer=None, has_been_updated=False): see_memory_usage(f'Before partitioning param {param.ds_id} {param.shape}', force=False) # param.data does not store anything meaningful in partitioned state - free_param(param) + if free_data: + free_param(param) see_memory_usage(f'After partitioning param {param.ds_id} {param.shape}', force=False) if param.ds_tensor.final_location == OffloadDeviceEnum.nvme: diff --git a/deepspeed/runtime/zero/partitioned_param_coordinator.py b/deepspeed/runtime/zero/partitioned_param_coordinator.py index 596d0e9c20f9..08cb6c0de54f 100644 --- a/deepspeed/runtime/zero/partitioned_param_coordinator.py +++ b/deepspeed/runtime/zero/partitioned_param_coordinator.py @@ -76,18 +76,17 @@ class __ParamInTrace: param: Parameter step_id_last_used_at: int - def __init__( - self, - prefetch_bucket_sz: int, - max_reuse_distance_in_numel: int, - max_available_parameters_in_numel: int, - allgather_stream: get_accelerator().Stream, - inflight_param_registry: InflightParamRegistry, - prefetch_nvme: bool = False, - timers=None, - zero_quantized_weights=False, - zero_quantized_nontrainable_weights=False, - ) -> None: + def __init__(self, + prefetch_bucket_sz: int, + max_reuse_distance_in_numel: int, + max_available_parameters_in_numel: int, + allgather_stream: get_accelerator().Stream, + inflight_param_registry: InflightParamRegistry, + prefetch_nvme: bool = False, + timers=None, + zero_quantized_weights=False, + zero_quantized_nontrainable_weights=False, + fast_sharding_for_leaf_module=False) -> None: # mapping of param -> handle for each param that is currently in flight self.__inflight_param_registry = inflight_param_registry # keeps track of the number of submodules invoked so far. @@ -130,6 +129,10 @@ def __init__( self.__max_ongoing_fetch_events: int = 2 self.__profiler = PartitionedParameterProfiler(timers if ENABLE_PROFILER else None) + # whether to enable fast fetch for the z3 leaf module. + # this will improve fetch speed but will not break down leaf module parameters to alleviate memory pressure. + self.fast_sharding_for_leaf_module = fast_sharding_for_leaf_module + """Tracing and Tracking TODO. consider performing trace before initializing PartitionedParameterCoordinator and passing trace results into constructor. This way all the code in here can @@ -308,6 +311,7 @@ def fetch_sub_module(self, current_submodule: Module, forward: bool) -> None: wait_numel = 0 wait_event_name = __class__.FORWARD_FETCH_WAIT if forward else __class__.BACKWARD_FETCH_WAIT self.__profiler.start_event(wait_event_name) + fast_fetch = self.fast_sharding_for_leaf_module and z3_leaf_module(current_submodule) # wait for parameters in the immediately needed submodule to become available for param in params_to_fetch: param.ds_active_sub_modules.add(current_submodule.id) @@ -321,9 +325,9 @@ def fetch_sub_module(self, current_submodule: Module, forward: bool) -> None: if len(self.__ongoing_fetch_events) > self.__max_ongoing_fetch_events: self.__ongoing_fetch_events.popleft().synchronize() - self.__inflight_param_registry.pop(param).wait() + self.__inflight_param_registry.pop(param).wait(handle_dependency=not fast_fetch) - if not get_accelerator().handles_memory_backpressure(): + if not get_accelerator().handles_memory_backpressure() and not fast_fetch: event = get_accelerator().Event() event.record() self.__ongoing_fetch_events.append(event) @@ -331,6 +335,8 @@ def fetch_sub_module(self, current_submodule: Module, forward: bool) -> None: assert param.ds_status == ZeroParamStatus.AVAILABLE, param.ds_summary() if not get_accelerator().resolves_data_dependency(): get_accelerator().current_stream().wait_stream(self.__allgather_stream) + if fast_fetch: + AllGatherCoalescedHandle.free_buffer() self.__profiler.stop_event(wait_event_name, wait_numel) # kick off parameter prefetches for upcoming modules @@ -412,10 +418,20 @@ def release_sub_module(self, submodule: Module) -> None: be released.""" params_to_release = (self.__params_to_release(submodule, self.__step_id) if self.is_complete_trace() else set( p.ds_id for p in iter_params(submodule, recurse=z3_leaf_module(submodule)))) + + free_data = not z3_leaf_module(submodule) or not self.fast_sharding_for_leaf_module + if not free_data: + # wait for the computation to finish and launch as early as possible. + empty_buffer = torch.empty(1, device=get_accelerator().current_device()) + for param in iter_params(submodule, recurse=z3_leaf_module(submodule)): param.ds_active_sub_modules.discard(submodule.id) if param.ds_id in params_to_release and not param.is_external_param: - self.__release_param(param) + self.__release_param(param, free_data) + if not free_data: + if param.ds_id in params_to_release and not param.is_external_param: + # empty buffer ensures that all computations are complete + param.data = empty_buffer @instrument_w_nvtx @torch.no_grad() @@ -490,11 +506,11 @@ def __all_gather_params_(self, params: Set[Parameter], forward: bool, quantize: @compiler.disable @instrument_w_nvtx - def __release_param(self, param: Parameter) -> None: + def __release_param(self, param: Parameter, free_data: bool = True) -> None: if param.ds_status == ZeroParamStatus.AVAILABLE and not param.ds_active_sub_modules: if logger.isEnabledFor(logging.DEBUG): debug_rank0(f"-release: {param.ds_summary()}") - param.partition() + param.partition(free_data=free_data) self.__n_available_params -= param.ds_numel @instrument_w_nvtx diff --git a/deepspeed/runtime/zero/stage3.py b/deepspeed/runtime/zero/stage3.py index 459cffce52c8..28f91cb9b3ab 100644 --- a/deepspeed/runtime/zero/stage3.py +++ b/deepspeed/runtime/zero/stage3.py @@ -16,6 +16,7 @@ from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors from deepspeed.runtime.base_optimizer import ZeROOptimizer from deepspeed.utils import logger +from deepspeed.utils.torch import register_grad_hook from deepspeed.runtime.fp16.loss_scaler import CreateLossScaler from deepspeed.runtime.comm.coalesced_collectives import reduce_scatter_coalesced, all_to_all_quant_reduce, all_to_all_loco_quant_reduce from deepspeed.runtime.utils import inf, is_model_parallel_parameter, get_only_unique_item @@ -1159,7 +1160,6 @@ def overlapping_partition_gradients_reduce_epilogue(self): def create_reduce_and_remove_grad_hooks(self): print_rank_0(f'[Begin] Create gradient reduction hooks') - self.grad_accs = [] self.leaf_parameters = defaultdict(list) for i, param_group in enumerate(self.fp16_groups): for param in param_group: @@ -1172,15 +1172,12 @@ def create_reduce_and_remove_grad_hooks(self): #print(f"After all gather {param.device}, {param.shape}") def wrapper(param): - param_tmp = param.expand_as(param) - grad_acc = param_tmp.grad_fn.next_functions[0][0] @instrument_w_nvtx def reduce_partition_and_remove_grads(*notneeded): self.reduce_ready_partitions_and_remove_grads(param) - self._grad_acc_hooks.append(grad_acc.register_hook(reduce_partition_and_remove_grads)) - self.grad_accs.append(grad_acc) + self._grad_acc_hooks.append(register_grad_hook(param, reduce_partition_and_remove_grads)) #print(f"param grad fn {param.expand_as(param).grad_fn}") if z3_leaf_parameter(param): diff --git a/deepspeed/sequence/fpdt_layer.py b/deepspeed/sequence/fpdt_layer.py index 4fab768ce63c..4fa2cc988a19 100644 --- a/deepspeed/sequence/fpdt_layer.py +++ b/deepspeed/sequence/fpdt_layer.py @@ -47,7 +47,7 @@ def _update_out_and_lse( block_out = block_out.to(torch.float32) block_lse = block_lse.transpose(-2, -1).unsqueeze(dim=-1) - new_lse = lse + torch.log(1 + torch.exp(block_lse - lse)) + new_lse = lse + torch.log1p(torch.exp(block_lse - lse)) out = torch.exp(lse - new_lse) * out + torch.exp(block_lse - new_lse) * block_out diff --git a/deepspeed/utils/torch.py b/deepspeed/utils/torch.py index eb22d3561035..1d32775fe64a 100644 --- a/deepspeed/utils/torch.py +++ b/deepspeed/utils/torch.py @@ -20,3 +20,12 @@ def required_torch_version(min_version=None, max_version=None): return False return True + + +def register_grad_hook(param, hook): + if required_torch_version(min_version=2.1): + return param.register_post_accumulate_grad_hook(hook) + else: + param_tmp = param.expand_as(param) + grad_acc = param_tmp.grad_fn.next_functions[0][0] + return grad_acc.register_hook(hook) diff --git a/op_builder/builder.py b/op_builder/builder.py index 461281d4a569..ab26054bda7d 100644 --- a/op_builder/builder.py +++ b/op_builder/builder.py @@ -415,10 +415,11 @@ def cpu_arch(self): return '-mcpu=native' return '-march=native' - def is_cuda_enable(self): + def get_cuda_compile_flag(self): try: - assert_no_cuda_mismatch(self.name) - return '-D__ENABLE_CUDA__' + if not self.is_rocm_pytorch(): + assert_no_cuda_mismatch(self.name) + return "-D__ENABLE_CUDA__" except MissingCUDAException: print(f"{WARNING} {self.name} cuda is missing or is incompatible with installed torch, " "only cpu ops can be compiled!") @@ -839,7 +840,7 @@ def cxx_args(self): CPU_ARCH = self.cpu_arch() SIMD_WIDTH = self.simd_width() - CUDA_ENABLE = self.is_cuda_enable() + CUDA_ENABLE = self.get_cuda_compile_flag() args += [ CPU_ARCH, '-fopenmp', diff --git a/setup.py b/setup.py index c0452f867b31..cc5eb4a3500c 100755 --- a/setup.py +++ b/setup.py @@ -321,9 +321,9 @@ def op_enabled(op_name): include_package_data=True, scripts=scripts, classifiers=[ - 'Programming Language :: Python :: 3.6', 'Programming Language :: Python :: 3.7', 'Programming Language :: Python :: 3.8', 'Programming Language :: Python :: 3.9', - 'Programming Language :: Python :: 3.10' + 'Programming Language :: Python :: 3.10', 'Programming Language :: Python :: 3.11', + 'Programming Language :: Python :: 3.12' ], license='Apache Software License 2.0', ext_modules=ext_modules, diff --git a/tests/unit/ops/fp_quantizer/test_fp8_gemm.py b/tests/unit/ops/fp_quantizer/test_fp8_gemm.py index d66f7c8cb4cc..a4cf579f5943 100644 --- a/tests/unit/ops/fp_quantizer/test_fp8_gemm.py +++ b/tests/unit/ops/fp_quantizer/test_fp8_gemm.py @@ -14,6 +14,8 @@ from deepspeed.ops.fp_quantizer import FP_Quantize, matmul_fp8 +from deepspeed import get_accelerator + @pytest.mark.parametrize("dtype", [torch.bfloat16], ids=["bf16"]) @pytest.mark.parametrize("q_bits", [8], ids=[ @@ -21,23 +23,19 @@ ]) @pytest.mark.parametrize("M", [1, 2, 4, 8, 32, 64, 128, 256, 512, 1024, 2048]) def test_fp_quant(dtype, q_bits, M): + device_name = get_accelerator().device_name() quantization_group_size = 128 fpq = FP_Quantize(group_size=quantization_group_size) N = 8192 H = 4096 - x = torch.randn(M, H, dtype=dtype, device='cuda') - weight_bf16 = torch.randn(H, N, dtype=dtype, device='cuda') + x = torch.randn(M, H, dtype=dtype, device=device_name) + weight_bf16 = torch.randn(H, N, dtype=dtype, device=device_name) - weight, _ = fpq.quantize(weight_bf16.data, q_bits=8, return_meta_tensor=True) + weight, _ = fpq.quantize(weight_bf16.data, q_bits=q_bits, return_meta_tensor=True) scale = fpq.get_scales() - out = matmul_fp8( - x, - weight, - scale, - quantization_group_size, - ) + out = matmul_fp8(x, weight, scale, quantization_group_size, fpq) out_q = torch.matmul(x, fpq.dequantize(weight, scale=fpq.scale)) diff --git a/tests/unit/ops/transformer/inference/inference_test_utils.py b/tests/unit/ops/transformer/inference/inference_test_utils.py index 9cfcae809f09..d63c51267e51 100644 --- a/tests/unit/ops/transformer/inference/inference_test_utils.py +++ b/tests/unit/ops/transformer/inference/inference_test_utils.py @@ -3,6 +3,8 @@ # DeepSpeed Team +from typing import Tuple + import torch from deepspeed.accelerator import get_accelerator @@ -23,38 +25,22 @@ def get_tolerances(): DTYPES = None -def get_dtypes(): +def get_dtypes(include_float=True): global DTYPES if DTYPES is None: - DTYPES = get_accelerator().supported_dtypes() + DTYPES = [torch.float16, torch.float32] if include_float else [torch.float16] + try: + if get_accelerator().is_bf16_supported(): + DTYPES.append(torch.bfloat16) + except (AssertionError, AttributeError): + pass return DTYPES -def allclose(x, y): +def allclose(x, y, tolerances: Tuple[int, int] = None): assert x.dtype == y.dtype - rtol, atol = get_tolerances()[x.dtype] + if tolerances is None: + rtol, atol = get_tolerances()[x.dtype] + else: + rtol, atol = tolerances return torch.allclose(x, y, rtol=rtol, atol=atol) - - -def assert_almost_equal(x, y, decimal=2, err_msg=''): - import numpy.testing as npt - if isinstance(x, torch.Tensor): - if x.dtype == torch.bfloat16: - x = x.float() - x = x.cpu().detach().numpy() - if isinstance(y, torch.Tensor): - if y.dtype == torch.bfloat16: - y = y.float() - y = y.cpu().detach().numpy() - npt.assert_array_almost_equal(x, y, err_msg=err_msg, decimal=decimal) - - -def max_diff(a, b): - a = a.to(torch.float32).flatten() - b = b.to(torch.float32).flatten() - diff = torch.abs(a - b) - max_diff_indices = torch.argsort(diff)[-1] - print("Max difference indices:", max_diff_indices) - print("Max difference values:", diff[max_diff_indices]) - print(f"{a[max_diff_indices]} vs {b[max_diff_indices]}") - return max_diff_indices diff --git a/tests/unit/ops/transformer/inference/test_attention.py b/tests/unit/ops/transformer/inference/test_attention.py index ecf681542ff6..cae201d747a3 100644 --- a/tests/unit/ops/transformer/inference/test_attention.py +++ b/tests/unit/ops/transformer/inference/test_attention.py @@ -7,7 +7,7 @@ import torch import deepspeed from deepspeed.accelerator import get_accelerator -from .inference_test_utils import assert_almost_equal +from .inference_test_utils import allclose # reference timplementation @@ -88,4 +88,4 @@ def test_attention(BATCH, H, N_CTX, D_HEAD, causal, use_flash, dtype=torch.float use_triton_flash=False, use_ds_attention=False) tri_out = tri_out.reshape((BATCH, N_CTX, H, D_HEAD)).permute(0, 2, 1, 3) - assert_almost_equal(ref_out, tri_out) + assert (allclose(ref_out, tri_out)) diff --git a/tests/unit/ops/transformer/inference/test_bias_add.py b/tests/unit/ops/transformer/inference/test_bias_add.py index f25bbc1be692..eb283924f73c 100644 --- a/tests/unit/ops/transformer/inference/test_bias_add.py +++ b/tests/unit/ops/transformer/inference/test_bias_add.py @@ -15,8 +15,6 @@ if not deepspeed.ops.__compatible_ops__[InferenceBuilder.NAME]: pytest.skip("Inference ops are not available on this system", allow_module_level=True) -torch_minor_version = None - def run_bias_add_reference(activations, bias): return activations + bias diff --git a/tests/unit/ops/transformer/inference/test_bias_gelu.py b/tests/unit/ops/transformer/inference/test_bias_gelu.py index e3a3bad63961..f0a09245e890 100644 --- a/tests/unit/ops/transformer/inference/test_bias_gelu.py +++ b/tests/unit/ops/transformer/inference/test_bias_gelu.py @@ -10,8 +10,8 @@ from deepspeed.ops.op_builder import InferenceBuilder from deepspeed.ops.transformer import DeepSpeedInferenceConfig from deepspeed.ops.transformer.inference.op_binding.bias_gelu import BiasGeluOp +from deepspeed.utils.torch import required_torch_version from .inference_test_utils import allclose, get_dtypes -from packaging import version as pkg_version if not deepspeed.ops.__compatible_ops__[InferenceBuilder.NAME]: pytest.skip("Inference ops are not available on this system", allow_module_level=True) @@ -34,7 +34,7 @@ def run_bias_gelu_ds(activations, bias): @pytest.mark.parametrize("channels", [512, 1232, 4096]) @pytest.mark.parametrize("dtype", get_dtypes()) def test_bias_gelu(batch, sequence, channels, dtype): - if pkg_version.parse(torch.__version__) < pkg_version.parse("1.12"): + if not required_torch_version(min_version=1.12): pytest.skip("gelu implementation matches only after torch 1.12") activations_ds = torch.randn((batch, sequence, channels), dtype=dtype, device=get_accelerator().device_name()) diff --git a/tests/unit/ops/transformer/inference/test_layer_norm.py b/tests/unit/ops/transformer/inference/test_layer_norm.py index 7711daf0d887..4a84add16046 100644 --- a/tests/unit/ops/transformer/inference/test_layer_norm.py +++ b/tests/unit/ops/transformer/inference/test_layer_norm.py @@ -9,7 +9,7 @@ from deepspeed.accelerator import get_accelerator from deepspeed.ops.op_builder import InferenceBuilder from deepspeed.ops.transformer.inference.op_binding.layer_norm import LayerNormOp -from .inference_test_utils import allclose, get_dtypes, assert_almost_equal +from .inference_test_utils import allclose, get_dtypes try: import triton # noqa: F401 # type: ignore from deepspeed.ops.transformer.inference.triton import ( @@ -188,4 +188,4 @@ def test_triton_layer_norm(M, N, dtype, residual, input_bias, eps=1e-5, device=' y_ref = torch.nn.functional.layer_norm(x + res + (x_bias if input_bias else 0), w_shape, weight, bias, eps).to(dtype) # compare - assert_almost_equal(y_tri, y_ref) + assert (allclose(y_tri, y_ref)) diff --git a/tests/unit/ops/transformer/inference/test_matmul.py b/tests/unit/ops/transformer/inference/test_matmul.py index 2ab195ee0115..6f5173bbc827 100644 --- a/tests/unit/ops/transformer/inference/test_matmul.py +++ b/tests/unit/ops/transformer/inference/test_matmul.py @@ -11,8 +11,6 @@ if not deepspeed.ops.__compatible_ops__[InferenceBuilder.NAME]: pytest.skip("Inference ops are not available on this system", allow_module_level=True) -inference_module = None - def allclose(x, y): assert x.dtype == y.dtype diff --git a/tests/unit/pipe/test_pipe_module.py b/tests/unit/pipe/test_pipe_module.py index 05c6a82ef55a..2a8a4b9b7d82 100644 --- a/tests/unit/pipe/test_pipe_module.py +++ b/tests/unit/pipe/test_pipe_module.py @@ -60,9 +60,12 @@ def batch_input(): class TestPipeModuleSequential(DistributedTest): world_size = 2 + # needs to be set for torch.compile: running torch.compile with daemonic process causes an error + non_daemonic_procs = True @pytest.mark.parametrize("activation_checkpoints", [False, True]) - def test(self, sequential_model, simple_config, batch_input, activation_checkpoints): + @pytest.mark.parametrize("use_compile", [False, True]) + def test(self, sequential_model, simple_config, batch_input, activation_checkpoints, use_compile): base_model = copy.deepcopy(sequential_model) base_input = batch_input.clone().detach() base_output = base_model(base_input) @@ -71,7 +74,8 @@ def test(self, sequential_model, simple_config, batch_input, activation_checkpoi pipe_model = copy.deepcopy(sequential_model) pipe_model = PipelineModule(layers=pipe_model, num_stages=2) - + if (use_compile): + pipe_model.compile() # Ensure all parameters are accounted for. my_params = sum(p.numel() for p in pipe_model.parameters()) total_pipe_params = torch.LongTensor([my_params]).to(get_accelerator().device_name()) diff --git a/tests/unit/runtime/activation_checkpointing/test_activation_checkpointing.py b/tests/unit/runtime/activation_checkpointing/test_activation_checkpointing.py index 22a61003b31e..dd3bcd7fb6bd 100644 --- a/tests/unit/runtime/activation_checkpointing/test_activation_checkpointing.py +++ b/tests/unit/runtime/activation_checkpointing/test_activation_checkpointing.py @@ -8,6 +8,7 @@ import pytest import torch import deepspeed +from deepspeed.pipe import PipelineModule, LayerSpec from deepspeed.accelerator import get_accelerator from copy import deepcopy from unit.common import DistributedTest @@ -259,3 +260,52 @@ def test_ckpt_non_tensor_output_ordering(self, non_tensor_output): else: ordering += [torch.is_tensor(non_tensor_output)] _test_activation_checkpoint_ordering(module, ordering, inputs) + + +class TestCheckpointableLayersConfig(DistributedTest): + world_size = 1 + + def test_gpt2_checkpointable_layers(self): + if get_accelerator().device_name() == "cpu": + pytest.skip("CPU accelerator does not support this test yet") + + # Create a simple topology for testing + from deepspeed.runtime.pipe.topology import PipeModelDataParallelTopology + topo = PipeModelDataParallelTopology(num_pp=1, num_mp=1, num_dp=1) + + # Create test classes that we want to checkpoint + class TestTransformerLayer(torch.nn.Module): + + def forward(self, x): + return x + + class ParallelTransformerLayerPipe(TestTransformerLayer): + pass + + class GMLPBlock(TestTransformerLayer): + pass + + # Create a mock GPT2 model with different layer types + class TestGPT2ModelPipe(PipelineModule): + + def __init__(self): + self.layers_spec = [ + LayerSpec(ParallelTransformerLayerPipe), + LayerSpec(GMLPBlock), + LayerSpec(torch.nn.Linear, 10, 10), # Should not be checkpointed + ] + + super().__init__(layers=self.layers_spec, + topology=topo, + checkpointable_layers=["GMLPBlock", "ParallelTransformerLayerPipe"]) + + model = TestGPT2ModelPipe() + model.to(get_accelerator().device_name()) + + # Build layers manually for testing + layers = [spec.build() for spec in model.layers_spec] + + # Test that _is_checkpointable returns correct values + assert model._is_checkpointable([layers[0]]) == True # ParallelTransformerLayerPipe + assert model._is_checkpointable([layers[1]]) == True # GMLPBlock + assert model._is_checkpointable([layers[2]]) == False # Linear layer