Skip to content

Commit 19e1a5c

Browse files
authored
[shardformer] update colo attention to support custom mask (#5510)
* [feature] refactor colo attention (#5462) * [extension] update api * [feature] add colo attention * [feature] update sdpa * [feature] update npu attention * [feature] update flash-attn * [test] add flash attn test * [test] update flash attn test * [shardformer] update modeling to fit colo attention (#5465) * [misc] refactor folder structure * [shardformer] update llama flash-attn * [shardformer] fix llama policy * [devops] update tensornvme install * [test] update llama test * [shardformer] update colo attn kernel dispatch * [shardformer] update blip2 * [shardformer] update chatglm * [shardformer] update gpt2 * [shardformer] update gptj * [shardformer] update opt * [shardformer] update vit * [shardformer] update colo attention mask prep * [shardformer] update whisper * [test] fix shardformer tests (#5514) * [test] fix shardformer tests * [test] fix shardformer tests
1 parent 9a3321e commit 19e1a5c

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

45 files changed

+2538
-1165
lines changed

.github/workflows/build_on_pr.yml

+2-2
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ jobs:
117117
cd TensorNVMe
118118
conda install cmake
119119
pip install -r requirements.txt
120-
pip install -v .
120+
DISABLE_URING=1 pip install -v .
121121
122122
- name: Store TensorNVMe Cache
123123
run: |
@@ -201,4 +201,4 @@ jobs:
201201
uses: actions/upload-artifact@v3
202202
with:
203203
name: report
204-
path: report/
204+
path: report/

.github/workflows/build_on_schedule.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ jobs:
4444
cd TensorNVMe
4545
conda install cmake
4646
pip install -r requirements.txt
47-
pip install -v .
47+
DISABLE_URING=1 pip install -v .
4848
4949
- uses: actions/checkout@v2
5050
if: steps.check-avai.outputs.avai == 'true'

.github/workflows/compatiblity_test_on_dispatch.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ jobs:
6666
cd TensorNVMe
6767
apt update && apt install -y cmake
6868
pip install -r requirements.txt
69-
pip install -v .
69+
DISABLE_URING=1 pip install -v .
7070
- uses: actions/checkout@v2
7171
with:
7272
ssh-key: ${{ secrets.SSH_KEY_FOR_CI }}

.github/workflows/compatiblity_test_on_pr.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ jobs:
6060
cd TensorNVMe
6161
apt update && apt install -y cmake
6262
pip install -r requirements.txt
63-
pip install -v .
63+
DISABLE_URING=1 pip install -v .
6464
- uses: actions/checkout@v2
6565
with:
6666
ssh-key: ${{ secrets.SSH_KEY_FOR_CI }}

.github/workflows/compatiblity_test_on_schedule.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ jobs:
5656
cd TensorNVMe
5757
apt update && apt install -y cmake
5858
pip install -r requirements.txt
59-
pip install -v .
59+
DISABLE_URING=1 pip install -v .
6060
- uses: actions/checkout@v2
6161
with:
6262
ssh-key: ${{ secrets.SSH_KEY_FOR_CI }}

colossalai/kernel/kernel_loader.py

+20-4
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
CpuAdamX86Extension,
77
FlashAttentionDaoCudaExtension,
88
FlashAttentionNpuExtension,
9-
FlashAttentionXformersCudaExtension,
9+
FlashAttentionSdpaCudaExtension,
1010
FusedOptimizerCudaExtension,
1111
LayerNormCudaExtension,
1212
MoeCudaExtension,
@@ -65,9 +65,9 @@ def load(self, ext_name: str = None):
6565
else:
6666
usable_exts = []
6767
for ext in exts:
68-
if ext.is_hardware_available():
68+
if ext.is_available():
6969
# make sure the machine is compatible during kernel loading
70-
ext.assert_hardware_compatible()
70+
ext.assert_compatible()
7171
usable_exts.append(ext)
7272

7373
assert len(usable_exts) != 0, f"No usable kernel found for {self.__class__.__name__} on the current machine."
@@ -106,4 +106,20 @@ class ScaledUpperTriangleMaskedSoftmaxLoader(KernelLoader):
106106

107107

108108
class FlashAttentionLoader(KernelLoader):
109-
REGISTRY = [FlashAttentionNpuExtension, FlashAttentionDaoCudaExtension, FlashAttentionXformersCudaExtension]
109+
REGISTRY = [
110+
FlashAttentionNpuExtension,
111+
FlashAttentionDaoCudaExtension,
112+
FlashAttentionSdpaCudaExtension,
113+
]
114+
115+
116+
class FlashAttentionWithPaddingMaskLoader(KernelLoader):
117+
REGISTRY = [FlashAttentionNpuExtension, FlashAttentionDaoCudaExtension]
118+
119+
120+
class FlashAttentionWithCustomMaskLoader(KernelLoader):
121+
REGISTRY = [FlashAttentionNpuExtension, FlashAttentionSdpaCudaExtension]
122+
123+
124+
class FlashAttentionForFloatAndCustomMaskLoader(KernelLoader):
125+
REGISTRY = [FlashAttentionSdpaCudaExtension]

colossalai/nn/layer/colo_attention.py

-209
This file was deleted.

colossalai/shardformer/layer/__init__.py

+3
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from .attn import AttnMaskType, ColoAttention
12
from .dropout import DropoutForParallelInput, DropoutForReplicatedInput
23
from .embedding import Embedding1D, VocabParallelEmbedding1D
34
from .linear import Linear1D_Col, Linear1D_Row
@@ -23,4 +24,6 @@
2324
"FusedRMSNorm",
2425
"FusedLinear1D_Col",
2526
"ParallelModule",
27+
"AttnMaskType",
28+
"ColoAttention",
2629
]

0 commit comments

Comments
 (0)