Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[shardformer] merge shardformer to main #4152

Merged
merged 49 commits into from
Jul 4, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
49 commits
Select commit Hold shift + click to select a range
604a213
[shardformer] init shardformer code structure (#3731)
FoolPlayer May 22, 2023
ffacf0f
[shardformer]: Feature/shardformer, add some docstring and readme (#3…
FoolPlayer May 24, 2023
69d3daa
[shardformer] updated readme (#3827)
FrankLeeeee May 24, 2023
0470f1b
[shardformer] refactored the user api (#3828)
FrankLeeeee May 24, 2023
051e970
[shardformer] update readme with modules implement doc (#3834)
FoolPlayer May 24, 2023
3e840f7
[shardformer] add Dropout layer support different dropout pattern (#3…
FoolPlayer Jun 1, 2023
bf9c2fd
update README (#3909)
FoolPlayer Jun 6, 2023
551fec3
[shardformer] add gpt2 policy and modify shard and slicer to support …
FoolPlayer Jun 7, 2023
e5bc7e3
[shardformer] Align bert value (#3907)
FoolPlayer Jun 9, 2023
661dc3b
[shardformer] Unit test (#3928)
FoolPlayer Jun 12, 2023
702513a
[shardformer] Add dropout layer in shard model and refactor policy ap…
FoolPlayer Jun 12, 2023
17d1607
[shardformer] support llama model using shardformer (#3969)
wukong1992 Jun 13, 2023
e849d1b
[shardformer] shardformer support t5 model (#3994)
wukong1992 Jun 15, 2023
735e44b
[Shardformer] Downstream bert (#3979)
FoolPlayer Jun 15, 2023
73cacb7
[shardformer] fix an error in readme (#3988)
FoolPlayer Jun 15, 2023
45a3110
[device] support init device mesh from process group (#3990)
FrankLeeeee Jun 15, 2023
18396e7
[shardformer] Refactor shardformer api (#4001)
FoolPlayer Jun 15, 2023
579b617
[shardformer] integrated linear 1D with dtensor (#3996)
FrankLeeeee Jun 15, 2023
bdc405e
integrate with dist layer (#4011)
FoolPlayer Jun 16, 2023
2c366e3
[shardformer] refactored embedding and dropout to parallel module (#4…
FrankLeeeee Jun 16, 2023
eaa46d7
[shardformer] removed inplace tensor sharding (#4018)
FrankLeeeee Jun 16, 2023
60eb380
add vocabembedding layer
FoolPlayer Jun 16, 2023
90e1a0a
support bert with new api
FoolPlayer Jun 16, 2023
38ceded
[shardformer] updated doc (#4016)
FrankLeeeee Jun 16, 2023
c982769
[shardformer] fix bert and gpt downstream with new api (#4024)
FoolPlayer Jun 19, 2023
b2c5dd0
[shardformer] adapted llama to the new API (#4036)
FrankLeeeee Jun 19, 2023
8219d96
[shardformer] supported T5 and its variants (#4045)
FrankLeeeee Jun 19, 2023
0113097
[shardformer] add gpt2 test and layer class refactor (#4041)
FoolPlayer Jun 20, 2023
ac3aef3
[shardformer] adapted T5 and LLaMa test to use kit (#4049)
FrankLeeeee Jun 21, 2023
e5d4a87
[shardformer] refactored the shardformer layer structure (#4053)
FrankLeeeee Jun 21, 2023
d5d9178
support kit use for bert/gpt test (#4055)
FoolPlayer Jun 22, 2023
9436f73
[shardformer] support module saving and loading (#4062)
FrankLeeeee Jun 22, 2023
8108c35
[shardformer] add linearconv1d test (#4067)
FoolPlayer Jun 22, 2023
a484c71
[shardformer] supported fused qkv checkpoint (#4073)
FrankLeeeee Jun 23, 2023
12801e8
[shardformer] Add layernorm (#4072)
FoolPlayer Jun 23, 2023
d88844c
[test] fixed tests failed due to dtensor change (#4082)
FrankLeeeee Jun 26, 2023
4e0db99
[shardformer] refactored layernorm (#4086)
FrankLeeeee Jun 26, 2023
a7433a0
[shardformer] shardformer support opt models (#4091)
flybird11111 Jun 27, 2023
ad604f7
[shardformer] support vision transformer (#4096)
klhhhhh Jun 28, 2023
8b0930c
[shardformer] supported bloom model (#4098)
FrankLeeeee Jun 28, 2023
92e669e
[shardformer] supported fused normalization (#4112)
FrankLeeeee Jun 30, 2023
8d3f077
[shardformer] integrate with data parallelism (#4103)
FrankLeeeee Jun 30, 2023
60d2cad
[shardformer] import huggingface implicitly (#4101)
FrankLeeeee Jun 30, 2023
26ecfd7
[shardformer] added embedding gradient check (#4124)
FrankLeeeee Jun 30, 2023
b6f4e05
[shardformer] write an shardformer example with bert finetuning (#4126)
flybird11111 Jun 30, 2023
1b4a901
[shardformer] refactored some doc and api (#4137)
FrankLeeeee Jul 3, 2023
f8dcf9d
[shardformer] made tensor parallelism configurable (#4144)
FrankLeeeee Jul 4, 2023
d1db043
[shardformer] added development protocol for standardization (#4149)
FrankLeeeee Jul 4, 2023
dd9fe39
[chat] removed cache file (#4155)
FrankLeeeee Jul 4, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ def register_strategy(self, compute_resharding_cost: bool = True) -> StrategiesV
remove_strategy_list = []
for strategy in self.strategies_vector:
shard_axis_list = []
last_axis = len(self.device_mesh.mesh_shape) - 1
last_axis = len(self.device_mesh.shape) - 1
for op_data, sharding_spec in strategy.sharding_specs.items():
if op_data.data is not None and isinstance(op_data.data, torch.Tensor):
for dim, shard_axes in sharding_spec.dim_partition_dict.items():
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -984,18 +984,18 @@ def split_batch_dim_both_contract(self, mesh_dim_0, mesh_dim_1):
def collate_strategies(self) -> List[ShardingStrategy]:
strategy_list = []
device_mesh_is_1d = True
if len(self.device_mesh.mesh_shape) == 2 and 1 not in self.device_mesh.mesh_shape:
if len(self.device_mesh.shape) == 2 and 1 not in self.device_mesh.shape:
device_mesh_is_1d = False

if device_mesh_is_1d:
# split only the batch dimension
# Sb = Sb x Sb
# can be None as it is only for 1D device mesh
# only for 1D device mesh
if len(self.device_mesh.mesh_shape) == 1:
if len(self.device_mesh.shape) == 1:
mesh_dim = 0
else:
mesh_dim = self.device_mesh.mesh_shape.index(1)
mesh_dim = self.device_mesh.shape.index(1)
strategy_list.append(self.split_one_batch_dim(mesh_dim))
else:
# for 2D device mesh
Expand Down
4 changes: 2 additions & 2 deletions colossalai/auto_parallel/tensor_shard/utils/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,8 @@ def check_sharding_spec_validity(sharding_spec: ShardingSpec, tensor: torch.Tens
# make sure all dims are covered in sharding spec
sharding_len = len(sharding_spec.sharding_sequence)
tensor_num_dim = tensor.dim()
num_devices_in_col = sharding_spec.device_mesh.mesh_shape[0]
num_devices_in_row = sharding_spec.device_mesh.mesh_shape[1]
num_devices_in_col = sharding_spec.device_mesh.shape[0]
num_devices_in_row = sharding_spec.device_mesh.shape[1]
assert sharding_len == tensor_num_dim, \
f'The ShardingSpec ({sharding_spec.sharding_sequence}) is created for {sharding_len}-dimension tensor, but the given tensor is {tensor_num_dim}-dimension ({tensor.shape}).'

Expand Down
6 changes: 3 additions & 3 deletions colossalai/checkpoint_io/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import torch.nn as nn
from torch.optim import Optimizer

from colossalai.tensor.d_tensor.d_tensor import DTensor
from colossalai.tensor.d_tensor import is_distributed_tensor

SAFE_WEIGHTS_NAME = "model.safetensors"
WEIGHTS_NAME = "pytorch_model.bin"
Expand Down Expand Up @@ -99,7 +99,7 @@ def shard_model_checkpoint(state_dict: torch.Tensor, max_shard_size: int = 1024)
for key, weight in state_dict.items():
ret_block = None
ret_block_size = 0
if type(weight) != DTensor:
if not is_distributed_tensor(weight):
weight_size = calculate_tensor_size(weight)

# If this weight is going to tip up over the maximal size, we split.
Expand Down Expand Up @@ -146,7 +146,7 @@ def shard_optimizer_checkpoint(state_dict: dict, max_shard_size: int = 1024) ->
continue

# If the states are stored as DTensors, mark isDTensor as true.
if type(state_tensor) == DTensor:
if is_distributed_tensor(state_tensor):
isDTensor = True
state_size += calculate_tensor_size(state_tensor)

Expand Down
Loading