Skip to content

Commit

Permalink
Merge pull request #16 from argonne-lcf/distributed_loading
Browse files Browse the repository at this point in the history
Distributed data loading
  • Loading branch information
saforem2 authored Jun 19, 2024
2 parents d012937 + a36569e commit ef5356b
Show file tree
Hide file tree
Showing 12 changed files with 562 additions and 205 deletions.
104 changes: 69 additions & 35 deletions ALCF/test_blendable_dataset.py
Original file line number Diff line number Diff line change
@@ -1,40 +1,66 @@
#!/usr/bin/env python
import time
start_time = time.time()
from mpi4py import MPI
import os
from megatron.data.gpt_dataset import build_train_valid_test_datasets
import numpy as np
from megatron.global_vars import set_args, set_global_variables, get_args
from megatron.arguments import parse_args
from megatron.initialize import initialize_megatron
from megatron.data.data_samplers import build_pretraining_data_loader
from mpi4py import MPI

import torch
from megatron.core import mpu


comm = MPI.COMM_WORLD
from megatron.utils import PerfTrace, Profile


import datetime
def print_rank_0(msg):
if comm.rank==0:
print(f" [INFO][{datetime.datetime.now()}] {msg}", flush=True)
end_time = time.time()
print_rank_0(f"Loaded python modules in {end_time - start_time} seconds")
initialize_megatron(allow_no_cuda=True)
comm.Barrier()
print_rank_0(f"Barrier synchonization time: {time.time() - end_time} seconds")
args = get_args()
if os.getenv('DLIO_PROFILER_DATASET_DIR') is not None:
extra_trace_path = os.environ['DLIO_PROFILER_DATASET_DIR']
else:
extra_trace_path=''
PerfTrace.initialize_log(f"{args.trace_dir}/trace-{comm.rank}-of-{comm.size}.pfw", f"{args.data_cache_path}:{extra_trace_path}:{args.data_path}:{args.save}:{args.load}", process_id=comm.rank)
dlp = Profile("TEST_BLENDABLEDATASET")

os.makedirs(args.trace_dir, exist_ok=True)


data_file_list = args.data_file_list
if comm.rank==0:
print(f"Reading data from {args.data_file_list}")
print_rank_0(f"Reading data from {args.data_file_list}")
files = []
weights = []
flist = []
with open(data_file_list, 'r') as fin:
for f in fin.readlines():
w, fname = f.split()
w, fname, c = f.split()
weights.append(float(w))
flist.append(fname)
files.append(float(w))
files.append(fname)
files.append(c)
splits_string="100,0,0"

weights = np.array(weights)
weights = weights/np.sum(weights)

num_samples = args.global_batch_size*args.train_iters
num_datasets = len(weights)
if comm.rank==0:
print(f"Number of datasets: {num_datasets}")
print(f"Global batch size: {args.global_batch_size}")
print(f"Training iterations: {args.train_iters}")
print_rank_0(f"Number of datasets: {num_datasets}")
print_rank_0(f"Global batch size: {args.global_batch_size}")
print_rank_0(f"Training iterations: {args.train_iters}")
train_valid_test_num_samples = [num_samples, 0, 0]
seed=args.seed
data_impl = args.data_impl
Expand All @@ -43,38 +69,46 @@
splits_string = "1,0,0"

# Build datasets
start_build_dataset = time.time()

print_rank_0(f"Starting to build the blendable dataset")
train_ds, valid_ds, test_ds = build_train_valid_test_datasets(files, data_impl, splits_string,
train_valid_test_num_samples,
seq_length, seed, skip_warmup, data_cache_path=args.data_cache_path)

dataset_idx = [train_ds.dataset_index[i] for i in range(num_samples)]
ratio_select=np.zeros(num_datasets)
#for i in range(num_datasets):
# ratio_select[i] = np.sum([i==d for d in dataset_idx])/num_samples
if comm.rank ==0:
print(f"Total number of samples: {len(train_ds)}")
print(f"Weights set: {weights[:min(8, num_datasets)]}")
#print(f"Weights across training: {ratio_select[:min(8, num_datasets)]}")

