Skip to content

Commit

Permalink
Merge branch 'pytorch-labs:main' into expand_multi_node
Browse files Browse the repository at this point in the history
  • Loading branch information
lessw2020 authored Feb 22, 2024
2 parents d345408 + 55a6b0b commit 5dda674
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 20 deletions.
23 changes: 8 additions & 15 deletions torchtrain/datasets/alpaca.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from torchtrain.datasets.tokenizer import TokenizerIf

from datasets import load_dataset
from datasets.distributed import split_dataset_by_node


class AlpacaDataset(IterableDataset):
Expand Down Expand Up @@ -44,32 +45,24 @@ def __init__(
rank: int = 0,
**kwargs
) -> None:
self._data = load_dataset("tatsu-lab/alpaca", split="train")
# TODO: This is a temporary solution for small datasets like Alpaca.
# For larger datasets we need to use a more scalable approach.
# Setting `streaming=True` works for large dataset, but the speed is slow.
ds = load_dataset("tatsu-lab/alpaca", split="train")
self.data_iterator = iter(split_dataset_by_node(ds, rank, world_size))
self._tokenizer = tokenizer
self.data_iterator = iter(self._data)
self.seq_len = seq_len
self.world_size = world_size
self.rank = rank
self.response_tag = "\n\n### Response:\n"

def __len__(self):
return len(self._data)

def __iter__(self):
max_buffer_token_len = 1 + self.seq_len
all_tokens: List[int] = []

for idx, sample in enumerate(self.data_iterator):
# select samples to pack in a round-robin fashion
# TODO: This is a temporary solution for small datasets like Alpaca.
# For larger datasets we need to use a more scalable approach.
if idx % self.world_size != self.rank:
continue
for sample in self.data_iterator:
sample_text = sample["text"]
sample_tokens = self._tokenizer.encode(sample_text, bos=True, eos=True)
all_tokens.extend(sample_tokens)

if len(all_tokens) >= max_buffer_token_len:
while len(all_tokens) >= max_buffer_token_len:
x = torch.LongTensor(all_tokens[:max_buffer_token_len])
# batched_x = x.reshape(self.batch_size, -1)
# update tokens to the remaining tokens
Expand Down
6 changes: 3 additions & 3 deletions torchtrain/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,15 +122,15 @@ def get_current_stats(self, return_data: bool = False):
)

display_str = ""
display_str += f"Current Memory: {self.device_name} ({self.device_index}): Reserved: {self.device_reserved_memory_pct}%,"
display_str += f"Alloc {self.device_alloc_memory_pct}%, Active: {self.device_active_memory_pct}%\n"
display_str += f"Current Memory: {self.device_name} ({self.device_index}): Reserved: {self.device_reserved_memory_pct}%, "
display_str += f"Alloc {self.device_alloc_memory_pct}%, Active: {self.device_active_memory_pct}%\n"

self.get_peak_stats(curr_mem)

peak_active_pct = self.get_pct_memory(self.peak_active_memory)
peak_allocated_pct = self.get_pct_memory(self.peak_allocated_memory)
peak_reserved_pct = self.get_pct_memory(self.peak_reserved_memory)
display_str += f"Peak Memory: Reserved {peak_reserved_pct}%, Alloc {peak_allocated_pct}%, Active: {peak_active_pct}%\n"
display_str += f"Peak Memory: Reserved {peak_reserved_pct}%, Alloc {peak_allocated_pct}%, Active: {peak_active_pct}%\n"

display_str += f"num retries: {self.num_retries}, num ooms: {self.num_ooms}"
if self.num_retries > 0:
Expand Down
12 changes: 10 additions & 2 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,10 +219,18 @@ def main(args):
time_delta * parallel_dims.model_parallel_size
)

gpu_mem_stats = gpu_metrics.get_current_stats(return_data=True)

metrics = {
"global_avg_loss": global_avg_loss,
"global_max_loss": global_max_loss,
"loss_metrics/global_avg_loss": global_avg_loss,
"loss_metrics/global_max_loss": global_max_loss,
"wps": wps,
"memory_current/active(%)": gpu_mem_stats.active_curr,
"memory_current/allocated(%)": gpu_mem_stats.allocated_curr,
"memory_current/reserved(%)": gpu_mem_stats.reserved_curr,
"memory_peak/active(%)": gpu_mem_stats.active_peak,
"memory_peak/allocated(%)": gpu_mem_stats.allocated_peak,
"memory_peak/reserved(%)": gpu_mem_stats.reserved_peak,
}
metric_logger.log(metrics, step=train_state.step)

Expand Down

0 comments on commit 5dda674

Please sign in to comment.