Skip to content

Commit

Permalink
Merge pull request #336 from blisc/u_switch_apex_ddp_to_torch
Browse files Browse the repository at this point in the history
Switch Apex with Pytorch
  • Loading branch information
okuchaiev authored Feb 14, 2020
2 parents f072029 + 8c26247 commit 3e04e09
Show file tree
Hide file tree
Showing 15 changed files with 523 additions and 295 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,8 @@ To release a new version, please update the changelog as followed:
- Updated licenses
- Updated nemo's use of the logging library. from nemo import logging is now the reccomended way of using the nemo logger. neural_factory.logger and all other instances of logger are now deprecated and planned for removal in the next version. Please see PR 267 for complete change information.
([PR #267](https://github.com/NVIDIA/NeMo/pull/267), [PR #283](https://github.com/NVIDIA/NeMo/pull/283), [PR #305](https://github.com/NVIDIA/NeMo/pull/305), [PR #311](https://github.com/NVIDIA/NeMo/pull/311)) - @blisc
- Changed Distributed Data Parallel from Apex to Torch
([PR #336](https://github.com/NVIDIA/NeMo/pull/336)) - @blisc

- Added TRADE (dialogue state tracking model) on MultiWOZ dataset
([PR #322](https://github.com/NVIDIA/NeMo/pull/322)) - @chiphuyen, @VahidooX
Expand All @@ -108,6 +110,8 @@ To release a new version, please update the changelog as followed:
([PR #308](https://github.com/NVIDIA/NeMo/pull/309)) - @tkornuta-nvidia

### Removed
- gradient_predivide_factor arg of train() now has no effect
([PR #336](https://github.com/NVIDIA/NeMo/pull/336)) - @blisc
- Dropped support of the following ASR configs: jasper10x4.yaml, quartznet10x5.yaml, quartznet15x5_in.yaml, quartznet5x3.yaml, quartznet5x5.yaml, quartznet_an4.yaml. They are moved to experimental/configs and can still be used with v0.9 for use in replicating paper results
([PR #354](https://github.com/NVIDIA/NeMo/pull/354)) - @blisc

Expand Down
31 changes: 28 additions & 3 deletions examples/nlp/asr_postprocessor/asr_postprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
eval_epochs_done_callback_wer,
eval_iter_callback,
)
from nemo.core import WeightShareTransform
from nemo.core.callbacks import CheckpointCallback
from nemo.utils.lr_policies import SquareAnnealing

Expand Down Expand Up @@ -126,9 +127,33 @@
)

# tie all embeddings weights
t_log_softmax.mlp.layer0.weight = encoder.bert.embeddings.word_embeddings.weight
decoder.embedding_layer.token_embedding.weight = encoder.bert.embeddings.word_embeddings.weight
decoder.embedding_layer.position_embedding.weight = encoder.bert.embeddings.position_embeddings.weight
# t_log_softmax.mlp.layer0.weight = encoder.bert.embeddings.word_embeddings.weight
# decoder.embedding_layer.token_embedding.weight = encoder.bert.embeddings.word_embeddings.weight
# decoder.embedding_layer.position_embedding.weight = encoder.bert.embeddings.position_embeddings.weight
t_log_softmax.tie_weights_with(
encoder,
weight_names=["mlp.layer0.weight"],
name2name_and_transform={
"mlp.layer0.weight": ("bert.embeddings.word_embeddings.weight", WeightShareTransform.SAME)
},
)
decoder.tie_weights_with(
encoder,
weight_names=["embedding_layer.token_embedding.weight"],
name2name_and_transform={
"embedding_layer.token_embedding.weight": ("bert.embeddings.word_embeddings.weight", WeightShareTransform.SAME)
},
)
decoder.tie_weights_with(
encoder,
weight_names=["embedding_layer.position_embedding.weight"],
name2name_and_transform={
"embedding_layer.position_embedding.weight": (
"bert.embeddings.position_embeddings.weight",
WeightShareTransform.SAME,
)
},
)


def create_pipeline(dataset, tokens_in_batch, clean=False, training=True):
Expand Down
10 changes: 8 additions & 2 deletions examples/nlp/language_modeling/bert_pretraining.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# =============================================================================

"""
To pretrain BERT on raw text dataset run
Expand Down Expand Up @@ -224,7 +223,14 @@
# tie weights of MLM softmax layer and embedding layer of the encoder
if mlm_classifier.mlp.last_linear_layer.weight.shape != bert_model.bert.embeddings.word_embeddings.weight.shape:
raise ValueError("Final classification layer does not match embedding " "layer.")
mlm_classifier.mlp.last_linear_layer.weight = bert_model.bert.embeddings.word_embeddings.weight
# mlm_classifier.mlp.last_linear_layer.weight = bert_model.bert.embeddings.word_embeddings.weight
mlm_classifier.tie_weights_with(
bert_model,
weight_names=["mlp.last_linear_layer.weight"],
name2name_and_transform={
"mlp.last_linear_layer.weight": ("bert.embeddings.word_embeddings.weight", nemo_core.WeightShareTransform.SAME)
},
)


def create_pipeline(data_file, batch_size, preprocessed_data=False, batches_per_step=1, **kwargs):
Expand Down
11 changes: 9 additions & 2 deletions examples/nlp/language_modeling/language_modeling_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# =============================================================================

import math

import nemo
Expand All @@ -22,6 +21,7 @@
import nemo.collections.nlp.nm.trainables.common.token_classification_nm
from nemo.collections.nlp.callbacks.lm_transformer_callback import eval_epochs_done_callback, eval_iter_callback
from nemo.collections.nlp.data.datasets.lm_transformer_dataset import LanguageModelDataDesc
from nemo.core import WeightShareTransform
from nemo.utils.lr_policies import CosineAnnealing

parser = nemo.utils.NemoArgParser(description='LM Transformer')
Expand Down Expand Up @@ -114,7 +114,14 @@
)

# tie weight of embedding and log_softmax layers
log_softmax.mlp.last_linear_layer.weight = encoder.embedding_layer.token_embedding.weight
# log_softmax.mlp.last_linear_layer.weight = encoder.embedding_layer.token_embedding.weight
log_softmax.tie_weights_with(
encoder,
weight_names=["mlp.layer0.weight"],
name2name_and_transform={
"mlp.layer0.weight": ("embedding_layer.token_embedding.weight", WeightShareTransform.SAME)
},
)


def create_pipeline(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import nemo
import nemo.collections.nlp as nemo_nlp
from nemo.collections.nlp.callbacks.machine_translation_callback import eval_epochs_done_callback, eval_iter_callback
from nemo.core import WeightShareTransform
from nemo.utils.lr_policies import get_lr_policy

parser = nemo.utils.NemoArgParser(description='Transformer for Neural Machine Translation')
Expand Down Expand Up @@ -165,8 +166,25 @@
)

if tie_weight:
log_softmax.mlp.last_linear_layer.weight = encoder.embedding_layer.token_embedding.weight
decoder.embedding_layer.token_embedding.weight = encoder.embedding_layer.token_embedding.weight
# log_softmax.mlp.last_linear_layer.weight = encoder.embedding_layer.token_embedding.weight
log_softmax.tie_weights_with(
encoder,
weight_names=["mlp.last_linear_layer.weight"],
name2name_and_transform={
"mlp.last_linear_layer.weight": ("embedding_layer.token_embedding.weight", WeightShareTransform.SAME)
},
)
# decoder.embedding_layer.token_embedding.weight = encoder.embedding_layer.token_embedding.weight
decoder.tie_weights_with(
encoder,
weight_names=["embedding_layer.token_embedding.weight"],
name2name_and_transform={
"embedding_layer.token_embedding.weight": (
"embedding_layer.token_embedding.weight",
WeightShareTransform.SAME,
)
},
)


def create_pipeline(dataset_src, dataset_tgt, tokens_in_batch, clean=False, training=True):
Expand Down
128 changes: 86 additions & 42 deletions nemo/backends/pytorch/actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,15 @@
import json
import os
from collections import defaultdict
from contextlib import ExitStack
from pathlib import Path
from typing import Dict, List, Optional
from typing import List, Optional

import torch
import torch.distributed as dist
import torch.nn as nn
import torch.optim as optim
from torch.nn.parallel import DistributedDataParallel as DDP

from nemo import logging
from nemo.backends.pytorch.module_wrapper import TrainableNeuralModuleWrapper
Expand All @@ -25,9 +27,8 @@

# these imports will happen on as-needed basis
amp = None
convert_syncbn = None
create_syncbn_process_group = None
DDP = None
# convert_syncbn = None
# create_syncbn_process_group = None
LARC = None
FusedLAMB = None
FusedAdam = None
Expand Down Expand Up @@ -59,18 +60,16 @@ def __init__(
global amp
amp = importlib.import_module('apex.amp')
if local_rank is not None:
global convert_syncbn
global create_syncbn_process_group
global DDP
# global convert_syncbn
# global create_syncbn_process_group
global LARC
global FusedLAMB
global FusedAdam
global FusedNovoGrad
parallel = importlib.import_module('apex.parallel')
apex_optimizer = importlib.import_module('apex.optimizers')
convert_syncbn = parallel.convert_syncbn_model
create_syncbn_process_group = parallel.create_syncbn_process_group
DDP = parallel.DistributedDataParallel
# convert_syncbn = parallel.convert_syncbn_model
# create_syncbn_process_group = parallel.create_syncbn_process_group
LARC = parallel.LARC
FusedLAMB = apex_optimizer.FusedLAMB
FusedAdam = apex_optimizer.FusedAdam
Expand Down Expand Up @@ -379,7 +378,7 @@ def __initialize_amp(
return optimizer

def __nm_graph_forward_pass(
self, call_chain, registered_tensors, mode=ModelMode.train, disable_allreduce=False, use_cache=False,
self, call_chain, registered_tensors, mode=ModelMode.train, use_cache=False,
):
for ind in range(1, len(call_chain)):
if use_cache:
Expand All @@ -399,12 +398,12 @@ def __nm_graph_forward_pass(
m_id = call_chain[ind][0].unique_instance_id
pmodule = self.module_reference_table[m_id][1]

if self._local_rank is not None:
if isinstance(pmodule, DDP):
if disable_allreduce:
pmodule.disable_allreduce()
else:
pmodule.enable_allreduce()
# if self._local_rank is not None:
# if isinstance(pmodule, DDP):
# if disable_allreduce:
# pmodule.disable_allreduce()
# else:
# pmodule.enable_allreduce()

if mode == ModelMode.train:
# if module.is_trainable():
Expand Down Expand Up @@ -935,9 +934,8 @@ def __extract_dynamic_axes(port_name: str, ntype: NeuralType, dynamic_axes: defa
outputs_to_drop = set()
if type(module).__name__ == "JasperEncoder":
logging.info(
f"Module is JasperEncoder. We are removing"
f"input and output length ports since they "
f"are not needed for deployment"
"Module is JasperEncoder. We are removing input and output length ports since they are not needed for "
"deployment"
)
inputs_to_drop.add("length")
outputs_to_drop.add("encoded_lengths")
Expand Down Expand Up @@ -1072,6 +1070,11 @@ def train(
gradient_predivide=False,
amp_max_loss_scale=2.0 ** 24,
):
if gradient_predivide:
logging.error(
"gradient_predivide is currently disabled, and is under consideration for removal in future versions. "
"If this functionality is needed, please raise a github issue."
)
if not optimization_params:
optimization_params = {}
num_epochs = optimization_params.get("num_epochs", None)
Expand Down Expand Up @@ -1213,23 +1216,44 @@ def train(
key = call_chain[i][0].unique_instance_id
pmodule = self.module_reference_table[key][1]
if not isinstance(pmodule, DDP) and isinstance(pmodule, torch.nn.Module):
gpf = 1
if gradient_predivide:
gpf = dist.get_world_size()
pmodule = DDP(pmodule, gradient_predivide_factor=gpf)

# Convert batchnorm modules to synced if applicable
if synced_batchnorm and isinstance(pmodule, torch.nn.Module):
world_size = dist.get_world_size()
if synced_batchnorm_groupsize > 0 and world_size % synced_batchnorm_groupsize != 0:
raise ValueError(
f"Synchronized batch norm group size"
f" ({synced_batchnorm_groupsize}) must be 0"
f" or divide total number of GPUs"
f" ({world_size})."
# gpf = 1
# if gradient_predivide:
# gpf = dist.get_world_size()
# pmodule = DDP(pmodule, gradient_predivide_factor=gpf) # Old Apex Method

# Per pytorch docs, convert sync bn prior to DDP
if synced_batchnorm:
world_size = dist.get_world_size()
sync_batchnorm_group = None
if synced_batchnorm_groupsize > 0:
if world_size % synced_batchnorm_groupsize != 0:
raise ValueError(
f"Synchronized batch norm group size ({synced_batchnorm_groupsize}) must be 0"
f" or divide total number of GPUs ({world_size})."
)
sync_batchnorm_group = torch.distributed.new_group(synced_batchnorm_groupsize)
pmodule = nn.SyncBatchNorm.convert_sync_batchnorm(
pmodule, process_group=sync_batchnorm_group
)
process_group = create_syncbn_process_group(synced_batchnorm_groupsize)
pmodule = convert_syncbn(pmodule, process_group=process_group)

# By default, disable broadcast_buffers. This disables batch norm synchronization on forward
# pass
pmodule = DDP(
pmodule, device_ids=[self.local_rank], broadcast_buffers=False, find_unused_parameters=True
)

# # Convert batchnorm modules to synced if applicable
# if synced_batchnorm and isinstance(pmodule, torch.nn.Module):
# world_size = dist.get_world_size()
# if synced_batchnorm_groupsize > 0 and world_size % synced_batchnorm_groupsize != 0:
# raise ValueError(
# f"Synchronized batch norm group size"
# f" ({synced_batchnorm_groupsize}) must be 0"
# f" or divide total number of GPUs"
# f" ({world_size})."
# )
# process_group = create_syncbn_process_group(synced_batchnorm_groupsize)
# pmodule = convert_syncbn(pmodule, process_group=process_group)

self.module_reference_table[key] = (
self.module_reference_table[key][0],
Expand Down Expand Up @@ -1308,9 +1332,7 @@ def train(
}
disable_allreduce = batch_counter < (batches_per_step - 1)
self.__nm_graph_forward_pass(
call_chain=curr_call_chain,
registered_tensors=registered_tensors,
disable_allreduce=disable_allreduce,
call_chain=curr_call_chain, registered_tensors=registered_tensors,
)

curr_tensors_to_optimize = training_loop[self.step % len(training_loop)][1]
Expand All @@ -1331,19 +1353,31 @@ def train(
if nan:
continue
if self._optim_level in AmpOptimizations and self._optim_level != Optimization.mxprO0:
with amp.scale_loss(final_loss, curr_optimizer, delay_unscale=disable_allreduce,) as scaled_loss:
with amp.scale_loss(final_loss, curr_optimizer, delay_unscale=disable_allreduce) as scaled_loss:
if torch.isnan(scaled_loss).any() or torch.isinf(scaled_loss).any():
if stop_on_nan_loss:
raise ValueError('Loss is NaN or inf -' ' exiting')
logging.warning('WARNING: Loss is NaN or inf')
curr_optimizer.zero_grad()
continue
scaled_loss.backward(bps_scale.to(scaled_loss.get_device()))
if disable_allreduce:
with ExitStack() as stack:
for mod in self.get_DDP_modules(curr_call_chain):
stack.enter_context(mod.no_sync())
scaled_loss.backward(bps_scale.to(scaled_loss.get_device()))
else:
scaled_loss.backward(bps_scale.to(scaled_loss.get_device()))
# no AMP optimizations needed
else:
# multi-GPU, float32
if self._local_rank is not None:
final_loss.backward(bps_scale.to(final_loss.get_device()))
if disable_allreduce:
with ExitStack() as stack:
for mod in self.get_DDP_modules(curr_call_chain):
stack.enter_context(mod.no_sync())
final_loss.backward(bps_scale.to(final_loss.get_device()))
else:
final_loss.backward(bps_scale.to(final_loss.get_device()))
# single device (CPU or GPU)
else:
# Fix (workaround?) enabling to backpropagate gradiens on CPUs.
Expand Down Expand Up @@ -1438,3 +1472,13 @@ def infer(
use_cache=use_cache,
offload_to_cpu=offload_to_cpu,
)

def get_DDP_modules(self, call_chain):
modules = []
for ind in range(1, len(call_chain)):
m_id = call_chain[ind][0].unique_instance_id
module = self.module_reference_table[m_id][1]
if isinstance(module, DDP):
modules.append(module)

return modules
Loading

0 comments on commit 3e04e09

Please sign in to comment.