for e in range(min(100, args.train_iters)):
ratio_select=np.zeros(num_datasets)
for i in range(num_datasets):
ratio_select[i] = np.sum([i==d for d in dataset_idx[e*args.global_batch_size:(e+1)*args.global_batch_size]])/args.global_batch_size
if comm.rank==0:
print(f"iter-{e}: {ratio_select[:min(8, num_datasets)]}")

end_build_dataset = time.time()
print_rank_0(f"Finished building the blendable dataset in {end_build_dataset - start_build_dataset} second")
print_rank_0(f"Total number of samples: {len(train_ds)}")
print_rank_0(f"Weights set: {weights[:min(8, num_datasets)]}")

print("First 10 samples")
for i in range(10):
if comm.rank==0:
print(f"Sample: {i} \t dataset_idx: {train_ds.dataset_index[i]}, sample_idx: {train_ds.dataset_sample_index[i]}")

#### Build data loaders
start_build_dataloader = time.time()
print_rank_0(f"Starting to build the data loader")
rank_in_parallel_group = mpu.get_sequence_parallel_rank()
print(rank_in_parallel_group)
if rank_in_parallel_group == 0:
train_dataloader = build_pretraining_data_loader(
train_ds, args.consumed_train_samples)
valid_dataloader = build_pretraining_data_loader(
train_dataloader = build_pretraining_data_loader(
train_ds, args.consumed_train_samples)
valid_dataloader = build_pretraining_data_loader(
valid_ds, args.consumed_valid_samples)
test_dataloader = build_pretraining_data_loader(test_ds, 0)
test_dataloader = build_pretraining_data_loader(test_ds, 0)
end_build_dataloader = time.time()
print_rank_0(f"Finished building the data loader in {end_build_dataloader - start_build_dataloader} second")

print_rank_0(f"Starting loading the data")
start_loading_time = time.time()
NUM_ITEMS=1
SLEEP_TIME=10.0
@dlp.log
def compute(ct):
time.sleep(ct)
n=0
start_time = time.time()
for i in dlp.iter(train_dataloader):
print(f"[{comm.rank}] DATA {i}")
n+=1
if (n%NUM_ITEMS==0):
print_rank_0(f"Proccessed {n}th-batch in {time.time() - start_time}")
if n>=1000:
break
start_time = time.time()
end_loading_time = time.time()
print_rank_0(f"Finished loading the data ({n} batches) in {end_loading_time - start_loading_time}")
4 changes: 4 additions & 0 deletions megatron/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -874,6 +874,8 @@ def _add_training_args(parser):
'training if SIGTERM is received')
group.add_argument('--tensorboard-dir', type=str, default=None,
help='Write TensorBoard logs to this directory.')
group.add_argument('--trace-dir', type=str, default="./trace/",
help='Write trace logs to this directory.')
group.add_argument('--no-masked-softmax-fusion',
action='store_false',
help='Disable fusion of query_key_value scaling, '
Expand Down Expand Up @@ -1502,6 +1504,8 @@ def _add_zero_args(parser):
help='Remote device for ZeRO-3 initialized parameters.')
group.add_argument('--use-pin-memory', action='store_true',
help='Use pinned CPU memory for ZeRO-3 initialized model parameters.')
group.add_argument('--use-mics', action='store_true',
help='Use MiCS')
return parser

def _add_memoryopt_args(parser):
Expand Down
23 changes: 12 additions & 11 deletions megatron/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from .utils import (unwrap_model,
print_rank_0,
is_rank_0)
from .utils import PerfTrace, Profile

