Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add sp tp #5

Merged
merged 15 commits into from
Jul 26, 2024
164 changes: 87 additions & 77 deletions tools/fsdp_sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
import torch.distributed as dist
import torch.distributed.checkpoint as dcp
from accelerate.utils import set_module_tensor_to_device
from datasets import Dataset
from mmengine import mkdir_or_exist
from mmengine.dist import infer_launcher, init_dist
from mmengine.runner import set_random_seed
Expand All @@ -40,20 +39,25 @@
is_torch_sdpa_available)

from xtuner._lite import AutoTokenizer, get_logger
from xtuner._lite.accelerate import (LORA_TARGET_MAP, LoadWoInit,
dispatch_modules, packed_sequence)
from xtuner._lite.accelerate.fsdp import (RECOMPUTE_MODULES,
all_required_grad_wrap_policy,
checkpoint_check_fn, dp_lazy_init,
layer_auto_wrap_policy)
from xtuner._lite.accelerate import (LORA_TARGET_MAP, dispatch_modules,
packed_sequence)
from xtuner._lite.chat import CHAT_TEMPLATE_MAP
from xtuner._lite.datasets import (OPENAI_FORMAT_MAP, HardPackerForText,
SoftPackerForText, TextCollator,
TextOnlineTokenizeDataset,
from xtuner._lite.datasets import (OPENAI_FORMAT_MAP, SoftPackerForText,
TextCollator, TextOnlineTokenizeDataset,
TextTokenizedDataset, TextTokenizeFunction)
from xtuner._lite.datasets.load import (LOAD_FN_MAP, load_datasets,
load_from_cache)
from xtuner._lite.parallel import LengthGroupedSampler, ParallelSampler
from xtuner._lite.parallel import (LengthGroupedSampler, ParallelSampler,
get_dp_mesh, get_dp_world_size,
get_sp_group, get_sp_mesh,
get_sp_world_size,
reduce_sequence_parallel_loss,
setup_parallel, split_for_sequence_parallel)
from xtuner._lite.parallel.fsdp import (RECOMPUTE_MODULES, LoadWoInit,
all_required_grad_wrap_policy,
checkpoint_check_fn, dp_lazy_init,
dp_sp_lazy_init,
layer_auto_wrap_policy)

logger = get_logger()

Expand Down Expand Up @@ -134,6 +138,8 @@ def parse_args():
choices=['full', 'hybrid'],
help=('The sharding strategy to be used for distributed training.'))
model_args.add_argument('--cpu-offload', action='store_true', help=(''))
model_args.add_argument('--sp-size', type=int, default=1, help='')

