diff --git a/.github/workflows/hpu-gaudi2-nightly.yml b/.github/workflows/hpu-gaudi2-nightly.yml
index 5c5caff1ebb0..c0576360cd61 100644
--- a/.github/workflows/hpu-gaudi2-nightly.yml
+++ b/.github/workflows/hpu-gaudi2-nightly.yml
@@ -21,7 +21,7 @@ jobs:
# The type of runner that the job will run on
runs-on: [self-hosted, intel, gaudi2]
container:
- image: vault.habana.ai/gaudi-docker/1.18.0/ubuntu22.04/habanalabs/pytorch-installer-2.4.0:latest
+ image: vault.habana.ai/gaudi-docker/1.19.0/ubuntu22.04/habanalabs/pytorch-installer-2.5.1:latest
ports:
- 80
options: --runtime=habana -e HABANA_VISIBLE_DEVICES=all -e OMPI_MCA_btl_vader_single_copy_mechanism=none --cap-add=sys_nice
diff --git a/.github/workflows/hpu-gaudi2.yml b/.github/workflows/hpu-gaudi2.yml
index a06f871b7c56..b8b6f3cb5502 100644
--- a/.github/workflows/hpu-gaudi2.yml
+++ b/.github/workflows/hpu-gaudi2.yml
@@ -39,7 +39,7 @@ jobs:
# The type of runner that the job will run on
runs-on: [self-hosted, intel, gaudi2]
container:
- image: vault.habana.ai/gaudi-docker/1.18.0/ubuntu22.04/habanalabs/pytorch-installer-2.4.0:latest
+ image: vault.habana.ai/gaudi-docker/1.19.0/ubuntu22.04/habanalabs/pytorch-installer-2.5.1:latest
ports:
- 80
options: --runtime=habana -e HABANA_VISIBLE_DEVICES=all -e OMPI_MCA_btl_vader_single_copy_mechanism=none --cap-add=sys_nice
diff --git a/.github/workflows/nv-ds-chat.yml b/.github/workflows/nv-ds-chat.yml
index 329a1060f5eb..7e209cbe4397 100644
--- a/.github/workflows/nv-ds-chat.yml
+++ b/.github/workflows/nv-ds-chat.yml
@@ -43,7 +43,7 @@ jobs:
- name: Install deepspeed
run: |
- pip install transformers==4.45.2
+ pip install transformers
pip install .[dev]
ds_report
diff --git a/SECURITY.md b/SECURITY.md
index 9e9391ee0bac..3061748e610b 100644
--- a/SECURITY.md
+++ b/SECURITY.md
@@ -39,3 +39,7 @@ We prefer all communications to be in English.
Microsoft follows the principle of [Coordinated Vulnerability Disclosure](https://www.microsoft.com/en-us/msrc/cvd).
+
+---
+
+Please see [PyTorch's Security Policy](https://github.com/pytorch/pytorch/blob/main/SECURITY.md) for more information and recommendations on how to securely interact with models.
diff --git a/blogs/windows/08-2024/README.md b/blogs/windows/08-2024/README.md
index 34e11bd47792..8a23372a1d64 100644
--- a/blogs/windows/08-2024/README.md
+++ b/blogs/windows/08-2024/README.md
@@ -48,7 +48,7 @@ Regardless of the installation choice, you can check that the installation was s
We use an image classification model, CIFAR10, and a language model, BERT, to demonstrate pretraining on Windows with DeepSpeed.
## Pretraining CIFAR10
-The scripts and codes required for CIFAR10 pretraining example are available in the following path: DeepSpeedExamples\training\cifar. You can launch the CIFAR10 pretraining experiment using the following command: `deepspeed cifar10_deepspeed.py –deepspeed`. The final output should look something like this:
+The scripts and codes required for CIFAR10 pretraining example are available in the following path: DeepSpeedExamples\training\cifar. You can launch the CIFAR10 pretraining experiment using the following command: `deepspeed cifar10_deepspeed.py --deepspeed`. The final output should look something like this:
diff --git a/deepspeed/inference/engine.py b/deepspeed/inference/engine.py
index cfca1ff4fe4c..131dce07d22d 100755
--- a/deepspeed/inference/engine.py
+++ b/deepspeed/inference/engine.py
@@ -80,7 +80,6 @@ def __init__(self, model, config):
self.mp_group = config.tensor_parallel.tp_group
self.mpu = config.tensor_parallel.mpu
- #self._validate_args(self.mpu, config.replace_with_kernel_inject)
self.quantize_merge_count = 1
self.quantization_scales = None
@@ -300,29 +299,6 @@ def _init_quantization_setting(self, quantization_setting):
f"mlp_extra_grouping = {self.mlp_extra_grouping}, "
f"quantize_groups = {self.quantize_groups}", [0])
- # TODO: remove this function and add this functionality to pydantic config checking
- def _validate_args(self, mpu, replace_with_kernel_inject):
- # TODO: to support SD pipeline we need to avoid this check for now
- if replace_with_kernel_inject and not isinstance(self.module, Module):
- raise ValueError(f"model must be a torch.nn.Module, got {type(self.module)}")
- if not isinstance(self._config.tensor_parallel.tp_size, int) or self._config.tensor_parallel.tp_size < 1:
- raise ValueError(f"mp_size must be an int >= 1, got {self._config.tensor_parallel.tp_size}")
-
- if mpu:
- methods = ["get_model_parallel_group", "get_data_parallel_group"]
- for method in methods:
- if not hasattr(mpu, method):
- raise ValueError(f"mpu is missing {method}")
- if self._config.checkpoint is not None and not isinstance(self._config.checkpoint, (str, dict)):
- raise ValueError(f"checkpoint must be None, str or dict, got {type(self._config.checkpoint)}")
-
- supported_dtypes = [None, torch.half, torch.int8, torch.float, torch.bfloat16]
- if self._config.dtype not in supported_dtypes:
- raise ValueError(f"{self._config.dtype} not supported, valid dtype: {supported_dtypes}")
-
- if self.injection_dict is not None and not isinstance(self.injection_dict, dict):
- raise ValueError(f"injection_dict must be None or a dict, got: {self.injection_dict}")
-
def load_model_with_checkpoint(self, r_module):
self.mp_replace = ReplaceWithTensorSlicing(
mp_group=self.mp_group, mp_size=self._config.tensor_parallel.tp_size) #, out_dim=0, in_dim=1)
diff --git a/deepspeed/module_inject/auto_tp.py b/deepspeed/module_inject/auto_tp.py
index 5441000e581d..66d7c2659359 100755
--- a/deepspeed/module_inject/auto_tp.py
+++ b/deepspeed/module_inject/auto_tp.py
@@ -134,7 +134,8 @@ def is_load_module(module):
load_layer_names = [
"LPLayerNorm", "SharedEmbedding", "OPTLearnedPositionalEmbedding", "LlamaRMSNorm", "FalconLinear",
"MistralRMSNorm", "T5LayerNorm", "MixtralRMSNorm", "Phi3RotaryEmbedding", "Phi3SuScaledRotaryEmbedding",
- "Phi3RMSNorm", "YuanRMSNorm", "YuanRotaryEmbedding", "Phi3LongRoPEScaledRotaryEmbedding", "Qwen2RMSNorm"
+ "Phi3RMSNorm", "YuanRMSNorm", "YuanRotaryEmbedding", "Phi3LongRoPEScaledRotaryEmbedding", "Qwen2RMSNorm",
+ "DeepseekV2RMSNorm", "DeepseekV2YarnRotaryEmbedding", "MoEGate"
]
return module.__class__ in load_layers or module._get_name() in load_layer_names
@@ -332,9 +333,9 @@ def _replace(self, child, name, conv_linear_layer):
return
weight_shape = child.weight.shape
mp_replace = ReplaceWithTensorSlicing(mp_group=self.mp_group)
- # For mixtral-7x8b, need to skip MoE gate linear replace.
- if name == "block_sparse_moe.gate" or (('mlp.shared_expert_gate' == name or 'mlp.gate' == name)
- and 'qwen2_moe' in str(type(self.module))):
+ # For TP layer skip, e.g., MoE gate, deepseek low rank layer skip
+ if "q_a_proj" in name or "kv_a_proj_with_mqa" in name or name == "block_sparse_moe.gate" or (
+ ('mlp.shared_expert_gate' == name or 'mlp.gate' == name) and 'qwen2_moe' in str(type(self.module))):
return child
# For Yuan model
if 'Yuan' in str(self.module):
@@ -350,11 +351,15 @@ def _replace(self, child, name, conv_linear_layer):
arctic_w2_all_reduce_linear = False
if 'Arctic' in str(self.module) and 'w2' in name:
arctic_w2_all_reduce_linear = True
+ # For MoE MLP model, e.g., deepseek and jamba
+ down_proj = False
+ if 'down_proj' in name:
+ down_proj = True
# For MLP including chunk layer.
if 'gate_up_proj' in name or ('dense_h_to_4h' in name and 'GLM' in str(self.module)):
weight, bias = shard_chunk_mlp(child.weight.data, child.bias, dist.get_rank(), dist.get_world_size())
return LinearLayer(weight=weight, bias=bias)
- if name in self.all_reduce_linears or arctic_w2_all_reduce_linear:
+ if name in self.all_reduce_linears or arctic_w2_all_reduce_linear or down_proj:
# if conv_linear_layer [weight_shape[1], weight_shape[0] // mp_size]
# else [weight_shape[0], weight_shape[1] // mp_size]
diff --git a/deepspeed/module_inject/containers/bloom.py b/deepspeed/module_inject/containers/bloom.py
index a78ac8120346..7a9b9ca2065b 100644
--- a/deepspeed/module_inject/containers/bloom.py
+++ b/deepspeed/module_inject/containers/bloom.py
@@ -19,6 +19,18 @@
class DS_BloomContainer(MetaTensorContainer, HybridEngineContainer, BaseTransformerContainer):
def __init__(self, **kwargs):
+ # Check transformers version, error if > 4.43.4 (breaks at 4.44.0)
+ from importlib.metadata import version
+ v_transformers = version('transformers')
+ vers = v_transformers.split('.')
+ major = int(vers[0])
+ minor = int(vers[1])
+ if major > 4 or (major == 4 and minor > 43):
+ import sys
+ sys.exit(
+ f"Transformers version {v_transformers} exceeds version 4.43.4! After transformers version 4.43.4, BLOOM inference with DeepSpeed is no longer supported."
+ )
+
super().__init__(**kwargs)
# All model specific things should be defined here instead of the base class.
diff --git a/deepspeed/module_inject/layers.py b/deepspeed/module_inject/layers.py
index 22d24820d404..722ba413a671 100644
--- a/deepspeed/module_inject/layers.py
+++ b/deepspeed/module_inject/layers.py
@@ -191,7 +191,7 @@ def __init__(self, weight_shape=None, weight=None, bias=None):
self.offset = 2
super().__init__(weight_shape, weight=weight)
- def forward(self, attention_mask: torch.LongTensor, past_key_values_length: int = 0):
+ def forward(self, attention_mask: torch.LongTensor, past_key_values_length: int = 0, position_ids: int = 0):
"""`input_ids_shape` is expected to be [bsz x seqlen]."""
attention_mask = attention_mask.long()
diff --git a/deepspeed/module_inject/tp_shard.py b/deepspeed/module_inject/tp_shard.py
index 3e6fc2b63ef1..ded262edcf61 100644
--- a/deepspeed/module_inject/tp_shard.py
+++ b/deepspeed/module_inject/tp_shard.py
@@ -42,11 +42,16 @@ def get_num_attention_heads():
def get_shard_size(total_size, mp_size, name=None, rank=None):
global num_kv_heads
last_linear = ["lm_head", "embed_out"]
+ # MoE MLP layer use near even division will get better perf.
+ moe_mlp_layer = ["gate_proj", "up_proj", "down_proj", "w1", "w2", "w3"]
+ not_moe_mlp_layer = True
+ if name != None and any(s in str(name) for s in moe_mlp_layer):
+ not_moe_mlp_layer = False
# When we have num_kv_heads defined, uneven division is possible, otherwise enforce near even division
if rank == None:
rank = dist.get_rank()
if num_kv_heads != None and total_size % num_kv_heads == 0 and "mlp" not in str(name) and str(
- name) not in last_linear:
+ name) not in last_linear and not_moe_mlp_layer:
my_slices = (num_kv_heads // mp_size) + (1 if rank < (num_kv_heads % mp_size) else 0)
return total_size * my_slices // num_kv_heads
else:
diff --git a/deepspeed/ops/fp_quantizer/__init__.py b/deepspeed/ops/fp_quantizer/__init__.py
index 51377bc6092c..f9cf23373c26 100644
--- a/deepspeed/ops/fp_quantizer/__init__.py
+++ b/deepspeed/ops/fp_quantizer/__init__.py
@@ -4,9 +4,4 @@
# DeepSpeed Team
from .quantize import FP_Quantize, Quantizer
-
-try:
- import triton
- from .fp8_gemm import matmul_fp8
-except ImportError:
- pass
+from .fp8_gemm import matmul_fp8
diff --git a/deepspeed/ops/fp_quantizer/fp8_gemm.py b/deepspeed/ops/fp_quantizer/fp8_gemm.py
index 55504e3af8c9..db4fa5ae2c92 100644
--- a/deepspeed/ops/fp_quantizer/fp8_gemm.py
+++ b/deepspeed/ops/fp_quantizer/fp8_gemm.py
@@ -11,161 +11,18 @@
###################################
import torch
-import triton
-import triton.language as tl
-@triton.jit
-def matmul_kernel_fp8_bf16(inp_ptr, weight_ptr, out_ptr, scale_ptr, M, N, K, stride_am, stride_ak, stride_bk,
- stride_bn, stride_cm, stride_cn, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr,
- BLOCK_SIZE_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr,
- quantization_group_size: tl.constexpr):
- pid = tl.program_id(axis=0)
- num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
- num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
- num_pid_in_group = GROUP_SIZE_M * num_pid_n
- group_id = pid // num_pid_in_group
- first_pid_m = group_id * GROUP_SIZE_M
- group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
- pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
- pid_n = (pid % num_pid_in_group) // group_size_m
+def matmul_fp8(inp, weight, scale, quantization_group_size, quantizer):
+ from deepspeed import get_accelerator
- offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
- offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
- offs_k = tl.arange(0, BLOCK_SIZE_K)
+ if not get_accelerator().is_triton_supported():
+ return matmul_fp8_fallback(inp, weight, scale, quantization_group_size, quantizer)
+ else:
+ # Import dynamically to prevent failures on systems without triton.
+ from .fp8_gemm_triton import matmul_fp8_triton
+ return matmul_fp8_triton(inp, weight, scale, quantization_group_size)
- inp_data = inp_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
- weight_data = weight_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
- weight_ptrs_offset = offs_k[:, None] * (stride_bk // quantization_group_size) + (
- (pid_n * BLOCK_SIZE_N) // quantization_group_size)
- weight = tl.load(weight_data, mask=offs_k[:, None] < K, other=0.0)
- scale = tl.load(scale_ptr + weight_ptrs_offset)
-
- accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
- for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
- inp = tl.load(inp_data, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0)
- # Dequantize weight (fp8 -> bf16)
- w = (((weight & 0x80) << 8) | ((weight & 0x7f) << 4)).to(tl.uint16)
- w = (w + 0x3C00).to(tl.uint16)
- w = (w.to(tl.bfloat16, bitcast=True) * scale).to(tl.bfloat16)
-
- inp_data += BLOCK_SIZE_K * stride_ak
- weight_data += BLOCK_SIZE_K * stride_bk
- weight_mask = offs_k[:, None] < K - (k + 1) * BLOCK_SIZE_K
- weight = tl.load(weight_data, mask=weight_mask, other=0.0)
- scale = tl.load(scale_ptr + (weight_ptrs_offset +
- (((k + 1) * BLOCK_SIZE_K * stride_bk) // quantization_group_size)),
- mask=weight_mask,
- other=0.0)
-
- accumulator += tl.dot(inp, w)
-
- out = accumulator.to(tl.bfloat16)
-
- offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
- offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
- out_data = out_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
- tl.store(out_data, out, mask=(offs_cm[:, None] < M) & (offs_cn[None, :] < N))
-
-
-@triton.jit
-def matmul_kernel_fp8_fp16(inp_ptr, weight_ptr, out_ptr, scale_ptr, M, N, K, stride_am, stride_ak, stride_bk,
- stride_bn, stride_cm, stride_cn, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr,
- BLOCK_SIZE_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr,
- quantization_group_size: tl.constexpr):
- pid = tl.program_id(axis=0)
- num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
- num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
- num_pid_in_group = GROUP_SIZE_M * num_pid_n
- group_id = pid // num_pid_in_group
- first_pid_m = group_id * GROUP_SIZE_M
- group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
- pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
- pid_n = (pid % num_pid_in_group) // group_size_m
-
- offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
- offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
- offs_k = tl.arange(0, BLOCK_SIZE_K)
-
- inp_data = inp_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
- weight_data = weight_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
- weight_ptrs_offset = offs_k[:, None] * (stride_bk // quantization_group_size) + (
- (pid_n * BLOCK_SIZE_N) // quantization_group_size)
-
- weight = tl.load(weight_data, mask=offs_k[:, None] < K, other=0.0)
- scale = tl.load(scale_ptr + weight_ptrs_offset)
-
- accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
- for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
- inp = tl.load(inp_data, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0)
- # Dequantize weight (fp8 -> fp16)
- w = (((weight & 0x80) << 8) | ((weight & 0x7f) << 7)).to(tl.uint16)
- w = (w + 0x2000).to(tl.uint16)
- w = (w.to(tl.float16, bitcast=True) * scale).to(tl.float16)
-
- inp_data += BLOCK_SIZE_K * stride_ak
- weight_data += BLOCK_SIZE_K * stride_bk
-
- weight = tl.load(weight_data, mask=offs_k[:, None] < K - (k + 1) * BLOCK_SIZE_K, other=0.0)
- scale = tl.load(scale_ptr + (weight_ptrs_offset +
- (((k + 1) * BLOCK_SIZE_K * stride_bk) // quantization_group_size)))
-
- accumulator += tl.dot(inp, w)
-
- out = accumulator.to(tl.float16)
-
- offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
- offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
- out_data = out_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
- tl.store(out_data, out, mask=(offs_cm[:, None] < M) & (offs_cn[None, :] < N))
-
-
-def matmul_fp8(inp, weight, scale, quantization_group_size):
-
- assert inp.shape[1] == weight.shape[0], \
- f"Incompatible dimensions (input: {inp.shape}, weight: {weight.shape})"
-
- M, K = inp.shape
- K, N = weight.shape
-
- out = torch.empty((M, N), device=inp.device, dtype=inp.dtype)
-
- # GEMM tuning parameters!
- # TODO: Add a more configurable tuning for selecting the best GeMM
- BLOCK_SIZE_M = 16 if M <= 16 else 32 if M <= 32 else 64 if M <= 64 else 128
- BLOCK_SIZE_N = 64
- BLOCK_SIZE_K = max(64, quantization_group_size)
- GROUP_SIZE_M = 8
- num_stages = 4
- num_warps = 4
- if M >= 256:
- BLOCK_SIZE_M = 256
- BLOCK_SIZE_N = 128
- BLOCK_SIZE_K = max(128, quantization_group_size)
- num_stages = 3
- num_warps = 8
-
- grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), )
- kernel = matmul_kernel_fp8_bf16 if inp.dtype == torch.bfloat16 else matmul_kernel_fp8_fp16
- kernel[grid](inp,
- weight,
- out,
- scale,
- M,
- N,
- K,
- inp.stride(0),
- inp.stride(1),
- weight.stride(0),
- weight.stride(1),
- out.stride(0),
- out.stride(1),
- quantization_group_size=quantization_group_size,
- BLOCK_SIZE_M=BLOCK_SIZE_M,
- BLOCK_SIZE_N=BLOCK_SIZE_N,
- BLOCK_SIZE_K=BLOCK_SIZE_K,
- GROUP_SIZE_M=GROUP_SIZE_M,
- num_stages=num_stages,
- num_warps=num_warps)
- return out
+def matmul_fp8_fallback(inp, weight, scale, quantization_group_size, quantizer):
+ return torch.matmul(inp, quantizer.dequantize(weight, scale=scale))
diff --git a/deepspeed/ops/fp_quantizer/fp8_gemm_triton.py b/deepspeed/ops/fp_quantizer/fp8_gemm_triton.py
new file mode 100644
index 000000000000..746e217d4194
--- /dev/null
+++ b/deepspeed/ops/fp_quantizer/fp8_gemm_triton.py
@@ -0,0 +1,171 @@
+# Copyright (c) Microsoft Corporation.
+# SPDX-License-Identifier: Apache-2.0
+
+# DeepSpeed Team
+
+######## Fused MoE kernel #########
+# These kernels are implemented for
+# fusing GeMM with dequantization of
+# fp8 weight data when using bit-16
+# activation.
+###################################
+
+import torch
+import triton
+import triton.language as tl
+
+
+@triton.jit
+def matmul_kernel_fp8_bf16(inp_ptr, weight_ptr, out_ptr, scale_ptr, M, N, K, stride_am, stride_ak, stride_bk,
+ stride_bn, stride_cm, stride_cn, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr,
+ BLOCK_SIZE_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr,
+ quantization_group_size: tl.constexpr):
+ pid = tl.program_id(axis=0)
+ num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
+ num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
+ num_pid_in_group = GROUP_SIZE_M * num_pid_n
+ group_id = pid // num_pid_in_group
+ first_pid_m = group_id * GROUP_SIZE_M
+ group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
+ pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
+ pid_n = (pid % num_pid_in_group) // group_size_m
+
+ offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
+ offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
+ offs_k = tl.arange(0, BLOCK_SIZE_K)
+
+ inp_data = inp_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
+ weight_data = weight_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
+ weight_ptrs_offset = offs_k[:, None] * (stride_bk // quantization_group_size) + (
+ (pid_n * BLOCK_SIZE_N) // quantization_group_size)
+
+ weight = tl.load(weight_data, mask=offs_k[:, None] < K, other=0.0)
+ scale = tl.load(scale_ptr + weight_ptrs_offset)
+
+ accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
+ for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
+ inp = tl.load(inp_data, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0)
+ # Dequantize weight (fp8 -> bf16)
+ w = (((weight & 0x80) << 8) | ((weight & 0x7f) << 4)).to(tl.uint16)
+ w = (w + 0x3C00).to(tl.uint16)
+ w = (w.to(tl.bfloat16, bitcast=True) * scale).to(tl.bfloat16)
+
+ inp_data += BLOCK_SIZE_K * stride_ak
+ weight_data += BLOCK_SIZE_K * stride_bk
+ weight_mask = offs_k[:, None] < K - (k + 1) * BLOCK_SIZE_K
+ weight = tl.load(weight_data, mask=weight_mask, other=0.0)
+ scale = tl.load(scale_ptr + (weight_ptrs_offset +
+ (((k + 1) * BLOCK_SIZE_K * stride_bk) // quantization_group_size)),
+ mask=weight_mask,
+ other=0.0)
+
+ accumulator += tl.dot(inp, w)
+
+ out = accumulator.to(tl.bfloat16)
+
+ offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
+ offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
+ out_data = out_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
+ tl.store(out_data, out, mask=(offs_cm[:, None] < M) & (offs_cn[None, :] < N))
+
+
+@triton.jit
+def matmul_kernel_fp8_fp16(inp_ptr, weight_ptr, out_ptr, scale_ptr, M, N, K, stride_am, stride_ak, stride_bk,
+ stride_bn, stride_cm, stride_cn, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr,
+ BLOCK_SIZE_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr,
+ quantization_group_size: tl.constexpr):
+ pid = tl.program_id(axis=0)
+ num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
+ num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
+ num_pid_in_group = GROUP_SIZE_M * num_pid_n
+ group_id = pid // num_pid_in_group
+ first_pid_m = group_id * GROUP_SIZE_M
+ group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
+ pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
+ pid_n = (pid % num_pid_in_group) // group_size_m
+
+ offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
+ offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
+ offs_k = tl.arange(0, BLOCK_SIZE_K)
+
+ inp_data = inp_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
+ weight_data = weight_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
+ weight_ptrs_offset = offs_k[:, None] * (stride_bk // quantization_group_size) + (
+ (pid_n * BLOCK_SIZE_N) // quantization_group_size)
+
+ weight = tl.load(weight_data, mask=offs_k[:, None] < K, other=0.0)
+ scale = tl.load(scale_ptr + weight_ptrs_offset)
+
+ accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
+ for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
+ inp = tl.load(inp_data, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0)
+ # Dequantize weight (fp8 -> fp16)
+ w = (((weight & 0x80) << 8) | ((weight & 0x7f) << 7)).to(tl.uint16)
+ w = (w + 0x2000).to(tl.uint16)
+ w = (w.to(tl.float16, bitcast=True) * scale).to(tl.float16)
+
+ inp_data += BLOCK_SIZE_K * stride_ak
+ weight_data += BLOCK_SIZE_K * stride_bk
+
+ weight = tl.load(weight_data, mask=offs_k[:, None] < K - (k + 1) * BLOCK_SIZE_K, other=0.0)
+ scale = tl.load(scale_ptr + (weight_ptrs_offset +
+ (((k + 1) * BLOCK_SIZE_K * stride_bk) // quantization_group_size)))
+
+ accumulator += tl.dot(inp, w)
+
+ out = accumulator.to(tl.float16)
+
+ offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
+ offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
+ out_data = out_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
+ tl.store(out_data, out, mask=(offs_cm[:, None] < M) & (offs_cn[None, :] < N))
+
+
+def matmul_fp8_triton(inp, weight, scale, quantization_group_size):
+
+ assert inp.shape[1] == weight.shape[0], \
+ f"Incompatible dimensions (input: {inp.shape}, weight: {weight.shape})"
+
+ M, K = inp.shape
+ K, N = weight.shape
+
+ out = torch.empty((M, N), device=inp.device, dtype=inp.dtype)
+
+ # GEMM tuning parameters!
+ # TODO: Add a more configurable tuning for selecting the best GeMM
+ BLOCK_SIZE_M = 16 if M <= 16 else 32 if M <= 32 else 64 if M <= 64 else 128
+ BLOCK_SIZE_N = 64
+ BLOCK_SIZE_K = max(64, quantization_group_size)
+ GROUP_SIZE_M = 8
+ num_stages = 4
+ num_warps = 4
+ if M >= 256:
+ BLOCK_SIZE_M = 256
+ BLOCK_SIZE_N = 128
+ BLOCK_SIZE_K = max(128, quantization_group_size)
+ num_stages = 3
+ num_warps = 8
+
+ grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), )
+ kernel = matmul_kernel_fp8_bf16 if inp.dtype == torch.bfloat16 else matmul_kernel_fp8_fp16
+ kernel[grid](inp,
+ weight,
+ out,
+ scale,
+ M,
+ N,
+ K,
+ inp.stride(0),
+ inp.stride(1),
+ weight.stride(0),
+ weight.stride(1),
+ out.stride(0),
+ out.stride(1),
+ quantization_group_size=quantization_group_size,
+ BLOCK_SIZE_M=BLOCK_SIZE_M,
+ BLOCK_SIZE_N=BLOCK_SIZE_N,
+ BLOCK_SIZE_K=BLOCK_SIZE_K,
+ GROUP_SIZE_M=GROUP_SIZE_M,
+ num_stages=num_stages,
+ num_warps=num_warps)
+ return out
diff --git a/deepspeed/ops/transformer/inference/triton/matmul_ext.py b/deepspeed/ops/transformer/inference/triton/matmul_ext.py
index 412c8740a216..9be4b0098c37 100644
--- a/deepspeed/ops/transformer/inference/triton/matmul_ext.py
+++ b/deepspeed/ops/transformer/inference/triton/matmul_ext.py
@@ -19,6 +19,9 @@
# -----------------------------------------------------------------------------
# util class/functions for triton
def is_nfs_path(path):
+ if os.name == 'nt':
+ return False
+
# Normalize the path to get the absolute path
path = os.path.abspath(path)
@@ -99,7 +102,7 @@ def put(self, table):
with FileLock(self.lock_path):
with open(self.file_path + ".tmp", 'wb') as handle:
pickle.dump(table, handle)
- os.rename(self.file_path + ".tmp", self.file_path)
+ os.replace(self.file_path + ".tmp", self.file_path)
def load(self):
if os.path.exists(self.file_path):
diff --git a/deepspeed/runtime/compiler.py b/deepspeed/runtime/compiler.py
index fa9220f4fcd0..be778b83f8bb 100644
--- a/deepspeed/runtime/compiler.py
+++ b/deepspeed/runtime/compiler.py
@@ -4,6 +4,7 @@
# DeepSpeed Team
import torch
+from deepspeed.utils.torch import required_torch_version
try:
from torch.compiler import is_compiling as torch_is_compiling
@@ -16,7 +17,7 @@
def is_compile_supported():
- return hasattr(torch, "compiler") and hasattr(torch.nn.Module, "compile")
+ return required_torch_version(min_version=2.1)
def disable(func):
diff --git a/deepspeed/runtime/domino/transformer.py b/deepspeed/runtime/domino/transformer.py
index 8eb95e49c29d..88c5494c8147 100644
--- a/deepspeed/runtime/domino/transformer.py
+++ b/deepspeed/runtime/domino/transformer.py
@@ -6,8 +6,7 @@
import torch
import torch.nn.functional as F
from torch.nn.parameter import Parameter
-import deepspeed
-from deepspeed import comm as dist
+import deepspeed.comm as dist
from deepspeed.accelerator import get_accelerator
@@ -97,7 +96,7 @@ def backward(ctx, grad_output):
return grad_output
# Async All-reduce.
- handle = deepspeed.comm.all_reduce(grad_output, group=ctx.mpu.get_tensor_model_parallel_group(), async_op=True)
+ handle = dist.all_reduce(grad_output, group=ctx.mpu.get_tensor_model_parallel_group(), async_op=True)
ctx.handle_dic[ctx.h_id] = handle
return None, grad_output, None, None
@@ -249,6 +248,10 @@ def __init__(self,
output_bias=None):
super(DominoTransformerLayer, self).__init__()
+ if not dist.is_initialized():
+ dist.init_distributed()
+ assert dist.is_initialized(), "deepspeed.comm is not initialized!"
+
self.llama_model = config.llama_model
self.layer_number = layer_number
self.layer_type = layer_type
@@ -358,18 +361,14 @@ def forward(self, hidden_states, attention_mask, rotary_pos_emb=None):
layernorm_output0,
attention_mask,
rotary_pos_emb=rotary_pos_emb)
- handle0 = deepspeed.comm.all_reduce(attention_output0,
- group=self.mpu.get_tensor_model_parallel_group(),
- async_op=True)
+ handle0 = dist.all_reduce(attention_output0, group=self.mpu.get_tensor_model_parallel_group(), async_op=True)
attention_output1, attention_bias1 = \
self.self_attention(
layernorm_output1,
attention_mask,
rotary_pos_emb=rotary_pos_emb)
- handle1 = deepspeed.comm.all_reduce(attention_output1,
- group=self.mpu.get_tensor_model_parallel_group(),
- async_op=True)
+ handle1 = dist.all_reduce(attention_output1, group=self.mpu.get_tensor_model_parallel_group(), async_op=True)
handle0.wait()
# Residual0 connection.
@@ -413,7 +412,7 @@ def forward(self, hidden_states, attention_mask, rotary_pos_emb=None):
output0 = output0 + bias_c
output0 = self.mlp_activation_func(output0)
output0 = torch.matmul(output0, self.weight_r.t())
- handle2 = deepspeed.comm.all_reduce(output0, group=self.mpu.get_tensor_model_parallel_group(), async_op=True)
+ handle2 = dist.all_reduce(output0, group=self.mpu.get_tensor_model_parallel_group(), async_op=True)
handle1.wait()
@@ -425,7 +424,7 @@ def forward(self, hidden_states, attention_mask, rotary_pos_emb=None):
if bias_c is not None:
output1 = output1 + bias_c
output1 = torch.matmul(output1, self.weight_r.t())
- deepspeed.comm.all_reduce(output1, group=self.mpu.get_tensor_model_parallel_group())
+ dist.all_reduce(output1, group=self.mpu.get_tensor_model_parallel_group())
handle2.wait()
diff --git a/deepspeed/runtime/pipe/module.py b/deepspeed/runtime/pipe/module.py
index 31fec30be788..49fa2807c355 100644
--- a/deepspeed/runtime/pipe/module.py
+++ b/deepspeed/runtime/pipe/module.py
@@ -116,7 +116,9 @@ def forward(self, inputs):
partition_method (str, optional): The method upon which the layers are partitioned. Defaults to 'parameters'.
activation_checkpoint_interval (int, optional): The granularity activation checkpointing in terms of number of layers. 0 disables activation checkpointing.
activation_checkpoint_func (callable, optional): The function to use for activation checkpointing. Defaults to ``deepspeed.checkpointing.checkpoint``.
- checkpointable_layers(list, optional): Checkpointable layers may not be checkpointed. Defaults to None which does not additional filtering.
+ checkpointable_layers (list[str], optional): List of layer class names that are eligible for checkpointing. For GPT models,
+ ParallelTransformerLayerPipe is always checkpointed regardless of this list. If None, all layers with parameters are
+ considered checkpointable. Defaults to None.
dynamic_shape: Allows dynamic shapes of inputs. This might have a performance impact.
"""
@@ -650,9 +652,17 @@ def _is_checkpointable(self, funcs):
# because only non_reentrant_checkpoint can accept inputs with requires_grad=False
# otherwise, the backward of the embedding layer won't receive gradients.
if self.__class__.__name__ in ('GPTModelPipe', 'GPT2ModelPipe'):
- return all('ParallelTransformerLayerPipe' in f.__class__.__name__ for f in funcs)
+ # For GPT models, checkpoint both transformer layers and any additional
+ # layers specified in checkpointable_layers (if provided)
+ return all('ParallelTransformerLayerPipe' in f.__class__.__name__ or (
+ self.checkpointable_layers is not None and f.__class__.__name__ in self.checkpointable_layers)
+ for f in funcs)
+
if self.checkpointable_layers is not None:
+ # For non-GPT models, only checkpoint layers specified in checkpointable_layers
return all(f.__class__.__name__ in self.checkpointable_layers for f in funcs)
+
+ # Default behavior: checkpoint any layer that has parameters
params = [f.parameters() for f in funcs if isinstance(f, torch.nn.Module)]
return any(len(list(p)) > 0 for p in params)
@@ -662,3 +672,11 @@ def get_additional_losses(self):
Return a dictionary of {"loss name": loss_value} or None if no additional losses.
"""
return None
+
+ def compile(self, *args, **kwargs):
+ for idx, layer in enumerate(self.forward_funcs):
+ if isinstance(layer, nn.Module):
+ layer.compile(*args, **kwargs)
+ else:
+ new_layer = torch.compile(layer, *args, **kwargs)
+ self.forward_funcs[idx] = new_layer
diff --git a/deepspeed/runtime/zero/mics.py b/deepspeed/runtime/zero/mics.py
index c9ae58a121de..628bf86a61da 100755
--- a/deepspeed/runtime/zero/mics.py
+++ b/deepspeed/runtime/zero/mics.py
@@ -38,7 +38,7 @@ class MiCS_AllGatherCoalescedHandle(AllGatherCoalescedHandle):
def __init__(self, allgather_handle, params: List[Parameter], partitions: List[Tensor], world_size: int) -> None:
super().__init__(allgather_handle, params, partitions, world_size)
- def wait(self) -> None:
+ def wait(self, **kwargs) -> None:
"""
"""
# let the current stream to op
diff --git a/deepspeed/runtime/zero/parameter_offload.py b/deepspeed/runtime/zero/parameter_offload.py
index 0be88a1e1ba6..d5b7bac55146 100644
--- a/deepspeed/runtime/zero/parameter_offload.py
+++ b/deepspeed/runtime/zero/parameter_offload.py
@@ -145,6 +145,16 @@ def __init__(
module.ds_inflight_param_registry = InflightParamRegistry()
self.__inflight_param_registry = module.ds_inflight_param_registry
+ self.fast_sharding_for_leaf_module = False
+
+ if zero_module_granularity_threshold > 0:
+ self.min_granularity_value = sys.maxsize
+ self.min_granularity_layer = None
+ self.granularity_info = set()
+ self.z3_leaf_layers = []
+ self._set_z3_leaf_modules_by_threshold(module, zero_module_granularity_threshold)
+ self.fast_sharding_for_leaf_module = True
+
self.param_coordinator = PartitionedParameterCoordinator(
prefetch_bucket_sz=self._prefetch_bucket_sz,
max_reuse_distance_in_numel=self._max_reuse_distance_in_numel,
@@ -155,14 +165,7 @@ def __init__(
timers=self.timers,
zero_quantized_weights=self.zero_quantized_weights,
zero_quantized_nontrainable_weights=self.zero_quantized_nontrainable_weights,
- )
-
- if zero_module_granularity_threshold > 0:
- self.min_granularity_value = sys.maxsize
- self.min_granularity_layer = None
- self.granularity_info = set()
- self.z3_leaf_layers = []
- self._set_z3_leaf_modules_by_threshold(module, zero_module_granularity_threshold)
+ fast_sharding_for_leaf_module=self.fast_sharding_for_leaf_module)
self.forward_hooks = []
self.backward_hooks = []
diff --git a/deepspeed/runtime/zero/partition_parameters.py b/deepspeed/runtime/zero/partition_parameters.py
index cb0cd7c8017d..e8cb797b8a5b 100755
--- a/deepspeed/runtime/zero/partition_parameters.py
+++ b/deepspeed/runtime/zero/partition_parameters.py
@@ -55,7 +55,7 @@ def __init__(self, param: Parameter) -> None:
non_blocking=True).view(param.ds_shape)
self.__param = param
- def wait(self) -> None:
+ def wait(self, **kwargs) -> None:
if not get_accelerator().resolves_data_dependency():
get_accelerator().current_stream().synchronize()
self.__param.ds_status = ZeroParamStatus.AVAILABLE
@@ -78,7 +78,7 @@ def __init__(self, params: List[Parameter]) -> None:
non_blocking=True).view(param.ds_shape)
@instrument_w_nvtx
- def wait(self) -> None:
+ def wait(self, **kwargs) -> None:
if self.__complete:
return
@@ -639,7 +639,7 @@ def __init__(self, handle, param: Parameter, quantization=None) -> None:
self.__param = param
self.__quantization = quantization
- def wait(self) -> None:
+ def wait(self, handle_dependency=True) -> None:
instrument_w_nvtx(self.__handle.wait)()
if self.__quantization:
instrument_w_nvtx(self.__quantization.quant_handle.wait)()
@@ -650,6 +650,8 @@ def wait(self) -> None:
class AllGatherCoalescedHandle:
+ data_buffer = []
+
def __init__(
self,
allgather_handle,
@@ -672,7 +674,7 @@ def __init__(
raise RuntimeError(f"expected param {param.ds_summary()} to not be available")
@instrument_w_nvtx
- def wait(self) -> None:
+ def wait(self, handle_dependency=True) -> None:
if self.complete:
return
@@ -704,14 +706,20 @@ def wait(self) -> None:
partitions.append(part_to_copy)
param.data = instrument_w_nvtx(torch.cat)(partitions).view(param.ds_shape)
param.ds_status = ZeroParamStatus.AVAILABLE
-
- for part_to_copy in partitions:
- if not get_accelerator().is_synchronized_device():
+ if not get_accelerator().is_synchronized_device() and handle_dependency:
+ for part_to_copy in partitions:
part_to_copy.record_stream(get_accelerator().current_stream())
param_offset += ds_tensor_numel
self.complete = True
+ if not get_accelerator().is_synchronized_device() and not handle_dependency:
+ # if the device needs to handle dependencies and opts for explicit processing outside the function.
+ AllGatherCoalescedHandle.data_buffer.append(partitions)
+
+ @staticmethod
+ def free_buffer():
+ AllGatherCoalescedHandle.data_buffer = []
class MultipleAllGatherHandles:
@@ -719,9 +727,9 @@ class MultipleAllGatherHandles:
def __init__(self, handles: List[AllGatherCoalescedHandle]):
self.handles = handles
- def wait(self) -> None:
+ def wait(self, handle_dependency=True) -> None:
for handle in self.handles:
- handle.wait()
+ handle.wait(handle_dependency)
class AllReduceCoalescedHandle:
@@ -1377,13 +1385,13 @@ def all_gather_coalesced(params: Iterable[Parameter],
quantization=quant_info,
)
- def partition(param_list=None, hierarchy=0, has_been_updated=False):
+ def partition(param_list=None, hierarchy=0, has_been_updated=False, free_data=True):
cls = param
print_rank_0(f"{'--'*hierarchy}----Partitioning param {debug_param2name_id_shape_device(cls)}",
force=False)
if param_list is None:
param_list = [cls]
- self._partition(param_list, has_been_updated=has_been_updated)
+ self._partition(param_list, has_been_updated=has_been_updated, free_data=True)
def reduce_gradients_at_owner(param_list=None, hierarchy=0):
cls = param
@@ -1527,12 +1535,12 @@ def _all_gather(self, param_list, async_op=False, hierarchy=None):
return handles
- def _partition(self, param_list, force=False, has_been_updated=False):
+ def _partition(self, param_list, force=False, has_been_updated=False, free_data=True):
for param in param_list:
print_rank_0(f"Before Partitioning Param {param.ds_id}", force=False)
if self.zero_param_process_group is not None:
self._partition_param_sec(param)
- self._partition_param(param, has_been_updated=has_been_updated)
+ self._partition_param(param, has_been_updated=has_been_updated, free_data=True)
param.ds_status = ZeroParamStatus.NOT_AVAILABLE
# if param.ds_tensor is not None:
@@ -1540,7 +1548,7 @@ def _partition(self, param_list, force=False, has_been_updated=False):
# "After the parameters are initially partitioned, make sure we are not recreating the partition."
#print_rank_0(f"After Partitioning Param {param.ds_id} {param.ds_tensor.size()} {param.ds_tensor}",force=False)
@instrument_w_nvtx
- def _partition_param(self, param, buffer=None, has_been_updated=False):
+ def _partition_param(self, param, buffer=None, has_been_updated=False, free_data=True):
assert param.ds_status is not ZeroParamStatus.INFLIGHT, f" {param} Cannot partition a param in flight"
global reuse_buffers
print_rank_0(f"Param id {param.ds_id} status is {param.ds_status}", force=False)
@@ -1565,7 +1573,8 @@ def _partition_param(self, param, buffer=None, has_been_updated=False):
see_memory_usage(f'Before partitioning param {param.ds_id} {param.shape}', force=False)
# param.data does not store anything meaningful in partitioned state
- free_param(param)
+ if free_data:
+ free_param(param)
see_memory_usage(f'After partitioning param {param.ds_id} {param.shape}', force=False)
if param.ds_tensor.final_location == OffloadDeviceEnum.nvme:
diff --git a/deepspeed/runtime/zero/partitioned_param_coordinator.py b/deepspeed/runtime/zero/partitioned_param_coordinator.py
index 596d0e9c20f9..08cb6c0de54f 100644
--- a/deepspeed/runtime/zero/partitioned_param_coordinator.py
+++ b/deepspeed/runtime/zero/partitioned_param_coordinator.py
@@ -76,18 +76,17 @@ class __ParamInTrace:
param: Parameter
step_id_last_used_at: int
- def __init__(
- self,
- prefetch_bucket_sz: int,
- max_reuse_distance_in_numel: int,
- max_available_parameters_in_numel: int,
- allgather_stream: get_accelerator().Stream,
- inflight_param_registry: InflightParamRegistry,
- prefetch_nvme: bool = False,
- timers=None,
- zero_quantized_weights=False,
- zero_quantized_nontrainable_weights=False,
- ) -> None:
+ def __init__(self,
+ prefetch_bucket_sz: int,
+ max_reuse_distance_in_numel: int,
+ max_available_parameters_in_numel: int,
+ allgather_stream: get_accelerator().Stream,
+ inflight_param_registry: InflightParamRegistry,
+ prefetch_nvme: bool = False,
+ timers=None,
+ zero_quantized_weights=False,
+ zero_quantized_nontrainable_weights=False,
+ fast_sharding_for_leaf_module=False) -> None:
# mapping of param -> handle for each param that is currently in flight
self.__inflight_param_registry = inflight_param_registry
# keeps track of the number of submodules invoked so far.
@@ -130,6 +129,10 @@ def __init__(
self.__max_ongoing_fetch_events: int = 2
self.__profiler = PartitionedParameterProfiler(timers if ENABLE_PROFILER else None)
+ # whether to enable fast fetch for the z3 leaf module.
+ # this will improve fetch speed but will not break down leaf module parameters to alleviate memory pressure.
+ self.fast_sharding_for_leaf_module = fast_sharding_for_leaf_module
+
"""Tracing and Tracking
TODO. consider performing trace before initializing PartitionedParameterCoordinator
and passing trace results into constructor. This way all the code in here can
@@ -308,6 +311,7 @@ def fetch_sub_module(self, current_submodule: Module, forward: bool) -> None:
wait_numel = 0
wait_event_name = __class__.FORWARD_FETCH_WAIT if forward else __class__.BACKWARD_FETCH_WAIT
self.__profiler.start_event(wait_event_name)
+ fast_fetch = self.fast_sharding_for_leaf_module and z3_leaf_module(current_submodule)
# wait for parameters in the immediately needed submodule to become available
for param in params_to_fetch:
param.ds_active_sub_modules.add(current_submodule.id)
@@ -321,9 +325,9 @@ def fetch_sub_module(self, current_submodule: Module, forward: bool) -> None:
if len(self.__ongoing_fetch_events) > self.__max_ongoing_fetch_events:
self.__ongoing_fetch_events.popleft().synchronize()
- self.__inflight_param_registry.pop(param).wait()
+ self.__inflight_param_registry.pop(param).wait(handle_dependency=not fast_fetch)
- if not get_accelerator().handles_memory_backpressure():
+ if not get_accelerator().handles_memory_backpressure() and not fast_fetch:
event = get_accelerator().Event()
event.record()
self.__ongoing_fetch_events.append(event)
@@ -331,6 +335,8 @@ def fetch_sub_module(self, current_submodule: Module, forward: bool) -> None:
assert param.ds_status == ZeroParamStatus.AVAILABLE, param.ds_summary()
if not get_accelerator().resolves_data_dependency():
get_accelerator().current_stream().wait_stream(self.__allgather_stream)
+ if fast_fetch:
+ AllGatherCoalescedHandle.free_buffer()
self.__profiler.stop_event(wait_event_name, wait_numel)
# kick off parameter prefetches for upcoming modules
@@ -412,10 +418,20 @@ def release_sub_module(self, submodule: Module) -> None:
be released."""
params_to_release = (self.__params_to_release(submodule, self.__step_id) if self.is_complete_trace() else set(
p.ds_id for p in iter_params(submodule, recurse=z3_leaf_module(submodule))))
+
+ free_data = not z3_leaf_module(submodule) or not self.fast_sharding_for_leaf_module
+ if not free_data:
+ # wait for the computation to finish and launch as early as possible.
+ empty_buffer = torch.empty(1, device=get_accelerator().current_device())
+
for param in iter_params(submodule, recurse=z3_leaf_module(submodule)):
param.ds_active_sub_modules.discard(submodule.id)
if param.ds_id in params_to_release and not param.is_external_param:
- self.__release_param(param)
+ self.__release_param(param, free_data)
+ if not free_data:
+ if param.ds_id in params_to_release and not param.is_external_param:
+ # empty buffer ensures that all computations are complete
+ param.data = empty_buffer
@instrument_w_nvtx
@torch.no_grad()
@@ -490,11 +506,11 @@ def __all_gather_params_(self, params: Set[Parameter], forward: bool, quantize:
@compiler.disable
@instrument_w_nvtx
- def __release_param(self, param: Parameter) -> None:
+ def __release_param(self, param: Parameter, free_data: bool = True) -> None:
if param.ds_status == ZeroParamStatus.AVAILABLE and not param.ds_active_sub_modules:
if logger.isEnabledFor(logging.DEBUG):
debug_rank0(f"-release: {param.ds_summary()}")
- param.partition()
+ param.partition(free_data=free_data)
self.__n_available_params -= param.ds_numel
@instrument_w_nvtx
diff --git a/deepspeed/runtime/zero/stage3.py b/deepspeed/runtime/zero/stage3.py
index 459cffce52c8..28f91cb9b3ab 100644
--- a/deepspeed/runtime/zero/stage3.py
+++ b/deepspeed/runtime/zero/stage3.py
@@ -16,6 +16,7 @@
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
from deepspeed.runtime.base_optimizer import ZeROOptimizer
from deepspeed.utils import logger
+from deepspeed.utils.torch import register_grad_hook
from deepspeed.runtime.fp16.loss_scaler import CreateLossScaler
from deepspeed.runtime.comm.coalesced_collectives import reduce_scatter_coalesced, all_to_all_quant_reduce, all_to_all_loco_quant_reduce
from deepspeed.runtime.utils import inf, is_model_parallel_parameter, get_only_unique_item
@@ -1159,7 +1160,6 @@ def overlapping_partition_gradients_reduce_epilogue(self):
def create_reduce_and_remove_grad_hooks(self):
print_rank_0(f'[Begin] Create gradient reduction hooks')
- self.grad_accs = []
self.leaf_parameters = defaultdict(list)
for i, param_group in enumerate(self.fp16_groups):
for param in param_group:
@@ -1172,15 +1172,12 @@ def create_reduce_and_remove_grad_hooks(self):
#print(f"After all gather {param.device}, {param.shape}")
def wrapper(param):
- param_tmp = param.expand_as(param)
- grad_acc = param_tmp.grad_fn.next_functions[0][0]
@instrument_w_nvtx
def reduce_partition_and_remove_grads(*notneeded):
self.reduce_ready_partitions_and_remove_grads(param)
- self._grad_acc_hooks.append(grad_acc.register_hook(reduce_partition_and_remove_grads))
- self.grad_accs.append(grad_acc)
+ self._grad_acc_hooks.append(register_grad_hook(param, reduce_partition_and_remove_grads))
#print(f"param grad fn {param.expand_as(param).grad_fn}")
if z3_leaf_parameter(param):
diff --git a/deepspeed/sequence/fpdt_layer.py b/deepspeed/sequence/fpdt_layer.py
index 4fab768ce63c..4fa2cc988a19 100644
--- a/deepspeed/sequence/fpdt_layer.py
+++ b/deepspeed/sequence/fpdt_layer.py
@@ -47,7 +47,7 @@ def _update_out_and_lse(
block_out = block_out.to(torch.float32)
block_lse = block_lse.transpose(-2, -1).unsqueeze(dim=-1)
- new_lse = lse + torch.log(1 + torch.exp(block_lse - lse))
+ new_lse = lse + torch.log1p(torch.exp(block_lse - lse))
out = torch.exp(lse - new_lse) * out + torch.exp(block_lse - new_lse) * block_out
diff --git a/deepspeed/utils/torch.py b/deepspeed/utils/torch.py
index eb22d3561035..1d32775fe64a 100644
--- a/deepspeed/utils/torch.py
+++ b/deepspeed/utils/torch.py
@@ -20,3 +20,12 @@ def required_torch_version(min_version=None, max_version=None):
return False
return True
+
+
+def register_grad_hook(param, hook):
+ if required_torch_version(min_version=2.1):
+ return param.register_post_accumulate_grad_hook(hook)
+ else:
+ param_tmp = param.expand_as(param)
+ grad_acc = param_tmp.grad_fn.next_functions[0][0]
+ return grad_acc.register_hook(hook)
diff --git a/op_builder/builder.py b/op_builder/builder.py
index 461281d4a569..ab26054bda7d 100644
--- a/op_builder/builder.py
+++ b/op_builder/builder.py
@@ -415,10 +415,11 @@ def cpu_arch(self):
return '-mcpu=native'
return '-march=native'
- def is_cuda_enable(self):
+ def get_cuda_compile_flag(self):
try:
- assert_no_cuda_mismatch(self.name)
- return '-D__ENABLE_CUDA__'
+ if not self.is_rocm_pytorch():
+ assert_no_cuda_mismatch(self.name)
+ return "-D__ENABLE_CUDA__"
except MissingCUDAException:
print(f"{WARNING} {self.name} cuda is missing or is incompatible with installed torch, "
"only cpu ops can be compiled!")
@@ -839,7 +840,7 @@ def cxx_args(self):
CPU_ARCH = self.cpu_arch()
SIMD_WIDTH = self.simd_width()
- CUDA_ENABLE = self.is_cuda_enable()
+ CUDA_ENABLE = self.get_cuda_compile_flag()
args += [
CPU_ARCH,
'-fopenmp',
diff --git a/setup.py b/setup.py
index c0452f867b31..cc5eb4a3500c 100755
--- a/setup.py
+++ b/setup.py
@@ -321,9 +321,9 @@ def op_enabled(op_name):
include_package_data=True,
scripts=scripts,
classifiers=[
- 'Programming Language :: Python :: 3.6', 'Programming Language :: Python :: 3.7',
'Programming Language :: Python :: 3.8', 'Programming Language :: Python :: 3.9',
- 'Programming Language :: Python :: 3.10'
+ 'Programming Language :: Python :: 3.10', 'Programming Language :: Python :: 3.11',
+ 'Programming Language :: Python :: 3.12'
],
license='Apache Software License 2.0',
ext_modules=ext_modules,
diff --git a/tests/unit/ops/fp_quantizer/test_fp8_gemm.py b/tests/unit/ops/fp_quantizer/test_fp8_gemm.py
index d66f7c8cb4cc..a4cf579f5943 100644
--- a/tests/unit/ops/fp_quantizer/test_fp8_gemm.py
+++ b/tests/unit/ops/fp_quantizer/test_fp8_gemm.py
@@ -14,6 +14,8 @@
from deepspeed.ops.fp_quantizer import FP_Quantize, matmul_fp8
+from deepspeed import get_accelerator
+
@pytest.mark.parametrize("dtype", [torch.bfloat16], ids=["bf16"])
@pytest.mark.parametrize("q_bits", [8], ids=[
@@ -21,23 +23,19 @@
])
@pytest.mark.parametrize("M", [1, 2, 4, 8, 32, 64, 128, 256, 512, 1024, 2048])
def test_fp_quant(dtype, q_bits, M):
+ device_name = get_accelerator().device_name()
quantization_group_size = 128
fpq = FP_Quantize(group_size=quantization_group_size)
N = 8192
H = 4096
- x = torch.randn(M, H, dtype=dtype, device='cuda')
- weight_bf16 = torch.randn(H, N, dtype=dtype, device='cuda')
+ x = torch.randn(M, H, dtype=dtype, device=device_name)
+ weight_bf16 = torch.randn(H, N, dtype=dtype, device=device_name)
- weight, _ = fpq.quantize(weight_bf16.data, q_bits=8, return_meta_tensor=True)
+ weight, _ = fpq.quantize(weight_bf16.data, q_bits=q_bits, return_meta_tensor=True)
scale = fpq.get_scales()
- out = matmul_fp8(
- x,
- weight,
- scale,
- quantization_group_size,
- )
+ out = matmul_fp8(x, weight, scale, quantization_group_size, fpq)
out_q = torch.matmul(x, fpq.dequantize(weight, scale=fpq.scale))
diff --git a/tests/unit/ops/transformer/inference/inference_test_utils.py b/tests/unit/ops/transformer/inference/inference_test_utils.py
index 9cfcae809f09..d63c51267e51 100644
--- a/tests/unit/ops/transformer/inference/inference_test_utils.py
+++ b/tests/unit/ops/transformer/inference/inference_test_utils.py
@@ -3,6 +3,8 @@
# DeepSpeed Team
+from typing import Tuple
+
import torch
from deepspeed.accelerator import get_accelerator
@@ -23,38 +25,22 @@ def get_tolerances():
DTYPES = None
-def get_dtypes():
+def get_dtypes(include_float=True):
global DTYPES
if DTYPES is None:
- DTYPES = get_accelerator().supported_dtypes()
+ DTYPES = [torch.float16, torch.float32] if include_float else [torch.float16]
+ try:
+ if get_accelerator().is_bf16_supported():
+ DTYPES.append(torch.bfloat16)
+ except (AssertionError, AttributeError):
+ pass
return DTYPES
-def allclose(x, y):
+def allclose(x, y, tolerances: Tuple[int, int] = None):
assert x.dtype == y.dtype
- rtol, atol = get_tolerances()[x.dtype]
+ if tolerances is None:
+ rtol, atol = get_tolerances()[x.dtype]
+ else:
+ rtol, atol = tolerances
return torch.allclose(x, y, rtol=rtol, atol=atol)
-
-
-def assert_almost_equal(x, y, decimal=2, err_msg=''):
- import numpy.testing as npt
- if isinstance(x, torch.Tensor):
- if x.dtype == torch.bfloat16:
- x = x.float()
- x = x.cpu().detach().numpy()
- if isinstance(y, torch.Tensor):
- if y.dtype == torch.bfloat16:
- y = y.float()
- y = y.cpu().detach().numpy()
- npt.assert_array_almost_equal(x, y, err_msg=err_msg, decimal=decimal)
-
-
-def max_diff(a, b):
- a = a.to(torch.float32).flatten()
- b = b.to(torch.float32).flatten()
- diff = torch.abs(a - b)
- max_diff_indices = torch.argsort(diff)[-1]
- print("Max difference indices:", max_diff_indices)
- print("Max difference values:", diff[max_diff_indices])
- print(f"{a[max_diff_indices]} vs {b[max_diff_indices]}")
- return max_diff_indices
diff --git a/tests/unit/ops/transformer/inference/test_attention.py b/tests/unit/ops/transformer/inference/test_attention.py
index ecf681542ff6..cae201d747a3 100644
--- a/tests/unit/ops/transformer/inference/test_attention.py
+++ b/tests/unit/ops/transformer/inference/test_attention.py
@@ -7,7 +7,7 @@
import torch
import deepspeed
from deepspeed.accelerator import get_accelerator
-from .inference_test_utils import assert_almost_equal
+from .inference_test_utils import allclose
# reference timplementation
@@ -88,4 +88,4 @@ def test_attention(BATCH, H, N_CTX, D_HEAD, causal, use_flash, dtype=torch.float
use_triton_flash=False,
use_ds_attention=False)
tri_out = tri_out.reshape((BATCH, N_CTX, H, D_HEAD)).permute(0, 2, 1, 3)
- assert_almost_equal(ref_out, tri_out)
+ assert (allclose(ref_out, tri_out))
diff --git a/tests/unit/ops/transformer/inference/test_bias_add.py b/tests/unit/ops/transformer/inference/test_bias_add.py
index f25bbc1be692..eb283924f73c 100644
--- a/tests/unit/ops/transformer/inference/test_bias_add.py
+++ b/tests/unit/ops/transformer/inference/test_bias_add.py
@@ -15,8 +15,6 @@
if not deepspeed.ops.__compatible_ops__[InferenceBuilder.NAME]:
pytest.skip("Inference ops are not available on this system", allow_module_level=True)
-torch_minor_version = None
-
def run_bias_add_reference(activations, bias):
return activations + bias
diff --git a/tests/unit/ops/transformer/inference/test_bias_gelu.py b/tests/unit/ops/transformer/inference/test_bias_gelu.py
index e3a3bad63961..f0a09245e890 100644
--- a/tests/unit/ops/transformer/inference/test_bias_gelu.py
+++ b/tests/unit/ops/transformer/inference/test_bias_gelu.py
@@ -10,8 +10,8 @@
from deepspeed.ops.op_builder import InferenceBuilder
from deepspeed.ops.transformer import DeepSpeedInferenceConfig
from deepspeed.ops.transformer.inference.op_binding.bias_gelu import BiasGeluOp
+from deepspeed.utils.torch import required_torch_version
from .inference_test_utils import allclose, get_dtypes
-from packaging import version as pkg_version
if not deepspeed.ops.__compatible_ops__[InferenceBuilder.NAME]:
pytest.skip("Inference ops are not available on this system", allow_module_level=True)
@@ -34,7 +34,7 @@ def run_bias_gelu_ds(activations, bias):
@pytest.mark.parametrize("channels", [512, 1232, 4096])
@pytest.mark.parametrize("dtype", get_dtypes())
def test_bias_gelu(batch, sequence, channels, dtype):
- if pkg_version.parse(torch.__version__) < pkg_version.parse("1.12"):
+ if not required_torch_version(min_version=1.12):
pytest.skip("gelu implementation matches only after torch 1.12")
activations_ds = torch.randn((batch, sequence, channels), dtype=dtype, device=get_accelerator().device_name())
diff --git a/tests/unit/ops/transformer/inference/test_layer_norm.py b/tests/unit/ops/transformer/inference/test_layer_norm.py
index 7711daf0d887..4a84add16046 100644
--- a/tests/unit/ops/transformer/inference/test_layer_norm.py
+++ b/tests/unit/ops/transformer/inference/test_layer_norm.py
@@ -9,7 +9,7 @@
from deepspeed.accelerator import get_accelerator
from deepspeed.ops.op_builder import InferenceBuilder
from deepspeed.ops.transformer.inference.op_binding.layer_norm import LayerNormOp
-from .inference_test_utils import allclose, get_dtypes, assert_almost_equal
+from .inference_test_utils import allclose, get_dtypes
try:
import triton # noqa: F401 # type: ignore
from deepspeed.ops.transformer.inference.triton import (
@@ -188,4 +188,4 @@ def test_triton_layer_norm(M, N, dtype, residual, input_bias, eps=1e-5, device='
y_ref = torch.nn.functional.layer_norm(x + res + (x_bias if input_bias else 0), w_shape, weight, bias,
eps).to(dtype)
# compare
- assert_almost_equal(y_tri, y_ref)
+ assert (allclose(y_tri, y_ref))
diff --git a/tests/unit/ops/transformer/inference/test_matmul.py b/tests/unit/ops/transformer/inference/test_matmul.py
index 2ab195ee0115..6f5173bbc827 100644
--- a/tests/unit/ops/transformer/inference/test_matmul.py
+++ b/tests/unit/ops/transformer/inference/test_matmul.py
@@ -11,8 +11,6 @@
if not deepspeed.ops.__compatible_ops__[InferenceBuilder.NAME]:
pytest.skip("Inference ops are not available on this system", allow_module_level=True)
-inference_module = None
-
def allclose(x, y):
assert x.dtype == y.dtype
diff --git a/tests/unit/pipe/test_pipe_module.py b/tests/unit/pipe/test_pipe_module.py
index 05c6a82ef55a..2a8a4b9b7d82 100644
--- a/tests/unit/pipe/test_pipe_module.py
+++ b/tests/unit/pipe/test_pipe_module.py
@@ -60,9 +60,12 @@ def batch_input():
class TestPipeModuleSequential(DistributedTest):
world_size = 2
+ # needs to be set for torch.compile: running torch.compile with daemonic process causes an error
+ non_daemonic_procs = True
@pytest.mark.parametrize("activation_checkpoints", [False, True])
- def test(self, sequential_model, simple_config, batch_input, activation_checkpoints):
+ @pytest.mark.parametrize("use_compile", [False, True])
+ def test(self, sequential_model, simple_config, batch_input, activation_checkpoints, use_compile):
base_model = copy.deepcopy(sequential_model)
base_input = batch_input.clone().detach()
base_output = base_model(base_input)
@@ -71,7 +74,8 @@ def test(self, sequential_model, simple_config, batch_input, activation_checkpoi
pipe_model = copy.deepcopy(sequential_model)
pipe_model = PipelineModule(layers=pipe_model, num_stages=2)
-
+ if (use_compile):
+ pipe_model.compile()
# Ensure all parameters are accounted for.
my_params = sum(p.numel() for p in pipe_model.parameters())
total_pipe_params = torch.LongTensor([my_params]).to(get_accelerator().device_name())
diff --git a/tests/unit/runtime/activation_checkpointing/test_activation_checkpointing.py b/tests/unit/runtime/activation_checkpointing/test_activation_checkpointing.py
index 22a61003b31e..dd3bcd7fb6bd 100644
--- a/tests/unit/runtime/activation_checkpointing/test_activation_checkpointing.py
+++ b/tests/unit/runtime/activation_checkpointing/test_activation_checkpointing.py
@@ -8,6 +8,7 @@
import pytest
import torch
import deepspeed
+from deepspeed.pipe import PipelineModule, LayerSpec
from deepspeed.accelerator import get_accelerator
from copy import deepcopy
from unit.common import DistributedTest
@@ -259,3 +260,52 @@ def test_ckpt_non_tensor_output_ordering(self, non_tensor_output):
else:
ordering += [torch.is_tensor(non_tensor_output)]
_test_activation_checkpoint_ordering(module, ordering, inputs)
+
+
+class TestCheckpointableLayersConfig(DistributedTest):
+ world_size = 1
+
+ def test_gpt2_checkpointable_layers(self):
+ if get_accelerator().device_name() == "cpu":
+ pytest.skip("CPU accelerator does not support this test yet")
+
+ # Create a simple topology for testing
+ from deepspeed.runtime.pipe.topology import PipeModelDataParallelTopology
+ topo = PipeModelDataParallelTopology(num_pp=1, num_mp=1, num_dp=1)
+
+ # Create test classes that we want to checkpoint
+ class TestTransformerLayer(torch.nn.Module):
+
+ def forward(self, x):
+ return x
+
+ class ParallelTransformerLayerPipe(TestTransformerLayer):
+ pass
+
+ class GMLPBlock(TestTransformerLayer):
+ pass
+
+ # Create a mock GPT2 model with different layer types
+ class TestGPT2ModelPipe(PipelineModule):
+
+ def __init__(self):
+ self.layers_spec = [
+ LayerSpec(ParallelTransformerLayerPipe),
+ LayerSpec(GMLPBlock),
+ LayerSpec(torch.nn.Linear, 10, 10), # Should not be checkpointed
+ ]
+
+ super().__init__(layers=self.layers_spec,
+ topology=topo,
+ checkpointable_layers=["GMLPBlock", "ParallelTransformerLayerPipe"])
+
+ model = TestGPT2ModelPipe()
+ model.to(get_accelerator().device_name())
+
+ # Build layers manually for testing
+ layers = [spec.build() for spec in model.layers_spec]
+
+ # Test that _is_checkpointable returns correct values
+ assert model._is_checkpointable([layers[0]]) == True # ParallelTransformerLayerPipe
+ assert model._is_checkpointable([layers[1]]) == True # GMLPBlock
+ assert model._is_checkpointable([layers[2]]) == False # Linear layer