from deepspeed.checkpoint import (
ORIGINAL_VOCAB_SIZE,
Expand All @@ -38,7 +39,7 @@

_CHECKPOINT_VERSION = None


dlp = Profile("CHECKPOINT")
def set_checkpoint_version(value):
global _CHECKPOINT_VERSION
if _CHECKPOINT_VERSION is not None:
Expand Down Expand Up @@ -167,7 +168,7 @@ def get_checkpoint_tracker_filename(checkpoints_path):
training to restart from."""
return os.path.join(checkpoints_path, 'latest_checkpointed_iteration.txt')


@dlp.log
def read_metadata(tracker_filename):
# Read the tracker file and either set the iteration or
# mark it as a release checkpoint.
Expand Down Expand Up @@ -207,7 +208,7 @@ def read_metadata(tracker_filename):
max_iter = iteration
return max_iter, release


@dlp.log
def get_rng_state():
""" collect rng state across data parallel ranks """
args = get_args()
Expand All @@ -233,7 +234,7 @@ def get_rng_state():

return rng_state_list


@dlp.log
def save_checkpoint(iteration, model, optimizer, opt_param_scheduler):
"""Save a model checkpoint."""
args = get_args()
Expand Down Expand Up @@ -338,7 +339,7 @@ def state_dict_for_save_checkpoint_deepspeed(destination=None, prefix='', keep_v
if torch.distributed.is_initialized():
torch.distributed.barrier()


@dlp.log
def _transpose_first_dim(t, num_splits, num_splits_first, model):
input_shape = t.size()
# We use a self_attention module but the values extracted aren't
Expand Down Expand Up @@ -408,7 +409,7 @@ def fix_query_key_value_ordering(model, checkpoint_version):
print_rank_0(" succesfully fixed query-key-values ordering for"
" checkpoint version {}".format(checkpoint_version))


@dlp.log
def _load_base_checkpoint(load_dir, rank0=False):
""" Load the base state_dict from the given directory
Expand Down Expand Up @@ -463,7 +464,7 @@ def _load_base_checkpoint(load_dir, rank0=False):

return state_dict, release


@dlp.log
def load_args_from_checkpoint(args, load_arg='load'):
"""Set required arguments from the checkpoint specified in the
arguments.
Expand Down Expand Up @@ -544,7 +545,7 @@ def _set_arg(arg_name, old_arg_name=None, force=False):
_set_arg('num_layers_per_virtual_pipeline_stage')
return args, checkpoint_args


@dlp.log
def load_lr_state_dict(strict: bool = False) -> dict:
"""Load {iteration, lr} from .yaml file when restoring from checkpoint."""
args = get_args()
Expand All @@ -568,7 +569,7 @@ def load_lr_state_dict(strict: bool = False) -> dict:
)
return lr_state_dict


@dlp.log
def save_lr_state_dict() -> None:
"""Save {iteration, lr} to .yaml file for safe-keeping.
Expand All @@ -590,7 +591,7 @@ def save_lr_state_dict() -> None:
f
)