data_args = parser.add_argument_group('data', 'Dataset Related Settings')
data_args.add_argument(
'--datasets',
Expand Down Expand Up @@ -328,7 +334,12 @@ def sft(args):
set_random_seed(args.seed)

world_size = int(os.environ['WORLD_SIZE'])
dp_size = world_size
sp_size = args.sp_size

setup_parallel(sp_size=sp_size)
dp_mesh = get_dp_mesh()
sp_mesh = get_sp_mesh()
dp_size = get_dp_world_size()

if args.global_batch_size < dp_size or args.global_batch_size % dp_size:
raise ValueError(f'The `global_batch_size`({args.global_batch_size}) '
Expand All @@ -346,12 +357,7 @@ def sft(args):
'folder, which may lead to inaccurate '
'cache results.')

device_mesh = init_device_mesh(
'cuda', (dp_size, ), mesh_dim_names=('dp', ))

dp_mesh = device_mesh['dp']

rank = dp_mesh.get_local_rank()
rank = dist.get_rank()
timestamp = datetime.now().strftime('%Y%m%d%H%M%S')

objects = [timestamp]
Expand Down Expand Up @@ -396,18 +402,25 @@ def sft(args):

start_load_data_t = time.time()

chat_template = CHAT_TEMPLATE_MAP[args.chat_template]

tokenizer = AutoTokenizer.from_pretrained(
args.tokenizer if args.tokenizer else args.llm,
trust_remote_code=True,
padding_side='right')

if args.dset_from_cache:
# packer = partial(SoftPackerForText, max_length=args.max_length)
_datasets = load_from_cache(args.dset_cache_dir)
else:
if args.dset_pack_level == 'soft':
init_fn = partial(
SoftPackerForText.from_cache, max_length=args.max_length)
elif args.dset_pack_level == 'hard':
raise NotImplementedError
else:
init_fn = partial(
TextTokenizeFunction.from_cache, max_length=args.max_length)
_datasets = load_from_cache(args.dset_cache_dir, init_fn)
dist.barrier()

else:
chat_template = CHAT_TEMPLATE_MAP[args.chat_template]
tokenize_fns = []
init_fns = []
for dset_format in args.dset_formats:
Expand All @@ -419,16 +432,17 @@ def sft(args):
tokenize_fn = TextTokenizeFunction(tokenizer, chat_template,
dset_format)

if args.dset_cache_dir or args.dset_pack_level:
# Before caching or packing dataset, you need to first
# tokenize the dataset and then transform it into a
# Huggingface `Dataset`
init_fn = Dataset.from_list
if args.dset_pack_level == 'soft':
init_fn = partial(
SoftPackerForText, max_length=args.max_length)
elif args.dset_cache_dir:
init_fn = partial(
TextTokenizedDataset, max_length=args.max_length)
else:
# Use online tokenize when there is no need to cache or pack
# the dataset, thereby saving startup time.
init_fn = partial(
TextOnlineTokenizeDataset, tokenize_fn=tokenize_fn)
# Online tokenization is used when not using a pack dataset,
# saving startup time.
tokenize_fn = None

tokenize_fns.append(tokenize_fn)
Expand All @@ -447,40 +461,13 @@ def sft(args):

if (args.dset_pack_level or args.cache_dir) and rank == 0 and args.debug:
# Only the tokenized datasets can count the number of tokens
num_tokens = sum(sum(dset['num_tokens']) for dset in _datasets)
num_tokens = sum(sum(dset.total_tokens) for dset in _datasets)
logger.debug(f'[Dataset] {num_tokens} tokens.')

num_datasets = len(_datasets)
datasets = []
if args.dset_pack_level and args.dset_pack_level == 'soft':
pack_infos = SoftPackerForText.get_pack_infos(_datasets,
args.max_length)
for i in range(num_datasets):
_infos = pack_infos[i]
_dset = _datasets[i]
_packed_dset = SoftPackerForText(_dset, args.max_length, _infos)
datasets.append(_packed_dset)
elif args.dset_pack_level and args.dset_pack_level == 'hard':
pack_infos = HardPackerForText.get_pack_infos(_datasets,
args.max_length)
for i in range(num_datasets):
_infos = pack_infos[i]
_dset = _datasets[i]
_packed_dset = HardPackerForText(_dset, args.max_length, _infos)
datasets.append(_packed_dset)
elif args.dset_pack_level is None and args.dset_cache_dir:
datasets = []
for dset in _datasets:
datasets.append(TextTokenizedDataset(dset))
else:
datasets = []
for dset in _datasets:
datasets.append(TextOnlineTokenizeDataset(dset))

train_dataset = ConcatDataset(datasets)
train_dataset = ConcatDataset(_datasets)

if args.dset_pack_level and rank == 0:
ori_samples = sum([len(dset) for dset in _datasets])
ori_samples = sum([dset.num_samples for dset in _datasets])
packed_samples = len(train_dataset)
logger.info(f'[Dataset] (Original) {ori_samples} samples.')
logger.info(f'[Dataset] (Packed) {packed_samples} samples.')
Expand Down Expand Up @@ -540,7 +527,7 @@ def sft(args):
raise RuntimeError('The device does not support `bf16`, '
'please set `dtype` to `fp16`.')
else:
raise RuntimeError('`dtype` only supports `fp16``bf16`, or `auto`, '
raise RuntimeError('`dtype` only supports `fp16`, `bf16` or `auto`, '
f'but found {args.dtype}.')

llm_cfg = AutoConfig.from_pretrained(args.llm, trust_remote_code=True)
Expand Down Expand Up @@ -572,15 +559,22 @@ def sft(args):

dist.barrier()

param_init_fn = partial(
dp_lazy_init, module_map=meta_llm_map, dp_mesh=dp_mesh)
if get_sp_world_size() > 1:
param_init_fn = partial(
dp_sp_lazy_init,
module_map=meta_llm_map,
dp_mesh=dp_mesh,
sp_mesh=sp_mesh)
else:
param_init_fn = partial(
dp_lazy_init, module_map=meta_llm_map, dp_mesh=dp_mesh)

policies = [layer_auto_wrap_policy]
if args.use_lora:
policies.append(all_required_grad_wrap_policy)

if args.shard_strategy == 'full':
fsdp_device_mesh = dp_mesh
fsdp_device_mesh = init_device_mesh('cuda', (world_size, ))
strategy = ShardingStrategy.FULL_SHARD
elif args.shard_strategy == 'hybrid':
fsdp_device_mesh = init_device_mesh('cuda', (dp_size // 8, 8))
Expand Down Expand Up @@ -687,7 +681,7 @@ def warmup_fn(x):
train_dataloader.sampler.set_epoch(epoch, epoch_inner_step)
data_iterator = iter(train_dataloader)

if step <= warmup_steps:
if step < warmup_steps:
warmup_scheduler.step()
cur_lr = warmup_scheduler.get_last_lr()[0]
else:
Expand All @@ -711,15 +705,30 @@ def warmup_fn(x):
attention_mask = data['attention_mask'].cuda()
num_tokens = data['num_tokens'].cuda()

packed_ctx = packed_sequence(num_tokens, enable=pack_batch)
packed_ctx = packed_sequence(
num_tokens, enable=pack_batch, sp_size=get_sp_world_size())

with packed_ctx, autocast if args.use_lora else nullcontext():
if get_sp_world_size() > 1:
sp_group = get_sp_group()
# `dim` is 1 as the shape of tensor is (bs, seq_len, ...)
input_ids = split_for_sequence_parallel(
input_ids, dim=1, sp_group=sp_group)
labels = split_for_sequence_parallel(
labels, dim=1, sp_group=sp_group)

outputs = shard_llm(
input_ids=input_ids,
labels=labels,
attention_mask=attention_mask)

loss = outputs.loss
if get_sp_world_size() > 1:
tokens_cal_loss = (labels != -100).sum()
loss = reduce_sequence_parallel_loss(
loss, tokens_cal_loss, sp_group)

with packed_ctx:
with autocast if args.use_lora else nullcontext():
outputs = shard_llm(
input_ids=input_ids,
labels=labels,
attention_mask=attention_mask)
avg_iter_loss = outputs.loss / iters_per_step
avg_iter_loss = loss / iters_per_step

if scaler and args.use_lora:
scaler.scale(avg_iter_loss).backward()
Expand All @@ -732,9 +741,10 @@ def warmup_fn(x):
# still smaller than the max length after packing, will be
# padded to the max length. The last element of num tokens
# represents the count of pad tokens.
step_consumed_tokens += num_tokens[:-1].sum()
step_consumed_tokens += num_tokens[:-1].sum(
) / get_sp_world_size()
else:
step_consumed_tokens += num_tokens.sum()
step_consumed_tokens += num_tokens.sum() / get_sp_world_size()

grad_norm = shard_llm.clip_grad_norm_(args.max_grad_norm)
optimizer.step()
Expand Down Expand Up @@ -765,8 +775,8 @@ def warmup_fn(x):

num_digits = len(str(abs(total_steps)))
work_dir = args.work_dir
ckpt_dir = os.path.join(work_dir, f'ckpt-{step:0{num_digits}}')
hf_dir = os.path.join(work_dir, f'hf-{step:0{num_digits}}')
ckpt_dir = os.path.join(work_dir, f'ckpt-{step+1:0{num_digits}}')
hf_dir = os.path.join(work_dir, f'hf-{step+1:0{num_digits}}')
_options = StateDictOptions(cpu_offload=True, full_state_dict=True)

full_model_state_dict = get_model_state_dict(
Expand Down
Loading