Skip to content

Commit

Permalink
Apply isort and black reformatting
Browse files Browse the repository at this point in the history
Signed-off-by: ashors1 <ashors1@users.noreply.github.com>
  • Loading branch information
ashors1 committed Jul 9, 2024
1 parent 11279d7 commit 925dcbc
Show file tree
Hide file tree
Showing 23 changed files with 512 additions and 375 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
from nemo.utils.app_state import AppState
from nemo.utils.model_utils import inject_model_parallel_rank


@hydra_runner(config_path="conf", config_name="nmt_megatron_infer")
def main(cfg) -> None:

Expand Down Expand Up @@ -91,13 +92,19 @@ def main(cfg) -> None:
src_text.append(line.strip())
if len(src_text) == cfg.batch_size:
translations = model.translate(
text=src_text, source_lang=cfg.source_lang, target_lang=cfg.target_lang,
text=src_text,
source_lang=cfg.source_lang,
target_lang=cfg.target_lang,
)
for translation in translations:
tgt_f.write(translation + "\n")
src_text = []
if len(src_text) > 0:
translations = model.translate(text=src_text, source_lang=cfg.source_lang, target_lang=cfg.target_lang,)
translations = model.translate(
text=src_text,
source_lang=cfg.source_lang,
target_lang=cfg.target_lang,
)
for translation in translations:
tgt_f.write(translation + "\n")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -380,7 +380,9 @@ def __init__(

time_embed_dim = model_channels * 4
self.time_embed = nn.Sequential(
linear(model_channels, time_embed_dim), nn.SiLU(), linear(time_embed_dim, time_embed_dim),
linear(model_channels, time_embed_dim),
nn.SiLU(),
linear(time_embed_dim, time_embed_dim),
)

self.input_blocks = nn.ModuleList(
Expand Down Expand Up @@ -505,24 +507,26 @@ def __init__(
use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,
),
AttentionBlock(
ch,
use_checkpoint=use_checkpoint,
num_heads=num_heads,
num_head_channels=dim_head,
use_new_attention_order=use_new_attention_order,
)
if not use_spatial_transformer
else SpatialTransformer( # always uses a self-attn
ch,
num_heads,
dim_head,
depth=transformer_depth,
context_dim=context_dim,
disable_self_attn=disable_middle_self_attn,
use_linear=use_linear_in_transformer,
use_checkpoint=use_checkpoint,
use_flash_attention=use_flash_attention,
(
AttentionBlock(
ch,
use_checkpoint=use_checkpoint,
num_heads=num_heads,
num_head_channels=dim_head,
use_new_attention_order=use_new_attention_order,
)
if not use_spatial_transformer
else SpatialTransformer( # always uses a self-attn
ch,
num_heads,
dim_head,
depth=transformer_depth,
context_dim=context_dim,
disable_self_attn=disable_middle_self_attn,
use_linear=use_linear_in_transformer,
use_checkpoint=use_checkpoint,
use_flash_attention=use_flash_attention,
)
),
ResBlock(
ch,
Expand Down Expand Up @@ -684,7 +688,10 @@ def fwd_bwd_step(self, dataloader_iter, forward_only):
# handle asynchronous grad reduction
no_sync_func = None
if not forward_only and self.with_distributed_adam:
no_sync_func = partial(self._optimizer.no_sync, greedy_grad_copy=self.megatron_amp_O2,)
no_sync_func = partial(
self._optimizer.no_sync,
greedy_grad_copy=self.megatron_amp_O2,
)

# pipeline schedules will get these from self.model.config
for module in self.get_module_list():
Expand Down Expand Up @@ -728,12 +735,12 @@ def fwd_bwd_step(self, dataloader_iter, forward_only):

def training_step(self, dataloader_iter):
"""
Our dataloaders produce a micro-batch and then we fetch
a number of microbatches depending on the global batch size and model parallel size
from the dataloader to produce a list of microbatches.
Batch should be a list of microbatches and those microbatches should on CPU.
Microbatches are then moved to GPU during the pipeline.
The list of microbatches is then piped through the pipeline using Apex fwd/bwd functions.
Our dataloaders produce a micro-batch and then we fetch
a number of microbatches depending on the global batch size and model parallel size
from the dataloader to produce a list of microbatches.
Batch should be a list of microbatches and those microbatches should on CPU.
Microbatches are then moved to GPU during the pipeline.
The list of microbatches is then piped through the pipeline using Apex fwd/bwd functions.
"""
# we zero grads here because we also call backward in the apex fwd/bwd functions
self._optimizer.zero_grad()
Expand Down Expand Up @@ -777,20 +784,20 @@ def training_step(self, dataloader_iter):
return loss_mean

def backward(self, *args, **kwargs):
""" LightningModule hook to do backward.
We want this to do nothing since we run backward in the fwd/bwd functions from apex.
No need to call it here.
"""LightningModule hook to do backward.
We want this to do nothing since we run backward in the fwd/bwd functions from apex.
No need to call it here.
"""
pass

def optimizer_zero_grad(self, *args, **kwargs):
""" LightningModule hook to zero grad.
We want this to do nothing as we are zeroing grads during the training_step.
"""LightningModule hook to zero grad.
We want this to do nothing as we are zeroing grads during the training_step.
"""
pass

def _append_sequence_parallel_module_grads(self, module, grads):
""" Helper method for allreduce_sequence_parallel_gradients"""
"""Helper method for allreduce_sequence_parallel_gradients"""

for param in module.parameters():
sequence_parallel_param = getattr(param, 'sequence_parallel', False)
Expand All @@ -803,8 +810,8 @@ def _append_sequence_parallel_module_grads(self, module, grads):

def get_forward_output_and_loss_func(self):
def process_batch(batch):
""" Prepares the global batch for apex fwd/bwd functions.
Global batch is a list of micro batches.
"""Prepares the global batch for apex fwd/bwd functions.
Global batch is a list of micro batches.
"""
# noise_map, condition
batch[self.cfg.first_stage_key] = batch[self.cfg.first_stage_key].cuda(non_blocking=True)
Expand All @@ -814,7 +821,8 @@ def process_batch(batch):

# SD has more dedicated structure for encoding, so we enable autocasting here as well
with torch.cuda.amp.autocast(
self.autocast_dtype in (torch.half, torch.bfloat16), dtype=self.autocast_dtype,
self.autocast_dtype in (torch.half, torch.bfloat16),
dtype=self.autocast_dtype,
):
x, c = self.model.get_input(batch, self.cfg.first_stage_key)

Expand Down Expand Up @@ -881,7 +889,7 @@ def validation_step(self, batch, batch_idx):
self.log_dict(val_loss_dict, prog_bar=False, logger=True, on_step=False, on_epoch=True)

def setup(self, stage=None):
""" PTL hook that is executed after DDP spawns.
"""PTL hook that is executed after DDP spawns.
We setup datasets here as megatron datasets require DDP to instantiate.
See https://pytorch-lightning.readthedocs.io/en/latest/common/lightning_module.html#setup for more information.
Args:
Expand Down Expand Up @@ -935,7 +943,8 @@ def build_train_valid_test_datasets(self):

if self.cfg.first_stage_key.endswith("encoded"):
self._train_ds, self._validation_ds = build_train_valid_precached_datasets(
model_cfg=self.cfg, consumed_samples=self.compute_consumed_samples(0),
model_cfg=self.cfg,
consumed_samples=self.compute_consumed_samples(0),
)
else:
self._train_ds, self._validation_ds = build_train_valid_datasets(
Expand Down Expand Up @@ -989,20 +998,23 @@ def setup_test_data(self, cfg):
f'Setting up test dataloader with len(len(self._test_ds)): {len(self._test_ds)} and consumed samples: {consumed_samples}'
)
self._test_dl = torch.utils.data.DataLoader(
self._test_ds, batch_size=self._micro_batch_size, num_workers=cfg.num_workers, pin_memory=True,
self._test_ds,
batch_size=self._micro_batch_size,
num_workers=cfg.num_workers,
pin_memory=True,
)

def transfer_batch_to_device(self, batch: Any, device: torch.device, dataloader_idx: int) -> Any:
""" PTL hook: https://pytorch-lightning.readthedocs.io/en/latest/common/lightning_module.html#transfer-batch-to-device
When using pipeline parallelism, we need the global batch to remain on the CPU,
since the memory overhead will be too high when using a large number of microbatches.
Microbatches are transferred from CPU to GPU inside the pipeline.
"""PTL hook: https://pytorch-lightning.readthedocs.io/en/latest/common/lightning_module.html#transfer-batch-to-device
When using pipeline parallelism, we need the global batch to remain on the CPU,
since the memory overhead will be too high when using a large number of microbatches.
Microbatches are transferred from CPU to GPU inside the pipeline.
"""
return batch

def _validate_trainer(self):
""" Certain trainer configurations can break training.
Here we try to catch them and raise an error.
"""Certain trainer configurations can break training.
Here we try to catch them and raise an error.
"""
if self.trainer.accumulate_grad_batches > 1:
raise ValueError(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,9 @@ def __init__(self, cfg, model_parallel_config):
self.get_noise_scheduler(self.cfg.noise_scheduler)

self.model_type = None
self.rng = torch.Generator(device=torch.cuda.current_device(),)
self.rng = torch.Generator(
device=torch.cuda.current_device(),
)

self.use_cached_latents = self.cfg.use_cached_latents

Expand Down Expand Up @@ -246,7 +248,10 @@ def fwd_bwd_step(self, dataloader_iter, forward_only):
# handle asynchronous grad reduction
no_sync_func = None
if not forward_only and self.with_distributed_adam:
no_sync_func = partial(self._optimizer.no_sync, greedy_grad_copy=self.megatron_amp_O2,)
no_sync_func = partial(
self._optimizer.no_sync,
greedy_grad_copy=self.megatron_amp_O2,
)

# pipeline schedules will get these from self.model.config
for module in self.get_module_list():
Expand Down Expand Up @@ -291,12 +296,12 @@ def fwd_bwd_step(self, dataloader_iter, forward_only):

def training_step(self, dataloader_iter):
"""
Our dataloaders produce a micro-batch and then we fetch
a number of microbatches depending on the global batch size and model parallel size
from the dataloader to produce a list of microbatches.
Batch should be a list of microbatches and those microbatches should on CPU.
Microbatches are then moved to GPU during the pipeline.
The list of microbatches is then piped through the pipeline using Apex fwd/bwd functions.
Our dataloaders produce a micro-batch and then we fetch
a number of microbatches depending on the global batch size and model parallel size
from the dataloader to produce a list of microbatches.
Batch should be a list of microbatches and those microbatches should on CPU.
Microbatches are then moved to GPU during the pipeline.
The list of microbatches is then piped through the pipeline using Apex fwd/bwd functions.
"""

# we zero grads here because we also call backward in the apex fwd/bwd functions
Expand Down Expand Up @@ -351,20 +356,20 @@ def validation_step(self, dataloader_iter):
return loss

def backward(self, *args, **kwargs):
""" LightningModule hook to do backward.
We want this to do nothing since we run backward in the fwd/bwd functions from apex.
No need to call it here.
"""LightningModule hook to do backward.
We want this to do nothing since we run backward in the fwd/bwd functions from apex.
No need to call it here.
"""
pass

def optimizer_zero_grad(self, *args, **kwargs):
""" LightningModule hook to zero grad.
We want this to do nothing as we are zeroing grads during the training_step.
"""LightningModule hook to zero grad.
We want this to do nothing as we are zeroing grads during the training_step.
"""
pass

def _append_sequence_parallel_module_grads(self, module, grads):
""" Helper method for allreduce_sequence_parallel_gradients"""
"""Helper method for allreduce_sequence_parallel_gradients"""

for param in module.parameters():
sequence_parallel_param = getattr(param, 'sequence_parallel', False)
Expand All @@ -381,7 +386,8 @@ def process_batch(batch):
prompts, images = batch
# DB has more dedicated structure for encoding, so we enable autocasting here as well
with torch.cuda.amp.autocast(
self.autocast_dtype in (torch.half, torch.bfloat16), dtype=self.autocast_dtype,
self.autocast_dtype in (torch.half, torch.bfloat16),
dtype=self.autocast_dtype,
):
images = images.cuda(non_blocking=True)

Expand Down Expand Up @@ -412,7 +418,7 @@ def fwd_output_only_func(batch, model):
return fwd_output_only_func

def setup(self, stage=None):
""" PTL hook that is executed after DDP spawns.
"""PTL hook that is executed after DDP spawns.
We setup datasets here as megatron datasets require DDP to instantiate.
See https://pytorch-lightning.readthedocs.io/en/latest/common/lightning_module.html#setup for more information.
Args:
Expand Down Expand Up @@ -472,9 +478,9 @@ def setup_training_data(self, cfg):
center_crop=cfg.center_crop,
load_cache_latents=self.model.use_cached_latents,
cached_instance_data_root=self.cfg.data.get("cached_instance_dir", None),
cached_reg_data_root=self.cfg.data.get("cached_reg_dir", None)
if self.cfg.with_prior_preservation
else None,
cached_reg_data_root=(
self.cfg.data.get("cached_reg_dir", None) if self.cfg.with_prior_preservation else None
),
vae=self.model.vae,
text_encoder=self.model.text_encoder,
)
Expand Down Expand Up @@ -505,16 +511,16 @@ def setup_test_data(self, cfg):
pass

def transfer_batch_to_device(self, batch: Any, device: torch.device, dataloader_idx: int) -> Any:
""" PTL hook: https://pytorch-lightning.readthedocs.io/en/latest/common/lightning_module.html#transfer-batch-to-device
When using pipeline parallelism, we need the global batch to remain on the CPU,
since the memory overhead will be too high when using a large number of microbatches.
Microbatches are transferred from CPU to GPU inside the pipeline.
"""PTL hook: https://pytorch-lightning.readthedocs.io/en/latest/common/lightning_module.html#transfer-batch-to-device
When using pipeline parallelism, we need the global batch to remain on the CPU,
since the memory overhead will be too high when using a large number of microbatches.
Microbatches are transferred from CPU to GPU inside the pipeline.
"""
return batch

def _validate_trainer(self):
""" Certain trainer configurations can break training.
Here we try to catch them and raise an error.
"""Certain trainer configurations can break training.
Here we try to catch them and raise an error.
"""
if self.trainer.accumulate_grad_batches > 1:
raise ValueError(
Expand Down
Loading

0 comments on commit 925dcbc

Please sign in to comment.