@dlp.log
def load_checkpoint(
model,
optimizer,
Expand Down Expand Up @@ -811,7 +812,7 @@ def load_checkpoint(

return iteration


@dlp.log
def load_biencoder_checkpoint(model, only_query_model=False,
only_context_model=False, custom_load_path=None):
"""
Expand Down
27 changes: 15 additions & 12 deletions megatron/core/pipeline_parallel/schedules.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@
from megatron.utils import print_rank_0, unwrap_model
from megatron.model import DistributedDataParallel as LocalDDP
from megatron.model import Float16Module
from megatron.utils import Profile

dlp = Profile("CORE")
# Types
Shape = Union[List[int], torch.Size]

Expand Down Expand Up @@ -124,6 +126,7 @@ def deallocate_output_tensor(out, deallocate_pipeline_outputs=False):
dtype = out.dtype,
)

@dlp.log
def custom_backward(output, grad_output):
'''Directly call C++ autograd engine.
Expand Down Expand Up @@ -162,7 +165,7 @@ def custom_backward(output, grad_output):




@dlp.log
def forward_step(forward_step_func,
data_iterator,
model,
Expand Down Expand Up @@ -227,7 +230,7 @@ def forward_step(forward_step_func,
return output_tensor
return [output_tensor]


@dlp.log
def backward_step(
input_tensor,
output_tensor,
Expand Down Expand Up @@ -305,7 +308,7 @@ def backward_step(
config.timers('backward-compute').stop()
return input_tensor_grad


@dlp.log
def forward_backward_no_pipelining(*,
forward_step_func,
data_iterator: Union[Iterator, List[Iterator]],
Expand Down Expand Up @@ -352,7 +355,7 @@ def forward_backward_no_pipelining(*,
forward_data_store = []
input_tensor, output_tensor_grad = None, None
with no_sync_func():
for i in range(num_microbatches - 1):
for i in dlp.iter(range(num_microbatches - 1)):
output_tensor = forward_step(forward_step_func, data_iterator, model, num_microbatches,
input_tensor, forward_data_store, config, collect_non_loss_data)
if not forward_only:
Expand All @@ -370,7 +373,7 @@ def forward_backward_no_pipelining(*,

return forward_data_store


@dlp.log
def forward_backward_pipelining_with_interleaving(*,
forward_step_func,
data_iterator: Union[Iterator, List[Iterator]],
Expand Down Expand Up @@ -923,7 +926,7 @@ def get_tensor_shapes(*,
return tensor_shapes



@dlp.log
def recv_forward(tensor_shapes, config):
input_tensors = []
for tensor_shape in tensor_shapes:
Expand All @@ -933,7 +936,7 @@ def recv_forward(tensor_shapes, config):
input_tensors.append(p2p_communication.recv_forward(tensor_shape, config))
return input_tensors


@dlp.log
def recv_backward(tensor_shapes, config):
output_tensor_grads = []
for tensor_shape in tensor_shapes:
Expand All @@ -943,7 +946,7 @@ def recv_backward(tensor_shapes, config):
output_tensor_grads.append(p2p_communication.recv_backward(tensor_shape, config))
return output_tensor_grads


@dlp.log
def send_forward(output_tensors, tensor_shapes, config):
if not isinstance(output_tensors, list):
output_tensors = [output_tensors]
Expand All @@ -952,7 +955,7 @@ def send_forward(output_tensors, tensor_shapes, config):
continue
p2p_communication.send_forward(output_tensor, config)


@dlp.log
def send_backward(input_tensor_grads, tensor_shapes, config):
if not isinstance(input_tensor_grads, list):
input_tensor_grads = [input_tensor_grads]
Expand All @@ -961,7 +964,7 @@ def send_backward(input_tensor_grads, tensor_shapes, config):
continue
p2p_communication.send_backward(input_tensor_grad, config)


@dlp.log
def send_forward_recv_backward(output_tensors, tensor_shapes, config):
if not isinstance(output_tensors, list):
output_tensors = [output_tensors]
Expand All @@ -975,7 +978,7 @@ def send_forward_recv_backward(output_tensors, tensor_shapes, config):
output_tensor_grads.append(output_tensor_grad)
return output_tensor_grads


@dlp.log
def send_backward_recv_forward(input_tensor_grads, tensor_shapes, config):
if not isinstance(input_tensor_grads, list):
input_tensor_grads = [input_tensor_grads]
Expand All @@ -989,7 +992,7 @@ def send_backward_recv_forward(input_tensor_grads, tensor_shapes, config):
input_tensors.append(input_tensor)
return input_tensors


@dlp.log
def forward_backward_pipelining_without_interleaving(*,
forward_step_func,
data_iterator: Union[Iterator, List[Iterator]],
Expand Down
Loading

0 comments on commit ef5356b

Please sign in to comment.