From 5d5919fe8a4877d260a22cd80690b7ef3acde7a0 Mon Sep 17 00:00:00 2001 From: Adi Renduchintala Date: Wed, 1 May 2024 10:50:52 -0700 Subject: [PATCH 01/73] unfused lora (#9004) * WIP unfused lora Signed-off-by: arendu * unfused lora training and generation Signed-off-by: arendu * update Signed-off-by: arendu * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * update Signed-off-by: arendu * GQA support for unfused lora Signed-off-by: arendu * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * converter for fused to unfused lora added Signed-off-by: arendu * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * defaults Signed-off-by: arendu * refac Signed-off-by: arendu * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * cleaned Signed-off-by: arendu * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * unfusing h to 4h adapter Signed-off-by: arendu * unfused hto 4h Signed-off-by: arendu * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix for canonical Signed-off-by: arendu * updates Signed-off-by: arendu * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: arendu Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- .../conf/megatron_gpt_finetuning_config.yaml | 1 + .../conf/megatron_gpt_generate_config.yaml | 1 + .../common/megatron/adapters/mcore_mixins.py | 42 +++- .../megatron/adapters/parallel_adapters.py | 179 +++++++++++++++ nemo/collections/nlp/parts/peft_config.py | 50 ++++- .../convert_nemo_to_canonical.py | 212 ++++++++++++++++++ 6 files changed, 469 insertions(+), 16 deletions(-) create mode 100644 scripts/checkpoint_converters/lora_converters/convert_nemo_to_canonical.py diff --git a/examples/nlp/language_modeling/tuning/conf/megatron_gpt_finetuning_config.yaml b/examples/nlp/language_modeling/tuning/conf/megatron_gpt_finetuning_config.yaml index 40347f317fbb..6517b62010b4 100644 --- a/examples/nlp/language_modeling/tuning/conf/megatron_gpt_finetuning_config.yaml +++ b/examples/nlp/language_modeling/tuning/conf/megatron_gpt_finetuning_config.yaml @@ -101,6 +101,7 @@ model: position_embedding_strategy: null # used only when weight_tying is True lora_tuning: + variant: "nemo" # can be "nemo" or "canonical" target_modules: ['attention_qkv'] # this can either be 'attention_qkv','attention_dense','mlp_fc1','mlp_fc2', attention (qkv & dense), mlp (fc1 & fc2) adapter_dim: 32 alpha: ${model.peft.lora_tuning.adapter_dim} diff --git a/examples/nlp/language_modeling/tuning/conf/megatron_gpt_generate_config.yaml b/examples/nlp/language_modeling/tuning/conf/megatron_gpt_generate_config.yaml index 67d43eb303f4..592eed6c4420 100644 --- a/examples/nlp/language_modeling/tuning/conf/megatron_gpt_generate_config.yaml +++ b/examples/nlp/language_modeling/tuning/conf/megatron_gpt_generate_config.yaml @@ -89,6 +89,7 @@ model: position_embedding_strategy: null # used only when weight_tying is True lora_tuning: + variant: "nemo" # can be either "canonical" or "nemo" target_modules: ['attention_qkv'] # this can either be 'attention_qkv','attention_dense','mlp_fc1','mlp_fc2', attention (qkv & dense), mlp (fc1 & fc2) adapter_dim: 32 adapter_dropout: 0.0 diff --git a/nemo/collections/nlp/modules/common/megatron/adapters/mcore_mixins.py b/nemo/collections/nlp/modules/common/megatron/adapters/mcore_mixins.py index a5e886f3b479..16ded8e2c682 100644 --- a/nemo/collections/nlp/modules/common/megatron/adapters/mcore_mixins.py +++ b/nemo/collections/nlp/modules/common/megatron/adapters/mcore_mixins.py @@ -37,6 +37,8 @@ LoraDenseAttentionAdapterConfig, LoraHto4HAdapterConfig, LoraKQVAdapterConfig, + LoraUnfusedHto4HAdapterConfig, + LoraUnfusedKQVAdapterConfig, MLPInfusedAdapterConfig, ParallelLinearAdapterConfig, PromptEncoderAdapterConfig, @@ -67,7 +69,12 @@ def mcore_register_adapters(self): Setup NeMo LoRA or IA3 adapter to this MCore layer. """ self.set_accepted_adapter_types( - [LoraKQVAdapterConfig._target_, LoraDenseAttentionAdapterConfig._target_, InfusedAdapterConfig._target_] + [ + LoraUnfusedKQVAdapterConfig._target_, + LoraKQVAdapterConfig._target_, + LoraDenseAttentionAdapterConfig._target_, + InfusedAdapterConfig._target_, + ] ) self.linear_qkv.return_layernorm_output = True # need layernorm output for lora mlp if ( @@ -102,12 +109,20 @@ def get_query_key_value_tensors(self, hidden_states, key_value_states=None): # LoRA logic if self.is_adapter_available(): + lora_adapter = None lora_kqv_adapter = self.get_adapter_module(AdapterName.LORA_KQV_ADAPTER) + lora_unfused_kqv_adapter = self.get_adapter_module(AdapterName.LORA_UNFUSED_KQV_ADAPTER) if lora_kqv_adapter and self.adapter_cfg[AdapterName.LORA_KQV_ADAPTER]['enabled']: + lora_adapter = lora_kqv_adapter + if lora_unfused_kqv_adapter and self.adapter_cfg[AdapterName.LORA_UNFUSED_KQV_ADAPTER]['enabled']: + assert lora_adapter is None, "Expected only one of lora_kqv_adapter or lora_unfused_kqv_adapter" + lora_adapter = lora_unfused_kqv_adapter + + if lora_adapter: if layernorm_output is not None: - lora_mixed_qkv = lora_kqv_adapter(layernorm_output) + lora_mixed_qkv = lora_adapter(layernorm_output) else: - lora_mixed_qkv = lora_kqv_adapter(hidden_states) + lora_mixed_qkv = lora_adapter(hidden_states) mixed_qkv = mixed_qkv + lora_mixed_qkv @@ -251,7 +266,12 @@ def mcore_register_adapters(self): Setup NeMo IA3 adapter to this MCore layer. """ self.set_accepted_adapter_types( - [LoraHto4HAdapterConfig._target_, Lora4HtoHAdapterConfig._target_, MLPInfusedAdapterConfig._target_] + [ + LoraUnfusedHto4HAdapterConfig._target_, + LoraHto4HAdapterConfig._target_, + Lora4HtoHAdapterConfig._target_, + MLPInfusedAdapterConfig._target_, + ] ) # only self attn (packed qkv) for now self.linear_fc1.return_layernorm_output = True # need layernorm output for lora mlp if ( @@ -274,9 +294,17 @@ def forward(self, hidden_states): # LoRA logic if self.is_adapter_available(): - lora_linear_fc1_adapter = self.get_adapter_module(AdapterName.LORA_Hto4H_ADAPTER) - if lora_linear_fc1_adapter and self.adapter_cfg[AdapterName.LORA_Hto4H_ADAPTER]['enabled']: - lora_output = lora_linear_fc1_adapter(layernorm_output) + lora_adapter = None + lora_fc1_adapter = self.get_adapter_module(AdapterName.LORA_Hto4H_ADAPTER) + lora_unfused_fc1_adapter = self.get_adapter_module(AdapterName.LORA_UNFUSED_Hto4H_ADAPTER) + if lora_fc1_adapter and self.adapter_cfg[AdapterName.LORA_Hto4H_ADAPTER]['enabled']: + lora_adapter = lora_fc1_adapter + if lora_unfused_fc1_adapter and self.adapter_cfg[AdapterName.LORA_UNFUSED_Hto4H_ADAPTER]['enabled']: + assert lora_adapter is None, "Expected only one of LORA_Hto4H_ADAPTER or LORA_UNFUSED_Hto4H_ADAPTER" + lora_adapter = lora_unfused_fc1_adapter + + if lora_adapter: + lora_output = lora_adapter(layernorm_output) intermediate_parallel = intermediate_parallel + lora_output if self.config.bias_activation_fusion: diff --git a/nemo/collections/nlp/modules/common/megatron/adapters/parallel_adapters.py b/nemo/collections/nlp/modules/common/megatron/adapters/parallel_adapters.py index 5037bb1b3634..2a5372d11ab5 100644 --- a/nemo/collections/nlp/modules/common/megatron/adapters/parallel_adapters.py +++ b/nemo/collections/nlp/modules/common/megatron/adapters/parallel_adapters.py @@ -75,11 +75,13 @@ class AdapterName(str, enum.Enum): POST_ATTN_ADAPTER = 'adapter_2' PTUNING_ADAPTER = "ptuning_adapter" LORA_KQV_ADAPTER = "lora_kqv_adapter" + LORA_UNFUSED_KQV_ADAPTER = "lora_unfused_kqv_adapter" LORA_KV_ADAPTER = "lora_kv_adapter" LORA_Q_ADAPTER = "lora_q_adapter" MM_LINEAR_ADAPTER = "mm_linear_adapter" LORA_DENSE_ATTENTION_ADAPTER = "lora_dense_attention_adapter" LORA_Hto4H_ADAPTER = "lora_hto4h_adapter" + LORA_UNFUSED_Hto4H_ADAPTER = "lora_unfused_hto4h_adapter" LORA_4HtoH_ADAPTER = "lora_4htoh_adapter" MULTIMODAL_PROJECTOR_ADAPTER = "mm_projector_adapter" PARALLEL_LINEAR_ADAPTER = "parallel_linear_adapter" @@ -457,6 +459,183 @@ class Lora4HtoHAdapterConfig(ParallelLinearAdapterConfig): input_is_parallel: bool = True +class LoraUnfusedHto4HAdapter(nn.Module, AdapterModuleUtil): + def __init__( + self, + in_features: int, + out_features: int, + dim: int, + activation: str = 'swish', + norm_position: Optional[str] = 'post', + norm_type: Optional[str] = 'mixedfusedlayernorm', + column_init_method: str = 'xavier', # TODO: (@adithyare) should rename this to input_init_method to be more precise. + row_init_method: str = 'zero', # TODO: (@adithyare) should rename this to output_init_method to be more precise. + gather_output: bool = True, + input_is_parallel: bool = False, # NOTE: (@ertkonuk) we need this for LoRA adapters that are applied to RowParallelLinear layers + dropout: float = 0.0, + model_parallel_config: Optional[ModelParallelConfig] = None, + alpha: float | None = None, + dropout_position: str = 'post', + a2a_experimental: bool = False, # TODO: should rename this or make it a default feature + **kwargs, + ): + super().__init__() + self.gate_adapter = ParallelLinearAdapter( + in_features, + out_features // 2, + dim, + activation, + norm_position, + norm_type, + column_init_method, + row_init_method, + gather_output, + input_is_parallel, + dropout, + model_parallel_config, + alpha, + dropout_position, + a2a_experimental, + ) + self.up_adapter = ParallelLinearAdapter( + in_features, + out_features // 2, + dim, + activation, + norm_position, + norm_type, + column_init_method, + row_init_method, + gather_output, + input_is_parallel, + dropout, + model_parallel_config, + alpha, + dropout_position, + a2a_experimental, + ) + + def forward(self, x): + gate_x = self.gate_adapter(x) + up_x = self.up_adapter(x) + x = torch.concat([gate_x, up_x], dim=2) + return x + + +@dataclass +class LoraUnfusedHto4HAdapterConfig(ParallelLinearAdapterConfig): + _target_: str = "{0}.{1}".format(LoraUnfusedHto4HAdapter.__module__, LoraUnfusedHto4HAdapter.__name__) + + +class LoraUnfusedKQVAdapter(nn.Module, AdapterModuleUtil): + def __init__( + self, + in_features: int, + dim: int, + activation: str = 'swish', + norm_position: Optional[str] = 'post', + norm_type: Optional[str] = 'mixedfusedlayernorm', + column_init_method: str = 'xavier', # TODO: (@adithyare) should rename this to input_init_method to be more precise. + row_init_method: str = 'zero', # TODO: (@adithyare) should rename this to output_init_method to be more precise. + gather_output: bool = True, + input_is_parallel: bool = False, # NOTE: (@ertkonuk) we need this for LoRA adapters that are applied to RowParallelLinear layers + dropout: float = 0.0, + model_parallel_config: Optional[ModelParallelConfig] = None, + alpha: float | None = None, + dropout_position: str = 'post', + a2a_experimental: bool = False, # TODO: should rename this or make it a default feature + num_query_groups: Optional[int] = None, + kv_channels: Optional[int] = None, + **kwargs, + ): + super().__init__() + if num_query_groups is not None and kv_channels is not None: + out_features = kv_channels * num_query_groups + else: + out_features = in_features + + self.q_adapter = ParallelLinearAdapter( + in_features, + in_features, + dim, + activation, + norm_position, + norm_type, + column_init_method, + row_init_method, + gather_output, + input_is_parallel, + dropout, + model_parallel_config, + alpha, + dropout_position, + a2a_experimental, + ) + + self.k_adapter = ParallelLinearAdapter( + in_features, + out_features, + dim, + activation, + norm_position, + norm_type, + column_init_method, + row_init_method, + gather_output, + input_is_parallel, + dropout, + model_parallel_config, + alpha, + dropout_position, + a2a_experimental, + ) + self.v_adapter = ParallelLinearAdapter( + in_features, + out_features, + dim, + activation, + norm_position, + norm_type, + column_init_method, + row_init_method, + gather_output, + input_is_parallel, + dropout, + model_parallel_config, + alpha, + dropout_position, + a2a_experimental, + ) + + def forward(self, x): + qx = self.q_adapter(x) + kx = self.k_adapter(x) + vx = self.v_adapter(x) + x = torch.concat([qx, kx, vx], dim=2) + return x + + +@dataclass +class LoraUnfusedKQVAdapterConfig(AdapterConfig): + in_features: int + dim: int + activation: str = 'swish' + norm_position: Optional[str] = 'post' + norm_type: Optional[str] = 'mixedfusedlayernorm' + column_init_method: str = 'xavier' + row_init_method: str = 'zero' + gather_output: bool = True + input_is_parallel: bool = False + dropout: float = 0.0 + dropout_position: str = 'post' + alpha: float | None = None + network_alpha: int | None = None + a2a_experimental: bool = False + num_query_groups: Optional[int] = None + kv_channels: Optional[int] = None + _target_: str = "{0}.{1}".format(LoraUnfusedKQVAdapter.__module__, LoraUnfusedKQVAdapter.__name__) + + class PromptEncoderAdapter(nn.Module, AdapterModuleUtil): """ The Tensor Parallel MLP prompt encoder network that is used to generate the virtual diff --git a/nemo/collections/nlp/parts/peft_config.py b/nemo/collections/nlp/parts/peft_config.py index 63caa409b218..47d5167d630e 100644 --- a/nemo/collections/nlp/parts/peft_config.py +++ b/nemo/collections/nlp/parts/peft_config.py @@ -36,6 +36,8 @@ LoraHto4HAdapterConfig, LoraKQVAdapterConfig, LoraKQVAdapterWeightTyingConfig, + LoraUnfusedHto4HAdapterConfig, + LoraUnfusedKQVAdapterConfig, MLPInfusedAdapterConfig, ParallelLinearAdapterConfig, ParallelLinearAdapterWeightTyingConfig, @@ -132,11 +134,26 @@ def __init__(self, cfg): for module in target_modules: if module == PEFT_MODULE_MAP["qkv_module"]: - adapter_cfg = self._create_lora_config( - cfg, lora_cfg, cfg.hidden_size, qkv_projection_size, LoraKQVAdapterConfig - ) - name_key_to_cfg[AdapterName.LORA_KQV_ADAPTER] = adapter_cfg - name_key_to_mcore_mixins[AdapterName.LORA_KQV_ADAPTER] = [("self_attention", MCoreSelfAttentionMixin)] + if lora_cfg.get("variant", "nemo") == "canonical": + _adapter_name = AdapterName.LORA_UNFUSED_KQV_ADAPTER + _adapter_cfg_cls = LoraUnfusedKQVAdapterConfig + adapter_cfg = self._create_lora_config( + cfg, + lora_cfg, + cfg.hidden_size, + qkv_projection_size, + _adapter_cfg_cls, + num_query_groups=num_query_groups, + kv_channels=kv_channels, + ) + else: + _adapter_name = AdapterName.LORA_KQV_ADAPTER + _adapter_cfg_cls = LoraKQVAdapterConfig + adapter_cfg = self._create_lora_config( + cfg, lora_cfg, cfg.hidden_size, qkv_projection_size, _adapter_cfg_cls + ) + name_key_to_cfg[_adapter_name] = adapter_cfg + name_key_to_mcore_mixins[_adapter_name] = [("self_attention", MCoreSelfAttentionMixin)] elif module == PEFT_MODULE_MAP["dense_module"]: adapter_cfg = self._create_lora_config( @@ -149,11 +166,18 @@ def __init__(self, cfg): elif module == PEFT_MODULE_MAP["hto4h_module"]: hto4h_projection_size = cfg.ffn_hidden_size * 2 if fast_glu_activation else cfg.ffn_hidden_size + if lora_cfg.get("variant", "nemo") == "canonical": + _adapter_name = AdapterName.LORA_UNFUSED_Hto4H_ADAPTER + _adapter_cfg_cls = LoraUnfusedHto4HAdapterConfig + else: + _adapter_name = AdapterName.LORA_Hto4H_ADAPTER + _adapter_cfg_cls = LoraHto4HAdapterConfig + adapter_cfg = self._create_lora_config( - cfg, lora_cfg, cfg.hidden_size, hto4h_projection_size, LoraHto4HAdapterConfig + cfg, lora_cfg, cfg.hidden_size, hto4h_projection_size, _adapter_cfg_cls ) - name_key_to_cfg[AdapterName.LORA_Hto4H_ADAPTER] = adapter_cfg - name_key_to_mcore_mixins[AdapterName.LORA_Hto4H_ADAPTER] = [("mlp", MCoreMLPMixin)] + name_key_to_cfg[_adapter_name] = adapter_cfg + name_key_to_mcore_mixins[_adapter_name] = [("mlp", MCoreMLPMixin)] elif module == PEFT_MODULE_MAP["4htoh_module"]: adapter_cfg = self._create_lora_config( cfg, lora_cfg, cfg.ffn_hidden_size, cfg.hidden_size, Lora4HtoHAdapterConfig @@ -170,7 +194,9 @@ def __init__(self, cfg): self.name_key_to_mcore_mixins = name_key_to_mcore_mixins super().__init__(lora_cfg, name_key_to_cfg) - def _create_lora_config(self, cfg, lora_cfg, in_features, out_features, adapter_cfg_cls): + def _create_lora_config( + self, cfg, lora_cfg, in_features, out_features, adapter_cfg_cls, num_query_groups=None, kv_channels=None + ): config_args = { "in_features": in_features, "out_features": out_features, @@ -187,6 +213,12 @@ def _create_lora_config(self, cfg, lora_cfg, in_features, out_features, adapter_ "a2a_experimental": lora_cfg.get("a2a_experimental", False), } + if adapter_cfg_cls == LoraUnfusedKQVAdapterConfig: + assert num_query_groups is not None, "num_query_groups must be provided for canonical Lora" + assert kv_channels is not None, "kv_channels must be provided for canonical Lora" + config_args.update({"num_query_groups": num_query_groups, "kv_channels": kv_channels}) + config_args.pop("out_features") + if lora_cfg.weight_tying: position_embedding_strategy = lora_cfg.get("position_embedding_strategy", None) if position_embedding_strategy is None: diff --git a/scripts/checkpoint_converters/lora_converters/convert_nemo_to_canonical.py b/scripts/checkpoint_converters/lora_converters/convert_nemo_to_canonical.py new file mode 100644 index 000000000000..f2974aca1642 --- /dev/null +++ b/scripts/checkpoint_converters/lora_converters/convert_nemo_to_canonical.py @@ -0,0 +1,212 @@ +#!/usr/bin/env +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Convert nemo style (fused) lora checkpoint to canonical (unfused) lora checkpoint. +Currently supports TP=PP=1 only. + +Example usage: +python scripts/checkpoint_converters/lora_converters/convert_nemo_to_canonical.py \ + --lora_path nemo_style_lora_model.nemo \ + --output_path ./canonical_style_lora_model.nemo + +""" +import tempfile +from argparse import ArgumentParser +from typing import Dict + +import torch +from omegaconf import OmegaConf, open_dict +from scripts.nlp_language_modeling.merge_lora_weights.merge import replace_number_add_offset + +from nemo.collections.nlp.parts.nlp_overrides import NLPSaveRestoreConnector + + +def rename_keys(key): + new_keys = [] + if "lora_kqv_adapter" in key: + new_keys.append(key.replace(".lora_kqv_adapter.", ".lora_unfused_kqv_adapter.q_adapter.")) + new_keys.append(key.replace(".lora_kqv_adapter.", ".lora_unfused_kqv_adapter.k_adapter.")) + new_keys.append(key.replace(".lora_kqv_adapter.", ".lora_unfused_kqv_adapter.v_adapter.")) + elif "lora_hto4h_adapter" in key: + new_keys.append(key.replace(".lora_hto4h_adapter.", ".lora_unfused_hto4h_adapter.gate_adapter.")) + new_keys.append(key.replace(".lora_hto4h_adapter.", ".lora_unfused_hto4h_adapter.up_adapter.")) + return new_keys + + +def reformat_module_names_to_hf(tensors: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: + new_tensors = dict() + for module_name, module_weight in tensors.items(): + # map linear_in and linear_out to lora_a/lora_b counterparts + new_module_name = "base_model." + module_name.replace("linear_in", "lora_A").replace("linear_out", "lora_B") + + # map target modules to their vLLM/HF counterparts + new_module_name = new_module_name.replace("q_adapter", "q_proj") + new_module_name = new_module_name.replace("k_adapter", "k_proj") + new_module_name = new_module_name.replace("v_adapter", "v_proj") + new_module_name = new_module_name.replace("lora_dense_attention_adapter", "o_proj") + new_module_name = new_module_name.replace("lora_4htoh_adapter", "down_proj") + new_module_name = new_module_name.replace("gate_adapter", "gate_proj") + new_module_name = new_module_name.replace("up_adapter", "up_proj") + + # map other parts of the module names to fit vLLM/huggingface + new_module_name = new_module_name.replace(".adapter_layer", "") + new_module_name = new_module_name.replace(".lora_unfused_kqv_proj", "") + new_module_name = new_module_name.replace(".lora_unfused_hto4h_adapter", "") + new_module_name = new_module_name.replace("self_attention", "self_attn") + new_module_name = new_module_name.replace("decoder", "model") + + new_tensors[new_module_name] = module_weight + return new_tensors + + +def convert_hto4h(lora_weights, lora_config): + assert len(lora_weights) == 1, "Only single TP supported for now" + keys_to_update = [] + for key in lora_weights[0].keys(): + if "lora_hto4h_adapter" in key: + keys_to_update.append(key) + + for key in keys_to_update: + if "linear_in" in key: + for new_key in rename_keys(key): + lora_weights[0][new_key] = lora_weights[0][key] + print(new_key, lora_weights[0][new_key].shape) + elif "linear_out" in key: + for idx, new_key in enumerate(rename_keys(key)): + orginal_shape = lora_weights[0][key].shape[0] + lora_weights[0][new_key] = lora_weights[0][key][ + idx * (orginal_shape // 2) : (idx + 1) * (orginal_shape // 2) + ] + print(new_key, lora_weights[0][new_key].shape) + + lora_weights[0].pop(key) + return lora_weights + + +def convert_qkv(lora_weights, lora_model_cfg): + assert len(lora_weights) == 1, "Only single TP supported for now" + if ( + lora_model_cfg.get("num_query_groups", lora_model_cfg.num_attention_heads) + != lora_model_cfg.num_attention_heads + ): + kv_channels = int(lora_model_cfg.hidden_size / lora_model_cfg.num_attention_heads) + kv_size = int(lora_model_cfg.num_query_groups * kv_channels) + else: + kv_size = int(lora_model_cfg.hidden_size) + q_size = lora_model_cfg.hidden_size + k_size, v_size = kv_size, kv_size + + keys_to_update = [] + for key in lora_weights[0].keys(): + if "lora_kqv_adapter" in key: + keys_to_update.append(key) + + for key in keys_to_update: + if "linear_in" in key: + for new_key in rename_keys(key): + lora_weights[0][new_key] = lora_weights[0][key] + print(new_key, lora_weights[0][new_key].shape) + elif "linear_out" in key: + srt = 0 + for new_key, size in zip(rename_keys(key), [q_size, k_size, v_size]): + lora_weights[0][new_key] = lora_weights[0][key][srt : srt + size] + print(new_key, lora_weights[0][new_key].shape) + srt = srt + size + + lora_weights[0].pop(key) + return lora_weights + + +def convert_lora(lora_nemo, save_path, hf_format=False): + with tempfile.TemporaryDirectory() as tmpdir: + NLPSaveRestoreConnector._unpack_nemo_file(lora_nemo, tmpdir) + config_file = f"{tmpdir}/model_config.yaml" + lora_config = OmegaConf.load(config_file) + tp_size = lora_config.tensor_model_parallel_size + pp_size = lora_config.pipeline_model_parallel_size + + lora_state_dict = [{}] * tp_size + + for pp in range(pp_size): + for tp in range(tp_size): + if tp_size == 1: + ckpt_file = f"{tmpdir}/model_weights.ckpt" + elif pp_size == 1: + ckpt_file = f"{tmpdir}/mp_rank_{tp:02d}/model_weights.ckpt" + else: + ckpt_file = f"{tmpdir}/tp_rank_{tp:02d}_pp_rank_{pp:03d}/model_weights.ckpt" + + l = torch.load(ckpt_file, map_location=torch.device('cpu')) + if pp == 0: + lora_state_dict[tp] = l + else: + # calculate layer offset + layer_offset = lora_config.num_layers // pp_size * pp + for key, value in l.items(): + new_key = replace_number_add_offset(key, layer_offset) + lora_state_dict[tp][new_key] = value + + with open_dict(lora_config): + lora_config.peft.lora_tuning.variant = "canonical" + with open(f"{tmpdir}/model_config.yaml", "w") as f: + OmegaConf.save(lora_config, f) + lora_state_dict = convert_qkv(lora_state_dict, lora_config) + lora_state_dict = convert_hto4h(lora_state_dict, lora_config) + # TODO: currently suport tp=1 + lora_state_dict = lora_state_dict[0] + if hf_format: + lora_state_dict = reformat_module_names_to_hf(lora_state_dict) + torch.save(lora_state_dict, f"{save_path}/model_weights_hf_formatted.pt") + else: + torch.save(lora_state_dict, f"{tmpdir}/model_weights.ckpt") + NLPSaveRestoreConnector._make_nemo_file_from_folder(save_path, tmpdir) + + return lora_state_dict, lora_config + + +def fix_for_O2(state_dict): + new_state_dict = {} + for k, v in state_dict.items(): + if "model.module." not in k: + new_state_dict[k.replace('model.', 'model.module.')] = v + return new_state_dict + + +def get_args(): + parser = ArgumentParser() + parser.add_argument( + "--lora_path", + type=str, + default=None, + required=True, + help="Path to NeMo style (fused) lora checkpoint in .nemo file format", + ) + parser.add_argument( + "--output_path", + type=str, + default=None, + required=True, + help="Path to save the canonical (unfused) lora .nemo file.", + ) + parser.add_argument("--hf_format", action='store_true', help="saves tensors in huggingface naming format.") + parser.add_argument("--precision", type=str, default="16", help="Model precision") + args = parser.parse_args() + return args + + +if __name__ == '__main__': + args = get_args() + convert_lora(args.lora_path, args.output_path, args.hf_format) From e267406afe4369c67d29989b4fe7bd0c0a9a1f5e Mon Sep 17 00:00:00 2001 From: Jan Lasek Date: Wed, 1 May 2024 19:54:47 +0200 Subject: [PATCH 02/73] Restore PTQ tests for Llama2 (reopened) (#9064) * Restore PTQ tests for Llama2 (MR-9018) Signed-off-by: Jan Lasek * try not using release Signed-off-by: eharper * checkout v4 Signed-off-by: eharper --------- Signed-off-by: Jan Lasek Signed-off-by: eharper Co-authored-by: eharper --- .github/workflows/cicd-main.yml | 149 +++++++++++++++++++++++++++++- nemo/export/quantize/quantizer.py | 4 +- nemo/utils/model_utils.py | 20 +++- tests/setup/__main__.py | 4 +- 4 files changed, 166 insertions(+), 11 deletions(-) diff --git a/.github/workflows/cicd-main.yml b/.github/workflows/cicd-main.yml index de250596da62..6f090bd34213 100644 --- a/.github/workflows/cicd-main.yml +++ b/.github/workflows/cicd-main.yml @@ -132,6 +132,9 @@ jobs: apt-get update && apt-get install libsox-fmt-all -y && \ popd + # AMMO installation + pip install nvidia-ammo~=0.9.0 --extra-index-url https://pypi.nvidia.com --no-cache-dir + # PyTorch Lightning version python -c "import pytorch_lightning; print(pytorch_lightning.__version__)" @@ -220,7 +223,26 @@ jobs: - uses: "NVIDIA/NeMo/.github/actions/cancel-workflow@main" if: "failure()" - + L0_Setup_Test_Data_And_Models: + needs: [cicd-test-container-setup] + runs-on: self-hosted-azure + container: + image: nemoci.azurecr.io/nemo_container_${{ github.run_id }} + options: + # --user 0:128 + --device=/dev/nvidia0 + --gpus all + --shm-size=8g + --env TRANSFORMERS_OFFLINE=0 + --env HYDRA_FULL_ERROR=1 + --volume /mnt/datadrive/TestData:/home/TestData + steps: + - name: Checkout repository + uses: actions/checkout@v4 + - run: | + python -m tests.setup --save_dir /home/TestData/nlp + - uses: "NVIDIA/NeMo/.github/actions/cancel-workflow@main" + if: "failure()" ## - name: L2: Multimodal Imagen Train @@ -243,10 +265,9 @@ jobs: uses: actions/checkout@v4 - run: | CUDA_VISIBLE_DEVICES=0 python scripts/checkpoint_converters/convert_llama_hf_to_nemo.py \ - --input_name_or_path=/home/TestData/nlp/megatron_llama/llama-ci-hf \ - --output_path=/home/TestData/nlp/megatron_llama/llama-ci-hf/llama_ci.nemo \ + --input_name_or_path=/home/TestData/nlp/megatron_llama/llama-ci-hf-tiny \ + --output_path=/home/TestData/nlp/megatron_llama/llama_ci.nemo \ --precision=16 - rm -f /home/TestData/nlp/megatron_llama/llama-ci-hf/llama_ci.nemo - uses: "NVIDIA/NeMo/.github/actions/cancel-workflow@main" if: "failure()" @@ -322,6 +343,124 @@ jobs: - uses: "NVIDIA/NeMo/.github/actions/cancel-workflow@main" if: "failure()" + L2_PTQ_Llama2_Export_Only: + needs: [cicd-test-container-setup] + runs-on: self-hosted-azure + container: + image: nemoci.azurecr.io/nemo_container_${{ github.run_id }} + options: + # --user 0:128 + --device=/dev/nvidia0 + --gpus all + --shm-size=8g + --env TRANSFORMERS_OFFLINE=0 + --env HYDRA_FULL_ERROR=1 + --volume /mnt/datadrive/TestData:/home/TestData + steps: + - name: Checkout repository + uses: actions/checkout@v4 + - run: | + python examples/nlp/language_modeling/megatron_llama_quantization.py \ + model_file=/home/TestData/nlp/megatron_llama/llama_ci.nemo \ + quantization.algorithm=null \ + model_save=/home/TestData/nlp/megatron_llama/ci_baseline + + rm -rf /home/TestData/nlp/megatron_llama/ci_baseline + - uses: "NVIDIA/NeMo/.github/actions/cancel-workflow@main" + if: "failure()" + + L2_PTQ_Llama2_FP8: + needs: [cicd-test-container-setup] + runs-on: self-hosted-azure + container: + image: nemoci.azurecr.io/nemo_container_${{ github.run_id }} + options: + # --user 0:128 + --device=/dev/nvidia0 + --gpus all + --shm-size=8g + --env TRANSFORMERS_OFFLINE=0 + --env HYDRA_FULL_ERROR=1 + --volume /mnt/datadrive/TestData:/home/TestData + steps: + - name: Checkout repository + uses: actions/checkout@v4 + - run: | + python examples/nlp/language_modeling/megatron_llama_quantization.py \ + model_file=/home/TestData/nlp/megatron_llama/llama_ci.nemo \ + tensor_model_parallel_size=2 \ + trainer.devices=2 \ + quantization.calib_dataset=/home/TestData/nlp/test_quantization/test.json \ + quantization.algorithm=fp8 \ + quantization.num_calib_size=8 \ + inference.batch_size=2 \ + export.inference_tensor_parallel=2 \ + model_save=/home/TestData/nlp/megatron_llama/ci_fp8.qnemo + + rm -rf /home/TestData/nlp/megatron_llama/ci_fp8.qnemo + - uses: "NVIDIA/NeMo/.github/actions/cancel-workflow@main" + if: "failure()" + + L2_PTQ_Llama2_INT8_SQ: + needs: [cicd-test-container-setup] + runs-on: self-hosted-azure + container: + image: nemoci.azurecr.io/nemo_container_${{ github.run_id }} + options: + # --user 0:128 + --device=/dev/nvidia0 + --gpus all + --shm-size=8g + --env TRANSFORMERS_OFFLINE=0 + --env HYDRA_FULL_ERROR=1 + --volume /mnt/datadrive/TestData:/home/TestData + steps: + - name: Checkout repository + uses: actions/checkout@v4 + - run: | + python examples/nlp/language_modeling/megatron_llama_quantization.py \ + model_file=/home/TestData/nlp/megatron_llama/llama_ci.nemo \ + quantization.calib_dataset=/home/TestData/nlp/test_quantization/test.json \ + quantization.algorithm=int8_sq \ + quantization.num_calib_size=8 \ + inference.batch_size=2 \ + model_save=/home/TestData/nlp/megatron_llama/ci_int8_sq.qnemo + + rm -rf /home/TestData/nlp/megatron_llama/ci_int8_sq.qnemo + - uses: "NVIDIA/NeMo/.github/actions/cancel-workflow@main" + if: "failure()" + + L2_PTQ_Llama2_INT4_AWQ: + needs: [cicd-test-container-setup] + runs-on: self-hosted-azure + container: + image: nemoci.azurecr.io/nemo_container_${{ github.run_id }} + options: + # --user 0:128 + --device=/dev/nvidia0 + --gpus all + --shm-size=8g + --env TRANSFORMERS_OFFLINE=0 + --env HYDRA_FULL_ERROR=1 + --volume /mnt/datadrive/TestData:/home/TestData + steps: + - name: Checkout repository + uses: actions/checkout@v4 + - run: | + python examples/nlp/language_modeling/megatron_llama_quantization.py \ + model_file=/home/TestData/nlp/megatron_llama/llama_ci.nemo \ + tensor_model_parallel_size=1 \ + trainer.devices=1 \ + quantization.calib_dataset=/home/TestData/nlp/test_quantization/test.json \ + quantization.algorithm=int4_awq \ + quantization.num_calib_size=8 \ + inference.batch_size=2 \ + model_save=/home/TestData/nlp/megatron_llama/ci_int4_awq.qnemo + + rm -rf /home/TestData/nlp/megatron_llama/ci_int4_awq.qnemo + - uses: "NVIDIA/NeMo/.github/actions/cancel-workflow@main" + if: "failure()" + # L2: ASR dev run ASR_dev_run_Speech_to_Text: needs: [cicd-test-container-setup] @@ -4664,7 +4803,7 @@ jobs: --volume /mnt/datadrive/TestData:/home/TestData steps: - name: Checkout repository - uses: actions/checkout@v2 + uses: actions/checkout@v4 - run: | rm -rf /home/TestData/nlp/megatron_ir/working_dir diff --git a/nemo/export/quantize/quantizer.py b/nemo/export/quantize/quantizer.py index 2663f8fe9bac..783f47a08e79 100644 --- a/nemo/export/quantize/quantizer.py +++ b/nemo/export/quantize/quantizer.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import copy import tarfile from contextlib import nullcontext from typing import List, Optional @@ -21,7 +20,6 @@ import torch.distributed as dist from megatron.core import parallel_state from megatron.core.transformer.module import Float16Module -from megatron.training.utils import unwrap_model from omegaconf import OmegaConf from omegaconf.omegaconf import DictConfig, open_dict from pytorch_lightning.trainer.trainer import Trainer @@ -31,7 +29,7 @@ from nemo.collections.nlp.parts.utils_funcs import torch_dtype_from_precision from nemo.utils import logging from nemo.utils.distributed import temporary_directory -from nemo.utils.model_utils import load_config, save_artifacts +from nemo.utils.model_utils import load_config, save_artifacts, unwrap_model try: import ammo.torch.quantization as atq diff --git a/nemo/utils/model_utils.py b/nemo/utils/model_utils.py index 95d1bc414625..f4eefd39a9ea 100644 --- a/nemo/utils/model_utils.py +++ b/nemo/utils/model_utils.py @@ -24,7 +24,7 @@ from enum import Enum from functools import lru_cache from pathlib import Path -from typing import List, Optional, Tuple, Union +from typing import List, Optional, Tuple, Type, Union import wrapt @@ -92,6 +92,24 @@ def load_config(model_file: str) -> DictConfig: return model_config +def unwrap_model(model, module_instances: Union[Type, Tuple[Type]]): + """Unwrap model from wrapper classes like Float16Module, for example.""" + + # TODO: Import this from megatron.core once moved there from megatron.training. + return_list = True + if not isinstance(model, list): + model = [model] + return_list = False + unwrapped_model = [] + for model_module in model: + while isinstance(model_module, module_instances): + model_module = model_module.module + unwrapped_model.append(model_module) + if not return_list: + return unwrapped_model[0] + return unwrapped_model + + def param_is_not_shared(param): return not hasattr(param, 'shared') or not param.shared diff --git a/tests/setup/__main__.py b/tests/setup/__main__.py index 289a2537e2f2..a08ccdaa1634 100644 --- a/tests/setup/__main__.py +++ b/tests/setup/__main__.py @@ -34,8 +34,8 @@ ) create_hf_model( - model_name_or_path="/home/TestData/nlp/meta-llama/Llama-2-7b-hf", - output_dir=os.path.join(args.save_dir, "megatron_llama/llama-ci-hf"), + model_name_or_path="/home/TestData/nlp/megatron_llama/llama-ci-hf", + output_dir=os.path.join(args.save_dir, "megatron_llama/llama-ci-hf-tiny"), config_updates={"hidden_size": 256, "num_attention_heads": 4, "num_hidden_layers": 2, "num_key_value_heads": 4}, overwrite=args.overwrite, ) From 3d87ed7109a456b526f2fe19809623cd5183d5b3 Mon Sep 17 00:00:00 2001 From: Ali Taghibakhshi <71892896+JRD971000@users.noreply.github.com> Date: Wed, 1 May 2024 13:10:59 -0500 Subject: [PATCH 03/73] add clip H config (#9082) * add clip H config * add comment to 1st line of yaml --- .../clip/conf/megatron_clip_VIT-H-14.yaml | 204 ++++++++++++++++++ 1 file changed, 204 insertions(+) create mode 100644 examples/multimodal/vision_language_foundation/clip/conf/megatron_clip_VIT-H-14.yaml diff --git a/examples/multimodal/vision_language_foundation/clip/conf/megatron_clip_VIT-H-14.yaml b/examples/multimodal/vision_language_foundation/clip/conf/megatron_clip_VIT-H-14.yaml new file mode 100644 index 000000000000..b37d64a325e5 --- /dev/null +++ b/examples/multimodal/vision_language_foundation/clip/conf/megatron_clip_VIT-H-14.yaml @@ -0,0 +1,204 @@ +# An example model that works with this config is "https://huggingface.co/yuvalkirstain/PickScore_v1" +model: + precision: 32 + # specify micro_batch_size, global_batch_size, and model parallelism + # gradient accumulation will be done automatically based on data_parallel_size + micro_batch_size: 32 # limited by GPU memory + global_batch_size: 32 # will use more micro batches to reach global batch size + tensor_model_parallel_size: 1 # intra-layer model parallelism + pipeline_model_parallel_size: 1 # inter-layer model parallelism + virtual_pipeline_model_parallel_size: null # interleaved pipeline + + restore_from_pretrained: null # used in fine-tuning + # multimodal configs + output_dim: 1024 + # As the number of devices used to train increases, so does the space complexity of + # the logit matrix. Using a naïve all-gather scheme, space complexity will be + # `O(n^2)`. Instead, complexity may become effectively linear if the flags + # `--gather-with-grad` and `--local-loss` are used. This alteration results in one-to-one + # numerical results as the naïve method. + local_loss: False # calculate loss w/ local features @ global (instead of realizing full global @ global matrix) + gather_with_grad: True # enable full distributed gradient for feature gather, set this to False may cause convergence issue + + vision: + precision: 32 + # vision configs + patch_dim: 14 + img_h: 224 + img_w: 224 + image_mean: null + image_std: null + num_channels: 3 + drop_patch_rate: 0.0 + drop_path_rate: 0.0 + global_average_pool: False + output_dim: ${model.output_dim} + class_token_length: 1 + preprocess_layernorm: True # apply layer norm to embedded tokens + + # model architecture + encoder_seq_length: 196 + max_position_embeddings: ${.encoder_seq_length} + position_embedding_type: learned_parameters + num_layers: 32 + hidden_size: 1280 + ffn_hidden_size: 5120 # Transformer FFN hidden size. Usually 4 * hidden_size. + num_attention_heads: 16 + init_method_std: 0.02 # Standard deviation of the zero mean normal distribution used for weight initialization.') + use_scaled_init_method: True # use scaled residuals initialization + hidden_dropout: 0. # Dropout probability for hidden state transformer. + attention_dropout: 0. + kv_channels: null # Projection weights dimension in multi-head attention. Set to hidden_size // num_attention_heads if null + apply_query_key_layer_scaling: True # scale Q * K^T by 1 / layer-number. + normalization: layernorm # Type of normalization layers + layernorm_epsilon: 1e-5 + do_layer_norm_weight_decay: False # True means weight decay on all params + pre_process: True # add embedding + post_process: True # add pooler + persist_layer_norm: True # Use of persistent fused layer norm kernel. + + ## Activation Checkpointing + activations_checkpoint_granularity: null # 'selective' or 'full' + activations_checkpoint_method: null # 'uniform', 'block', not used with 'selective' + activations_checkpoint_num_layers: null # not used with 'selective' + sequence_parallel: False + + # precision + native_amp_init_scale: 4294967296 # 2 ** 32 + native_amp_growth_interval: 1000 + hysteresis: 2 # Gradient scale hysteresis + fp32_residual_connection: False # Move residual connections to fp32 + fp16_lm_cross_entropy: False # Move the cross entropy unreduced loss calculation for lm head to fp16 + + # model fusions + masked_softmax_fusion: True # Use a kernel that fuses the attention softmax with it's mask. + bias_dropout_add_fusion: True # Use a kernel that fuses the bias addition, dropout and residual connection addition. + + use_cpu_initialization: False # Init weights on the CPU (slow for large models) + onnx_safe: False # Use work-arounds for known problems with Torch ONNX exporter. + gradient_accumulation_fusion: False # Fuse weight gradient accumulation to GEMMs. Only used with pipeline parallelism. + openai_gelu: False + bias_activation_fusion: False + megatron_legacy: True + activation: gelu + + + + text: + precision: 32 + # text configs + output_dim: ${model.output_dim} + + # model architecture + encoder_seq_length: 77 + max_position_embeddings: ${.encoder_seq_length} + position_embedding_type: learned_parameters + num_layers: 24 + hidden_size: 1024 + ffn_hidden_size: 4096 # Transformer FFN hidden size. Usually 4 * hidden_size. + num_attention_heads: 16 + init_method_std: 0.02 # Standard deviation of the zero mean normal distribution used for weight initialization.') + use_scaled_init_method: True # use scaled residuals initialization + hidden_dropout: 0. # Dropout probability for hidden state transformer. + attention_dropout: 0. + kv_channels: null # Projection weights dimension in multi-head attention. Set to hidden_size // num_attention_heads if null + apply_query_key_layer_scaling: True # scale Q * K^T by 1 / layer-number. + normalization: layernorm # Type of normalization layers + layernorm_epsilon: 1e-5 + do_layer_norm_weight_decay: False # True means weight decay on all params + pre_process: True # add embedding + post_process: True # add pooler + persist_layer_norm: True # Use of persistent fused layer norm kernel. + + ## Activation Checkpointing + activations_checkpoint_granularity: null # 'selective' or 'full' + activations_checkpoint_method: null # 'uniform', 'block', not used with 'selective' + activations_checkpoint_num_layers: null # not used with 'selective' + num_micro_batches_with_partial_activation_checkpoints: null + activations_checkpoint_layers_per_pipeline: null + sequence_parallel: False + + # precision + native_amp_init_scale: 4294967296 # 2 ** 32 + native_amp_growth_interval: 1000 + hysteresis: 2 # Gradient scale hysteresis + fp32_residual_connection: False # Move residual connections to fp32 + fp16_lm_cross_entropy: False # Move the cross entropy unreduced loss calculation for lm head to fp16 + + # model fusions + masked_softmax_fusion: True # Use a kernel that fuses the attention softmax with it's mask. + bias_dropout_add_fusion: True # Use a kernel that fuses the bias addition, dropout and residual connection addition. + + use_cpu_initialization: False # Init weights on the CPU (slow for large models) + onnx_safe: False # Use work-arounds for known problems with Torch ONNX exporter. + gradient_accumulation_fusion: False # Fuse weight gradient accumulation to GEMMs. Only used with pipeline parallelism. + openai_gelu: False + bias_activation_fusion: False + megatron_legacy: True + + transformer_engine: False + fp8: False # enables fp8 in TransformerLayer forward + fp8_e4m3: False # sets fp8_format = recipe.Format.E4M3 + fp8_hybrid: False # sets fp8_format = recipe.Format.HYBRID + fp8_margin: 0 # scaling margin + fp8_interval: 1 # scaling update interval + fp8_amax_history_len: 1 # Number of steps for which amax history is recorded per tensor + fp8_amax_compute_algo: most_recent # 'most_recent' or 'max'. Algorithm for computing amax from history + use_emha: False # Use fused multi-head attention for large sequence-length. Note this is not yet supported. Please set to False. + activation: gelu + + # Megatron O2-style half-precision + megatron_amp_O2: False # Enable O2-level automatic mixed precision using main parameters + grad_allreduce_chunk_size_mb: 125 + grad_div_ar_fusion: True # Fuse grad division into torch.distributed.all_reduce + + # miscellaneous + seed: 1234 + resume_from_checkpoint: null # manually set the checkpoint file to load from + apex_transformer_log_level: 30 # Python logging level displays logs with severity greater than or equal to this + gradient_as_bucket_view: True # PyTorch DDP argument. Allocate gradients in a contiguous bucket to save memory (less fragmentation and buffer memory) + + tokenizer: + library: 'huggingface' + type: 'openai/clip-vit-large-patch14' + model: null + vocab_file: null + merge_file: null + delimiter: null # only used for tabular tokenizer + sentencepiece_legacy: False # Legacy=True allows you to add special tokens to sentencepiece tokenizers. + make_vocab_size_divisible_by: 128 # Pad the vocab size to be divisible by this value for computation efficiency. + + data: + num_workers: 8 + train: + dataset_path: # List of paths to pkl files or tar files + - /datasets/coyo/test.pkl + validation: # List of paths to pkl files or tar files + dataset_path: + - /datasets/coyo/test.pkl + webdataset: + infinite_sampler: False + local_root_path: /datasets/coyo + + imagenet_val: null # Path to imagenet val set for conducting zero shot evaluation. + + # Nsys profiling options + nsys_profile: + enabled: False + start_step: 10 # Global batch to start profiling + end_step: 10 # Global batch to end profiling + ranks: [ 0 ] # Global rank IDs to profile + gen_shape: False # Generate model and kernel details including input shapes + + optim: + name: fused_adam + lr: 1e-3 + weight_decay: 0.2 + betas: + - 0.9 + - 0.98 + sched: + name: CosineAnnealing + warmup_steps: 2000 + constant_steps: 0 + min_lr: 1e-5 \ No newline at end of file From f658b6f0445403c338c7371941b1fe644832df48 Mon Sep 17 00:00:00 2001 From: anteju <108555623+anteju@users.noreply.github.com> Date: Wed, 1 May 2024 12:59:24 -0700 Subject: [PATCH 04/73] Score-based generative enhancement model (#8567) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Score-based generative enhancement model in NeMo * Addressed comments, added unit test Signed-off-by: Ante Jukić --- examples/audio_tasks/audio_to_audio_eval.py | 18 +- examples/audio_tasks/conf/beamforming.yaml | 1 - examples/audio_tasks/conf/masking.yaml | 3 - examples/audio_tasks/conf/predictive.yaml | 130 ++ .../conf/score_based_generative.yaml | 149 ++ examples/audio_tasks/speech_enhancement.py | 45 +- nemo/collections/asr/data/audio_to_audio.py | 47 +- .../asr/data/audio_to_audio_dataset.py | 3 + nemo/collections/asr/losses/__init__.py | 2 +- nemo/collections/asr/losses/audio_losses.py | 226 ++- nemo/collections/asr/metrics/audio.py | 7 + nemo/collections/asr/models/__init__.py | 6 +- .../asr/models/audio_to_audio_model.py | 386 ++++- .../asr/models/enhancement_models.py | 695 +++++---- nemo/collections/asr/modules/audio_modules.py | 7 +- .../asr/modules/audio_preprocessing.py | 59 +- .../asr/parts/submodules/diffusion.py | 1310 +++++++++++++++++ requirements/requirements_asr.txt | 1 + tests/collections/asr/test_asr_datasets.py | 33 + tests/collections/asr/test_asr_losses.py | 192 ++- .../asr/test_audio_preprocessing.py | 14 +- 21 files changed, 2985 insertions(+), 349 deletions(-) create mode 100644 examples/audio_tasks/conf/predictive.yaml create mode 100644 examples/audio_tasks/conf/score_based_generative.yaml create mode 100644 nemo/collections/asr/parts/submodules/diffusion.py diff --git a/examples/audio_tasks/audio_to_audio_eval.py b/examples/audio_tasks/audio_to_audio_eval.py index 4ac68dfc84e7..ab6623df298d 100644 --- a/examples/audio_tasks/audio_to_audio_eval.py +++ b/examples/audio_tasks/audio_to_audio_eval.py @@ -61,6 +61,7 @@ import json import os import tempfile +from collections import defaultdict from dataclasses import dataclass, field, is_dataclass from typing import List, Optional @@ -101,6 +102,9 @@ class AudioEvaluationConfig(process_audio.ProcessConfig): # Metrics to calculate metrics: List[str] = field(default_factory=lambda: ['sdr', 'estoi']) + # Return metric values for each example + return_values_per_example: bool = False + def get_evaluation_dataloader(config): """Prepare a dataloader for evaluation. @@ -174,6 +178,9 @@ def main(cfg: AudioEvaluationConfig): # Setup metrics metrics = get_metrics(cfg) + if cfg.return_values_per_example and cfg.batch_size > 1: + raise ValueError('return_example_values is only supported for batch_size=1.') + # Processing if not cfg.only_score_manifest: # Process audio using the configured model and save in the output directory @@ -236,6 +243,10 @@ def main(cfg: AudioEvaluationConfig): num_files += 1 + if cfg.max_utts is not None and num_files >= cfg.max_utts: + logging.info('Reached max_utts: %s', cfg.max_utts) + break + # Prepare dataloader config = { 'manifest_filepath': temporary_manifest_filepath, @@ -249,6 +260,8 @@ def main(cfg: AudioEvaluationConfig): } temporary_dataloader = get_evaluation_dataloader(config) + metrics_value_per_example = defaultdict(list) + # Calculate metrics for eval_batch in tqdm(temporary_dataloader, desc='Evaluating'): processed_signal, processed_length, target_signal, target_length = eval_batch @@ -257,7 +270,9 @@ def main(cfg: AudioEvaluationConfig): raise RuntimeError(f'Length mismatch.') for name, metric in metrics.items(): - metric.update(preds=processed_signal, target=target_signal, input_length=target_length) + value = metric(preds=processed_signal, target=target_signal, input_length=target_length) + if cfg.return_values_per_example: + metrics_value_per_example[name].append(value.item()) # Convert to a dictionary with name: value metrics_value = {name: metric.compute().item() for name, metric in metrics.items()} @@ -277,6 +292,7 @@ def main(cfg: AudioEvaluationConfig): # Inject the metric name and score into the config, and return the entire config with open_dict(cfg): cfg.metrics_value = metrics_value + cfg.metrics_value_per_example = dict(metrics_value_per_example) return cfg diff --git a/examples/audio_tasks/conf/beamforming.yaml b/examples/audio_tasks/conf/beamforming.yaml index 18e04f0bd12a..3abc4f134e64 100644 --- a/examples/audio_tasks/conf/beamforming.yaml +++ b/examples/audio_tasks/conf/beamforming.yaml @@ -44,7 +44,6 @@ model: _target_: nemo.collections.asr.modules.audio_preprocessing.AudioToSpectrogram fft_length: 512 # Length of the window and FFT for calculating spectrogram hop_length: 256 # Hop length for calculating spectrogram - power: null decoder: _target_: nemo.collections.asr.modules.audio_preprocessing.SpectrogramToAudio diff --git a/examples/audio_tasks/conf/masking.yaml b/examples/audio_tasks/conf/masking.yaml index c667bec53076..68adca116aa5 100644 --- a/examples/audio_tasks/conf/masking.yaml +++ b/examples/audio_tasks/conf/masking.yaml @@ -1,5 +1,3 @@ -# This configuration contains the exemplary values for training a multichannel speech enhancement model with a mask-based beamformer. -# name: "masking" model: @@ -44,7 +42,6 @@ model: _target_: nemo.collections.asr.modules.audio_preprocessing.AudioToSpectrogram fft_length: 512 # Length of the window and FFT for calculating spectrogram hop_length: 256 # Hop length for calculating spectrogram - power: null decoder: _target_: nemo.collections.asr.modules.audio_preprocessing.SpectrogramToAudio diff --git a/examples/audio_tasks/conf/predictive.yaml b/examples/audio_tasks/conf/predictive.yaml new file mode 100644 index 000000000000..b141ba6fd1ee --- /dev/null +++ b/examples/audio_tasks/conf/predictive.yaml @@ -0,0 +1,130 @@ +name: "predictive_model" + +model: + type: predictive + sample_rate: 16000 + skip_nan_grad: false + num_outputs: 1 + normalize_input: true # normalize the input signal to 0dBFS + + train_ds: + manifest_filepath: ??? + input_key: noisy_filepath + target_key: clean_filepath + audio_duration: 2.04 # Number of STFT time frames = 1 + audio_duration // encoder.hop_length = 256 + random_offset: true + normalization_signal: input_signal + batch_size: 8 # batch size may be increased based on the available memory + shuffle: true + num_workers: 8 + pin_memory: true + + validation_ds: + manifest_filepath: ??? + input_key: noisy_filepath + target_key: clean_filepath + batch_size: 8 + shuffle: false + num_workers: 4 + pin_memory: true + + encoder: + _target_: nemo.collections.asr.modules.audio_preprocessing.AudioToSpectrogram + fft_length: 510 # Number of subbands in the STFT = fft_length // 2 + 1 = 256 + hop_length: 128 + magnitude_power: 0.5 + scale: 0.33 + + decoder: + _target_: nemo.collections.asr.modules.audio_preprocessing.SpectrogramToAudio + fft_length: ${model.encoder.fft_length} + hop_length: ${model.encoder.hop_length} + magnitude_power: ${model.encoder.magnitude_power} + scale: ${model.encoder.scale} + + estimator: + _target_: nemo.collections.asr.parts.submodules.diffusion.SpectrogramNoiseConditionalScoreNetworkPlusPlus + in_channels: 1 # single-channel noisy input + out_channels: 1 # single-channel estimate + num_res_blocks: 3 # increased number of res blocks + pad_time_to: 64 # pad to 64 frames for the time dimension + pad_dimension_to: 0 # no padding in the frequency dimension + + loss: + _target_: nemo.collections.asr.losses.MSELoss # computed in the time domain + + metrics: + val: + sisdr: # output SI-SDR + _target_: torchmetrics.audio.ScaleInvariantSignalDistortionRatio + + optim: + name: adam + lr: 1e-4 + # optimizer arguments + betas: [0.9, 0.999] + weight_decay: 0.0 + +trainer: + devices: -1 # number of GPUs, -1 would use all available GPUs + num_nodes: 1 + max_epochs: -1 + max_steps: -1 # computed at runtime if not set + val_check_interval: 1.0 # Set to 0.25 to check 4 times per epoch, or an int for number of iterations + accelerator: auto + strategy: ddp + accumulate_grad_batches: 1 + gradient_clip_val: null + precision: 32 # Should be set to 16 for O1 and O2 to enable the AMP. + log_every_n_steps: 25 # Interval of logging. + enable_progress_bar: true + num_sanity_val_steps: 0 # number of steps to perform validation steps for sanity check the validation process before starting the training, setting to 0 disables it + check_val_every_n_epoch: 1 # number of evaluations on validation every n epochs + sync_batchnorm: true + enable_checkpointing: false # Provided by exp_manager + logger: false # Provided by exp_manager + +exp_manager: + exp_dir: null + name: ${name} + + # use exponential moving average for model parameters + ema: + enable: true + decay: 0.999 # decay rate + cpu_offload: false # offload EMA parameters to CPU to save GPU memory + every_n_steps: 1 # how often to update EMA weights + validate_original_weights: False # use original weights for validation calculation? + + # logging + create_tensorboard_logger: true + + # checkpointing + create_checkpoint_callback: true + checkpoint_callback_params: + # in case of multiple validation sets, first one is used + monitor: val_sisdr + mode: max + save_top_k: 5 + always_save_nemo: true # saves the checkpoints as nemo files instead of PTL checkpoints + + # early stopping + create_early_stopping_callback: true + early_stopping_callback_params: + monitor: val_sisdr + mode: max + min_delta: 0.0 + patience: 20 # patience in terms of check_val_every_n_epoch + verbose: true + strict: false # Should be False to avoid a runtime error where EarlyStopping says monitor is unavailable, which sometimes happens with resumed training. + + resume_from_checkpoint: null # The path to a checkpoint file to continue the training, restores the whole state including the epoch, step, LR schedulers, apex, etc. + # you need to set these two to true to continue the training + resume_if_exists: false + resume_ignore_no_checkpoint: false + + # You may use this section to create a W&B logger + create_wandb_logger: false + wandb_logger_kwargs: + name: null + project: null diff --git a/examples/audio_tasks/conf/score_based_generative.yaml b/examples/audio_tasks/conf/score_based_generative.yaml new file mode 100644 index 000000000000..c0b36bd750a2 --- /dev/null +++ b/examples/audio_tasks/conf/score_based_generative.yaml @@ -0,0 +1,149 @@ +name: score_based_generative_model + +model: + type: score_based + sample_rate: 16000 + skip_nan_grad: false + num_outputs: 1 + normalize_input: true + max_utts_evaluation_metrics: 50 # metric calculation needs full inference and is slow, so we limit to first few files + + train_ds: + manifest_filepath: ??? + input_key: noisy_filepath + target_key: clean_filepath + audio_duration: 2.04 # Number of STFT time frames = 1 + audio_duration // encoder.hop_length = 256 + random_offset: true + normalization_signal: input_signal + batch_size: 8 # batch size may be increased based on the available memory + shuffle: true + num_workers: 8 + pin_memory: true + + validation_ds: + manifest_filepath: ??? + input_key: noisy_filepath + target_key: clean_filepath + normalize_input: false # load data as is for validation, the model will normalize it for inference + batch_size: 4 + shuffle: false + num_workers: 4 + pin_memory: true + + encoder: + _target_: nemo.collections.asr.modules.audio_preprocessing.AudioToSpectrogram + fft_length: 510 # Number of subbands in the STFT = fft_length // 2 + 1 = 256 + hop_length: 128 + magnitude_power: 0.5 + scale: 0.33 + + decoder: + _target_: nemo.collections.asr.modules.audio_preprocessing.SpectrogramToAudio + fft_length: ${model.encoder.fft_length} + hop_length: ${model.encoder.hop_length} + magnitude_power: ${model.encoder.magnitude_power} + scale: ${model.encoder.scale} + + estimator: + _target_: nemo.collections.asr.parts.submodules.diffusion.SpectrogramNoiseConditionalScoreNetworkPlusPlus + in_channels: 2 # concatenation of single-channel perturbed and noisy + out_channels: 1 # single-channel score estimate + conditioned_on_time: true + num_res_blocks: 3 # increased number of res blocks + pad_time_to: 64 # pad to 64 frames for the time dimension + pad_dimension_to: 0 # no padding in the frequency dimension + + sde: + _target_: nemo.collections.asr.parts.submodules.diffusion.OrnsteinUhlenbeckVarianceExplodingSDE + stiffness: 1.5 + std_min: 0.05 + std_max: 0.5 + num_steps: 1000 + + sampler: + _target_: nemo.collections.asr.parts.submodules.diffusion.PredictorCorrectorSampler + predictor: reverse_diffusion + corrector: annealed_langevin_dynamics + num_steps: 50 + num_corrector_steps: 1 + snr: 0.5 + + loss: + _target_: nemo.collections.asr.losses.MSELoss + ndim: 4 # loss is calculated on the score in the encoded domain (batch, channel, dimension, time) + + metrics: + val: + sisdr: # output SI-SDR + _target_: torchmetrics.audio.ScaleInvariantSignalDistortionRatio + + optim: + name: adam + lr: 1e-4 + # optimizer arguments + betas: [0.9, 0.999] + weight_decay: 0.0 + +trainer: + devices: -1 # number of GPUs, -1 would use all available GPUs + num_nodes: 1 + max_epochs: -1 + max_steps: -1 # computed at runtime if not set + val_check_interval: 1.0 # Set to 0.25 to check 4 times per epoch, or an int for number of iterations + accelerator: auto + strategy: ddp + accumulate_grad_batches: 1 + gradient_clip_val: null + precision: 32 # Should be set to 16 for O1 and O2 to enable the AMP. + log_every_n_steps: 25 # Interval of logging. + enable_progress_bar: true + num_sanity_val_steps: 0 # number of steps to perform validation steps for sanity check the validation process before starting the training, setting to 0 disables it + check_val_every_n_epoch: 1 # number of evaluations on validation every n epochs + sync_batchnorm: true + enable_checkpointing: false # Provided by exp_manager + logger: false # Provided by exp_manager + +exp_manager: + exp_dir: null + name: ${name} + + # use exponential moving average for model parameters + ema: + enable: true + decay: 0.999 # decay rate + cpu_offload: false # offload EMA parameters to CPU to save GPU memory + every_n_steps: 1 # how often to update EMA weights + validate_original_weights: false # use original weights for validation calculation? + + # logging + create_tensorboard_logger: true + + # checkpointing + create_checkpoint_callback: true + checkpoint_callback_params: + # in case of multiple validation sets, first one is used + monitor: val_sisdr + mode: max + save_top_k: 5 + always_save_nemo: true # saves the checkpoints as nemo files instead of PTL checkpoints + + # early stopping + create_early_stopping_callback: true + early_stopping_callback_params: + monitor: val_sisdr + mode: max + min_delta: 0.0 + patience: 20 # patience in terms of check_val_every_n_epoch + verbose: true + strict: false # Should be False to avoid a runtime error where EarlyStopping says monitor is unavailable, which sometimes happens with resumed training. + + resume_from_checkpoint: null # The path to a checkpoint file to continue the training, restores the whole state including the epoch, step, LR schedulers, apex, etc. + # you need to set these two to true to continue the training + resume_if_exists: false + resume_ignore_no_checkpoint: false + + # You may use this section to create a W&B logger + create_wandb_logger: false + wandb_logger_kwargs: + name: null + project: null diff --git a/examples/audio_tasks/speech_enhancement.py b/examples/audio_tasks/speech_enhancement.py index 250d212d2a25..33a25c1c107c 100644 --- a/examples/audio_tasks/speech_enhancement.py +++ b/examples/audio_tasks/speech_enhancement.py @@ -26,25 +26,64 @@ PyTorch Lightning Trainer arguments and args of the model and the optimizer can be added or overriden from CLI """ +from enum import Enum + import pytorch_lightning as pl import torch from omegaconf import OmegaConf -from nemo.collections.asr.models import EncMaskDecAudioToAudioModel +from nemo.collections.asr.models.enhancement_models import ( + EncMaskDecAudioToAudioModel, + PredictiveAudioToAudioModel, + ScoreBasedGenerativeAudioToAudioModel, +) from nemo.core.config import hydra_runner from nemo.utils import logging from nemo.utils.exp_manager import exp_manager +class ModelType(str, Enum): + """Enumeration with the available model types. + """ + + MaskBased = 'mask_based' + Predictive = 'predictive' + ScoreBased = 'score_based' + + +def get_model_class(model_type: ModelType): + """Get model class for a given model type. + """ + if model_type == ModelType.MaskBased: + return EncMaskDecAudioToAudioModel + elif model_type == ModelType.Predictive: + return PredictiveAudioToAudioModel + elif model_type == ModelType.ScoreBased: + return ScoreBasedGenerativeAudioToAudioModel + else: + raise ValueError(f'Unknown model type: {model_type}') + + @hydra_runner(config_path="./conf", config_name="masking") def main(cfg): logging.info(f'Hydra config: {OmegaConf.to_yaml(cfg, resolve=True)}') trainer = pl.Trainer(**cfg.trainer) exp_manager(trainer, cfg.get("exp_manager", None)) - model = EncMaskDecAudioToAudioModel(cfg=cfg.model, trainer=trainer) - # Initialize the weights of the model from another model, if provided via config + # Get model class + model_type = cfg.model.get('type') + if model_type is None: + model_type = ModelType.MaskBased + logging.warning('model_type not found in config. Using default: %s', model_type) + + logging.info('Get class for model type: %s', model_type) + model_class = get_model_class(model_type) + + logging.info('Instantiate model %s', model_class.__name__) + model = model_class(cfg=cfg.model, trainer=trainer) + + logging.info('Initialize the weights of the model from another model, if provided via config') model.maybe_init_from_pretrained_checkpoint(cfg) # Train the model diff --git a/nemo/collections/asr/data/audio_to_audio.py b/nemo/collections/asr/data/audio_to_audio.py index a3c6dd0cc1b3..4f4727239a4b 100644 --- a/nemo/collections/asr/data/audio_to_audio.py +++ b/nemo/collections/asr/data/audio_to_audio.py @@ -130,13 +130,19 @@ class ASRAudioProcessor: sample_rate: sample rate used for all audio signals random_offset: If `True`, offset will be randomized when loading a subsegment from a file. + normalization_signal: Normalize all audio with a factor that ensures the signal + `example[normalization_signal]` in `process` is in range [-1, 1]. + All other audio signals are scaled by the same factor. Default is + `None`, corresponding to no normalization. """ def __init__( - self, sample_rate: float, random_offset: bool, + self, sample_rate: float, random_offset: bool, normalization_signal: Optional[str] = None, eps: float = 1e-8, ): self.sample_rate = sample_rate self.random_offset = random_offset + self.normalization_signal = normalization_signal + self.eps = eps self.sync_setup = None self.async_setup = None @@ -314,7 +320,20 @@ def process_audio(self, audio: Dict[str, torch.Tensor]) -> Dict[str, torch.Tenso Returns: An ordered dictionary of signals and their tensors. """ - # Currently, not doing any processing of the loaded signals. + if self.normalization_signal: + # Normalize all audio with a factor that ensures the normalization signal is in range [-1, 1]. + norm_scale = audio[self.normalization_signal].abs().max() + + # Do not normalize embeddings + skip_signals = self.embedding_setup.signals if self.embedding_setup is not None else [] + + # Normalize audio signals + for signal in audio: + if signal not in skip_signals: + # All audio signals are scaled by the same factor. + # This ensures that the relative level between signals is preserved. + audio[signal] = audio[signal] / (norm_scale + self.eps) + return audio def load_sync_signals(self, example: collections.Audio.OUTPUT_TYPE) -> Dict[str, torch.Tensor]: @@ -812,6 +831,9 @@ class AudioToTargetDataset(BaseAudioDataset): If `None`, all channels will be loaded. target_channel_selector: Optional, select subset of channels from each input audio file. If `None`, all channels will be loaded. + normalization_signal: Normalize audio signals with a scale that ensures the normalization signal is in range [-1, 1]. + All audio signals are scaled by the same factor. Supported values are `None` (no normalization), + 'input_signal', 'target_signal'. """ def __init__( @@ -827,6 +849,7 @@ def __init__( max_utts: Optional[int] = None, input_channel_selector: Optional[int] = None, target_channel_selector: Optional[int] = None, + normalization_signal: Optional[str] = None, ): audio_to_manifest_key = { 'input_signal': input_key, @@ -841,7 +864,9 @@ def __init__( max_number=max_utts, ) - audio_processor = ASRAudioProcessor(sample_rate=sample_rate, random_offset=random_offset,) + audio_processor = ASRAudioProcessor( + sample_rate=sample_rate, random_offset=random_offset, normalization_signal=normalization_signal, + ) audio_processor.sync_setup = SignalSetup( signals=['input_signal', 'target_signal'], duration=audio_duration, @@ -932,6 +957,9 @@ class AudioToTargetWithReferenceDataset(BaseAudioDataset): from input and target. reference_duration: Optional, can be used to set a fixed duration of the reference utterance. If `None`, complete audio file will be loaded. + normalization_signal: Normalize audio signals with a scale that ensures the normalization signal is in range [-1, 1]. + All audio signals are scaled by the same factor. Supported values are `None` (no normalization), + 'input_signal', 'target_signal', 'reference_signal'. """ def __init__( @@ -951,6 +979,7 @@ def __init__( reference_channel_selector: Optional[int] = None, reference_is_synchronized: bool = True, reference_duration: Optional[float] = None, + normalization_signal: Optional[str] = None, ): audio_to_manifest_key = { 'input_signal': input_key, @@ -966,7 +995,9 @@ def __init__( max_number=max_utts, ) - audio_processor = ASRAudioProcessor(sample_rate=sample_rate, random_offset=random_offset,) + audio_processor = ASRAudioProcessor( + sample_rate=sample_rate, random_offset=random_offset, normalization_signal=normalization_signal, + ) if reference_is_synchronized: audio_processor.sync_setup = SignalSetup( @@ -1063,6 +1094,9 @@ class AudioToTargetWithEmbeddingDataset(BaseAudioDataset): If `None`, all channels will be loaded. target_channel_selector: Optional, select subset of channels from each input audio file. If `None`, all channels will be loaded. + normalization_signal: Normalize audio signals with a scale that ensures the normalization signal is in range [-1, 1]. + All audio signals are scaled by the same factor. Supported values are `None` (no normalization), + 'input_signal', 'target_signal'. """ def __init__( @@ -1079,6 +1113,7 @@ def __init__( max_utts: Optional[int] = None, input_channel_selector: Optional[int] = None, target_channel_selector: Optional[int] = None, + normalization_signal: Optional[str] = None, ): audio_to_manifest_key = { 'input_signal': input_key, @@ -1094,7 +1129,9 @@ def __init__( max_number=max_utts, ) - audio_processor = ASRAudioProcessor(sample_rate=sample_rate, random_offset=random_offset,) + audio_processor = ASRAudioProcessor( + sample_rate=sample_rate, random_offset=random_offset, normalization_signal=normalization_signal, + ) audio_processor.sync_setup = SignalSetup( signals=['input_signal', 'target_signal'], duration=audio_duration, diff --git a/nemo/collections/asr/data/audio_to_audio_dataset.py b/nemo/collections/asr/data/audio_to_audio_dataset.py index b296d64b1f2a..46e47020fda0 100644 --- a/nemo/collections/asr/data/audio_to_audio_dataset.py +++ b/nemo/collections/asr/data/audio_to_audio_dataset.py @@ -36,6 +36,7 @@ def get_audio_to_target_dataset(config: dict) -> audio_to_audio.AudioToTargetDat max_utts=config.get('max_utts', 0), input_channel_selector=config.get('input_channel_selector', None), target_channel_selector=config.get('target_channel_selector', None), + normalization_signal=config.get('normalization_signal', None), ) return dataset @@ -65,6 +66,7 @@ def get_audio_to_target_with_reference_dataset(config: dict) -> audio_to_audio.A reference_channel_selector=config.get('reference_channel_selector', None), reference_is_synchronized=config.get('reference_is_synchronized', True), reference_duration=config.get('reference_duration', None), + normalization_signal=config.get('normalization_signal', None), ) return dataset @@ -91,5 +93,6 @@ def get_audio_to_target_with_embedding_dataset(config: dict) -> audio_to_audio.A max_utts=config.get('max_utts', 0), input_channel_selector=config.get('input_channel_selector', None), target_channel_selector=config.get('target_channel_selector', None), + normalization_signal=config.get('normalization_signal', None), ) return dataset diff --git a/nemo/collections/asr/losses/__init__.py b/nemo/collections/asr/losses/__init__.py index 3e50cea1d692..c03f7a48ffe3 100644 --- a/nemo/collections/asr/losses/__init__.py +++ b/nemo/collections/asr/losses/__init__.py @@ -13,7 +13,7 @@ # limitations under the License. from nemo.collections.asr.losses.angularloss import AngularSoftmaxLoss -from nemo.collections.asr.losses.audio_losses import SDRLoss +from nemo.collections.asr.losses.audio_losses import MSELoss, SDRLoss from nemo.collections.asr.losses.ctc import CTCLoss from nemo.collections.asr.losses.lattice_losses import LatticeLoss from nemo.collections.asr.losses.ssl_losses.contrastive import ContrastiveLoss diff --git a/nemo/collections/asr/losses/audio_losses.py b/nemo/collections/asr/losses/audio_losses.py index 62ce4a9f7edd..b0214375a713 100644 --- a/nemo/collections/asr/losses/audio_losses.py +++ b/nemo/collections/asr/losses/audio_losses.py @@ -13,7 +13,7 @@ # limitations under the License. import math -from typing import List, Optional +from typing import List, Optional, Tuple, Union import numpy as np import torch @@ -21,31 +21,33 @@ from nemo.collections.asr.parts.preprocessing.features import make_seq_mask_like from nemo.collections.asr.parts.utils.audio_utils import toeplitz from nemo.core.classes import Loss, Typing, typecheck -from nemo.core.neural_types import AudioSignal, LengthsType, LossType, MaskType, NeuralType +from nemo.core.neural_types import AudioSignal, LengthsType, LossType, MaskType, NeuralType, VoidType from nemo.utils import logging -__all__ = ['SDRLoss'] +__all__ = ['SDRLoss', 'MSELoss'] -def temporal_mean( +def calculate_mean( input: torch.Tensor, input_length: Optional[torch.Tensor] = None, mask: Optional[torch.Tensor] = None, + dim: Union[int, Tuple[int]] = -1, keepdim: bool = False, eps: float = 1e-10, ) -> torch.Tensor: - """Calculate mean along temporal dimension with optionally + """Calculate mean along dimension `dim` with optionally averaging only over valid samples (based on the input length). Args: - input: Batch of signals, shape (B, C, T) + input: signal, for example (B, C, T) or (B, C, D, T) input_length: Optional, length of each example in the batch, shape (B,) - mask: Optional, temporal mask for each example in the batch, shape (B, T) + mask: Optional, temporal mask for each example in the batch, same shape as the input signal + dim: dimension or dimensions to reduce keepdim: Whether to keep the temporal dimension eps: Regularization to avoid division by zero Returns: - (B, C, 1) if keepdim=True, otherwise (B, C) + Mean over dimensions `dim`. """ if input_length is not None: if mask is not None: @@ -53,17 +55,18 @@ def temporal_mean( 'Argument `input_length` is mutually exclusive with `mask`. Both cannot be used at the same time.' ) # Construct a binary mask - mask = make_seq_mask_like(lengths=input_length, like=input, time_dim=-1, valid_ones=True).squeeze(1) + mask = make_seq_mask_like(lengths=input_length, like=input, time_dim=-1, valid_ones=True) + mask = mask.expand_as(input) if mask is None: # No length information, assume all samples are valid - mean = torch.mean(input, dim=-1, keepdim=keepdim) + mean = torch.mean(input, dim=dim, keepdim=keepdim) else: # Average using temporal mask - mean = mask.unsqueeze(1) * input - mean = torch.sum(mean, axis=-1, keepdim=keepdim) - normalization = torch.sum(mask, axis=-1, keepdim=keepdim) - mean = mean / (normalization.unsqueeze(1) + eps) + mean = mask * input + mean = torch.sum(mean, dim=dim, keepdim=keepdim) + normalization = torch.sum(mask, dim=dim, keepdim=keepdim) + mean = mean / (normalization + eps) return mean @@ -101,16 +104,17 @@ def scale_invariant_target( ) # Construct a binary mask - mask = make_seq_mask_like(lengths=input_length, like=estimate, time_dim=-1, valid_ones=True).squeeze(1) + mask = make_seq_mask_like(lengths=input_length, like=estimate, time_dim=-1, valid_ones=True) + mask = mask.expand_as(estimate) - estimate_dot_target = temporal_mean(estimate * target, mask=mask, keepdim=True, eps=eps) - target_pow = temporal_mean(torch.abs(target) ** 2, mask=mask, keepdim=True, eps=eps) + estimate_dot_target = calculate_mean(estimate * target, mask=mask, dim=-1, keepdim=True, eps=eps) + target_pow = calculate_mean(torch.abs(target) ** 2, mask=mask, dim=-1, keepdim=True, eps=eps) scale = estimate_dot_target / (target_pow + eps) target_scaled = scale * target # Mask to keep only the valid samples if mask is not None: - target_scaled = mask.unsqueeze(1) * target_scaled + target_scaled = mask * target_scaled return target_scaled @@ -162,12 +166,13 @@ def convolution_invariant_target( ) # Construct a binary mask - mask = make_seq_mask_like(lengths=input_length, like=estimate, time_dim=-1, valid_ones=True).squeeze(1) + mask = make_seq_mask_like(lengths=input_length, like=estimate, time_dim=-1, valid_ones=True) + mask = mask.expand_as(estimate) # Apply a mask, if available if mask is not None: - estimate = mask.unsqueeze(1) * estimate - target = mask.unsqueeze(1) * target + estimate = mask * estimate + target = mask * target # Calculate filtered target input_shape = estimate.shape @@ -207,7 +212,7 @@ def convolution_invariant_target( # Mask to keep only the valid samples if mask is not None: - target_filt = mask.unsqueeze(1) * target_filt + target_filt = mask * target_filt return target_filt @@ -261,11 +266,12 @@ def calculate_sdr_batch( ) # Construct a binary mask - mask = make_seq_mask_like(lengths=input_length, like=estimate, time_dim=-1, valid_ones=True).squeeze(1) + mask = make_seq_mask_like(lengths=input_length, like=estimate, time_dim=-1, valid_ones=True) + mask = mask.expand_as(estimate) if remove_mean: - estimate = estimate - temporal_mean(estimate, mask=mask, keepdim=True, eps=eps) - target = target - temporal_mean(target, mask=mask, keepdim=True, eps=eps) + estimate = estimate - calculate_mean(estimate, mask=mask, dim=-1, keepdim=True, eps=eps) + target = target - calculate_mean(target, mask=mask, dim=-1, keepdim=True, eps=eps) if scale_invariant or (convolution_invariant and convolution_filter_length == 1): target = scale_invariant_target(estimate=estimate, target=target, mask=mask, eps=eps) @@ -276,8 +282,8 @@ def calculate_sdr_batch( distortion = estimate - target - target_pow = temporal_mean(torch.abs(target) ** 2, mask=mask, eps=eps) - distortion_pow = temporal_mean(torch.abs(distortion) ** 2, mask=mask, eps=eps) + target_pow = calculate_mean(torch.abs(target) ** 2, mask=mask, dim=-1, eps=eps) + distortion_pow = calculate_mean(torch.abs(distortion) ** 2, mask=mask, dim=-1, eps=eps) if sdr_max is not None: distortion_pow = distortion_pow + 10 ** (-sdr_max / 10) * target_pow @@ -353,7 +359,7 @@ def input_types(self): "estimate": NeuralType(signal_shape, AudioSignal()), "target": NeuralType(signal_shape, AudioSignal()), "input_length": NeuralType(tuple('B'), LengthsType(), optional=True), - "mask": NeuralType(('B', 'T'), MaskType(), optional=True), + "mask": NeuralType(('B', 'C', 'T'), MaskType(), optional=True), } @property @@ -376,10 +382,10 @@ def forward( perform averaging across channels (weighting optional), and apply reduction across the batch. Args: - estimate: Batch of signals, shape (B, T, C) - target: Batch of signals, shape (B, T, C) + estimate: Batch of signals, shape (B, C, T) + target: Batch of signals, shape (B, C, T) input_length: Batch of lengths, shape (B,) - mask: Batch of temporal masks, shape (B, T) + mask: Batch of temporal masks for each channel, shape (B, C, T) Returns: Scalar loss. @@ -410,3 +416,161 @@ def forward( sdr = self.reduce(sdr) return -sdr + + +def calculate_mse_batch( + estimate: torch.Tensor, + target: torch.Tensor, + input_length: Optional[torch.Tensor] = None, + mask: Optional[torch.Tensor] = None, +) -> torch.Tensor: + """Calculate MSE per channel. + + MSE = ||estimate - target||_2^2 / input_length + + Args: + estimate: estimated signal, shape (B, C, T) or (B, C, D, T) + target: target signal, shape (B, C, T) or (B, C, D, T) + input_length: Optional, length of valid samples, shape (B,) + mask: Optional, temporal mask, same shape as signals + + Returns: + MSE for each channel, shape (B, C) + """ + assert ( + estimate.shape == target.shape + ), f'Estimate shape ({estimate.shape}) not matching target shape ({target.shape})' + + if input_length is not None: + if mask is not None: + raise RuntimeError( + 'Argument `input_length` is mutually exclusive with `mask`. Both cannot be used at the same time.' + ) + + # Construct a binary mask + mask = make_seq_mask_like(lengths=input_length, like=estimate, time_dim=-1, valid_ones=True) + mask = mask.expand_as(estimate) + + # error + err = estimate - target + + # dimensions for averaging + if estimate.ndim == 3: + # average across time + dim = -1 + elif estimate.ndim == 4: + # average across time and features + dim = (-2, -1) + else: + raise RuntimeError(f'Unexpected dimension of the input: {estimate.shape}') + + # calculate masked mean + mse = calculate_mean(torch.abs(err) ** 2, mask=mask, dim=dim) + + return mse + + +class MSELoss(Loss, Typing): + """ + Computes MSE loss with weighted average across channels. + + Args: + weight: weight for loss of each output channel, used for averaging the loss across channels. Defaults to `None` (averaging). + reduction: batch reduction. Defaults to `mean` over the batch. + ndim: Number of dimensions for the input signal + """ + + def __init__( + self, weight: Optional[List[float]] = None, reduction: str = 'mean', ndim: int = 3, + ): + super().__init__() + + # weight buffer + if weight is not None: + if any([w <= 0 for w in weight]): + raise ValueError(f'Weight must be positive! Current value: {weight}') + elif not np.isclose(sum(weight), 1, atol=1e-6): + raise ValueError(f'Weight should add to one, current weight: {weight}') + weight = torch.tensor(weight).reshape(1, -1) + logging.info(f'Channel weight set to %s', weight) + self.register_buffer('weight', weight) + self.weight: Optional[Tensor] + + # Batch reduction + self.reduction = reduction + if reduction == 'mean': + self.reduce = torch.mean + else: + raise ValueError(f'Unexpected reduction mode {reduction}.') + + # Input dimension + self.ndim = ndim + + if self.ndim == 3: + # Time-domain input + self.signal_shape = ('B', 'C', 'T') + elif self.ndim == 4: + # Spectral-domain input + self.signal_shape = ('B', 'C', 'D', 'T') + else: + raise ValueError(f'Unexpected input dimension: {self.ndim}') + + logging.debug('Initialized %s with', self.__class__.__name__) + logging.debug('\tweight: %s', self.weight) + logging.debug('\treduction: %s', self.reduction) + logging.debug('\tndim: %s', self.ndim) + logging.debug('\tsignal_shape: %s', self.signal_shape) + + @property + def input_types(self): + """Input types definitions for SDRLoss. + """ + return { + "estimate": NeuralType(self.signal_shape, VoidType()), + "target": NeuralType(self.signal_shape, VoidType()), + "input_length": NeuralType(tuple('B'), LengthsType(), optional=True), + "mask": NeuralType(self.signal_shape, MaskType(), optional=True), + } + + @property + def output_types(self): + """Output types definitions for SDRLoss. + loss: + NeuralType(None) + """ + return {"loss": NeuralType(elements_type=LossType())} + + @typecheck() + def forward( + self, + estimate: torch.Tensor, + target: torch.Tensor, + input_length: Optional[torch.Tensor] = None, + mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """For input batch of multi-channel signals, calculate SDR between estimate and target for each channel, + perform averaging across channels (weighting optional), and apply reduction across the batch. + + Args: + estimate: Estimate of the target signal + target: Target signal + input_length: Length of each example in the batch + mask: Mask for each signal + + Returns: + Scalar loss. + """ + mse = calculate_mse_batch(estimate=estimate, target=target, input_length=input_length, mask=mask,) + + # channel averaging + if self.weight is None: + mse = torch.mean(mse, dim=1) + else: + # weighting across channels + mse = mse * self.weight + mse = torch.sum(mse, dim=1) + + # reduction + mse = self.reduce(mse) + + return mse diff --git a/nemo/collections/asr/metrics/audio.py b/nemo/collections/asr/metrics/audio.py index 5e8c2915e3fa..db63ac19c098 100644 --- a/nemo/collections/asr/metrics/audio.py +++ b/nemo/collections/asr/metrics/audio.py @@ -57,6 +57,7 @@ class AudioMetricWrapper(Metric): """ full_state_update: bool = False + num_examples: torch.Tensor def __init__( self, metric: Metric, channel: Optional[int] = None, metric_using_batch_averaging: Optional[bool] = None @@ -74,6 +75,7 @@ def __init__( self._metric = metric self._channel = channel + self.add_state('num_examples', default=torch.tensor(0), dist_reduce_fx='sum') logging.debug('Setup metric %s, channel %s', metric, str(channel)) def _select_channel(self, preds: torch.Tensor, target: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: @@ -144,6 +146,8 @@ def update(self, preds: torch.Tensor, target: torch.Tensor, input_length: Option for b_preds, b_target in self._trim_inputs(preds=preds, target=target, input_length=input_length): self._metric.update(preds=b_preds, target=b_target) + self.num_examples += preds.size(0) + def compute(self) -> torch.Tensor: """Compute the underlying metric. """ @@ -179,6 +183,9 @@ def forward( def reset(self) -> None: """Reset the underlying metric. """ + # reset the internal states + super().reset() + # reset the underlying metric self._metric.reset() def __repr__(self) -> str: diff --git a/nemo/collections/asr/models/__init__.py b/nemo/collections/asr/models/__init__.py index 019c57f9c4e3..23c759afc80d 100644 --- a/nemo/collections/asr/models/__init__.py +++ b/nemo/collections/asr/models/__init__.py @@ -23,7 +23,11 @@ from nemo.collections.asr.models.clustering_diarizer import ClusteringDiarizer from nemo.collections.asr.models.ctc_bpe_models import EncDecCTCModelBPE from nemo.collections.asr.models.ctc_models import EncDecCTCModel -from nemo.collections.asr.models.enhancement_models import EncMaskDecAudioToAudioModel +from nemo.collections.asr.models.enhancement_models import ( + EncMaskDecAudioToAudioModel, + PredictiveAudioToAudioModel, + ScoreBasedGenerativeAudioToAudioModel, +) from nemo.collections.asr.models.hybrid_rnnt_ctc_bpe_models import EncDecHybridRNNTCTCBPEModel from nemo.collections.asr.models.hybrid_rnnt_ctc_models import EncDecHybridRNNTCTCModel from nemo.collections.asr.models.k2_sequence_models import ( diff --git a/nemo/collections/asr/models/audio_to_audio_model.py b/nemo/collections/asr/models/audio_to_audio_model.py index 49364843e8b8..094dbc38b72a 100644 --- a/nemo/collections/asr/models/audio_to_audio_model.py +++ b/nemo/collections/asr/models/audio_to_audio_model.py @@ -12,15 +12,26 @@ # See the License for the specific language governing permissions and # limitations under the License. +import json +import os +import tempfile from abc import ABC, abstractmethod -from typing import List, Union +from typing import Dict, List, Optional, Union import hydra +import librosa +import soundfile as sf import torch from omegaconf import DictConfig, OmegaConf from pytorch_lightning import Trainer +from tqdm import tqdm +from nemo.collections.asr.data import audio_to_audio_dataset +from nemo.collections.asr.data.audio_to_audio_lhotse import LhotseAudioToTargetDataset +from nemo.collections.asr.data.audio_to_text_dataset import inject_dataloader_value_from_model_config from nemo.collections.asr.metrics.audio import AudioMetricWrapper +from nemo.collections.asr.parts.utils.audio_utils import ChannelSelectorType +from nemo.collections.common.data.lhotse import get_lhotse_dataloader_from_config from nemo.core.classes import ModelPT from nemo.utils import logging, model_utils @@ -158,23 +169,384 @@ def multi_validation_epoch_end(self, outputs, dataloader_idx: int = 0): def multi_test_epoch_end(self, outputs, dataloader_idx: int = 0): return self.multi_evaluation_epoch_end(outputs, dataloader_idx, 'test') - @abstractmethod + @torch.no_grad() def process( - self, paths2audio_files: List[str], output_dir: str, batch_size: int = 4 - ) -> List[Union[str, List[str]]]: + self, + paths2audio_files: List[str], + output_dir: str, + batch_size: int = 1, + num_workers: Optional[int] = None, + input_channel_selector: Optional[ChannelSelectorType] = None, + ) -> List[str]: + """ + Process audio files provided in paths2audio_files. + Processed signals will be saved in output_dir. + + Args: + paths2audio_files: (a list) of paths to audio files. \ + Recommended length per file is between 5 and 25 seconds. \ + But it is possible to pass a few hours long file if enough GPU memory is available. + output_dir: + batch_size: (int) batch size to use during inference. + Bigger will result in better throughput performance but would use more memory. + num_workers: Number of workers for the dataloader + input_channel_selector (int | Iterable[int] | str): select a single channel or a subset of channels from multi-channel audio. If set to `'average'`, it performs averaging across channels. Disabled if set to `None`. Defaults to `None`. + + Returns: + """ + if paths2audio_files is None or len(paths2audio_files) == 0: + return {} + + if num_workers is None: + num_workers = min(batch_size, os.cpu_count() - 1) + + # Output + paths2processed_files = [] + + # Model's mode and device + mode = self.training + device = next(self.parameters()).device + + try: + # Switch model to evaluation mode + self.eval() + # Freeze weights + self.freeze() + + logging_level = logging.get_verbosity() + logging.set_verbosity(logging.WARNING) + + # Processing + with tempfile.TemporaryDirectory() as tmpdir: + # Save temporary manifest + temporary_manifest_filepath = os.path.join(tmpdir, 'manifest.json') + with open(temporary_manifest_filepath, 'w', encoding='utf-8') as fp: + for audio_file in paths2audio_files: + entry = {'input_filepath': audio_file, 'duration': librosa.get_duration(path=audio_file)} + fp.write(json.dumps(entry) + '\n') + + config = { + 'manifest_filepath': temporary_manifest_filepath, + 'input_key': 'input_filepath', + 'input_channel_selector': input_channel_selector, + 'batch_size': min(batch_size, len(paths2audio_files)), + 'num_workers': num_workers, + } + + # Create output dir if necessary + if not os.path.isdir(output_dir): + os.makedirs(output_dir) + + # DataLoader for the input files + temporary_dataloader = self._setup_process_dataloader(config) + + # Indexing of the original files, used to form the output file name + file_idx = 0 + + # Process batches + for test_batch in tqdm(temporary_dataloader, desc="Processing"): + input_signal = test_batch[0] + input_length = test_batch[1] + + # Expand channel dimension, if necessary + # For consistency, the model uses multi-channel format, even if the channel dimension is 1 + if input_signal.ndim == 2: + input_signal = input_signal.unsqueeze(1) + + processed_batch, _ = self.forward( + input_signal=input_signal.to(device), input_length=input_length.to(device) + ) + + for example_idx in range(processed_batch.size(0)): + # This assumes the data loader is not shuffling files + file_name = os.path.basename(paths2audio_files[file_idx]) + # Prepare output file + output_file = os.path.join(output_dir, f'processed_{file_name}') + # Crop the output signal to the actual length + output_signal = processed_batch[example_idx, :, : input_length[example_idx]].cpu().numpy() + # Write audio + sf.write(output_file, output_signal.T, self.sample_rate, 'float') + # Update the file counter + file_idx += 1 + # Save processed file + paths2processed_files.append(output_file) + + del test_batch + del processed_batch + + finally: + # set mode back to its original value + self.train(mode=mode) + if mode is True: + self.unfreeze() + logging.set_verbosity(logging_level) + + return paths2processed_files + + def _setup_dataloader_from_config(self, config: Optional[Dict]): + + if config.get("use_lhotse", False): + return get_lhotse_dataloader_from_config( + config, global_rank=self.global_rank, world_size=self.world_size, dataset=LhotseAudioToTargetDataset() + ) + + is_concat = config.get('is_concat', False) + if is_concat: + raise NotImplementedError('Concat not implemented') + + # TODO: Consider moving `inject` from `audio_to_text_dataset` to a utility module? + # Automatically inject args from model config to dataloader config + inject_dataloader_value_from_model_config(self.cfg, config, key='sample_rate') + + # Instantiate tarred dataset loader or normal dataset loader + if config.get('is_tarred', False): + raise NotImplementedError('Tarred datasets not supported') + + if 'manifest_filepath' in config and config['manifest_filepath'] is None: + logging.warning(f"Could not load dataset as `manifest_filepath` was None. Provided config : {config}") + return None + + dataset = audio_to_audio_dataset.get_audio_to_target_dataset(config=config) + + if hasattr(dataset, 'collate_fn'): + collate_fn = dataset.collate_fn + elif hasattr(dataset.datasets[0], 'collate_fn'): + # support datasets that are lists of entries + collate_fn = dataset.datasets[0].collate_fn + else: + # support datasets that are lists of lists + collate_fn = dataset.datasets[0].datasets[0].collate_fn + + return torch.utils.data.DataLoader( + dataset=dataset, + batch_size=config['batch_size'], + collate_fn=collate_fn, + drop_last=config.get('drop_last', False), + shuffle=config['shuffle'], + num_workers=config.get('num_workers', 0), + pin_memory=config.get('pin_memory', False), + ) + + def setup_training_data(self, train_data_config: Optional[Union[DictConfig, Dict]]): + """ + Sets up the training data loader via a Dict-like object. + + Args: + train_data_config: A config that contains the information regarding construction + of a training dataset. + + Supported Datasets: + - :class:`~nemo.collections.asr.data.audio_to_audio.AudioToTargetDataset` + """ + if 'shuffle' not in train_data_config: + train_data_config['shuffle'] = True + + # preserve config + self._update_dataset_config(dataset_name='train', config=train_data_config) + + self._train_dl = self._setup_dataloader_from_config(config=train_data_config) + + if 'is_tarred' in train_data_config and train_data_config['is_tarred']: + raise NotImplementedError('Tarred datasets not supported') + + def setup_validation_data(self, val_data_config: Optional[Union[DictConfig, Dict]]): + """ + Sets up the validation data loader via a Dict-like object. + + Args: + val_data_config: A config that contains the information regarding construction + of a validation dataset. + + Supported Datasets: + - :class:`~nemo.collections.asr.data.audio_to_audio.AudioToTargetDataset` + """ + if 'shuffle' not in val_data_config: + val_data_config['shuffle'] = False + + # preserve config + self._update_dataset_config(dataset_name='validation', config=val_data_config) + + self._validation_dl = self._setup_dataloader_from_config(config=val_data_config) + + def setup_test_data(self, test_data_config: Optional[Union[DictConfig, Dict]]): + """ + Sets up the test data loader via a Dict-like object. + + Args: + test_data_config: A config that contains the information regarding construction + of a test dataset. + + Supported Datasets: + - :class:`~nemo.collections.asr.data.audio_to_audio.AudioToTargetDataset` + """ + if 'shuffle' not in test_data_config: + test_data_config['shuffle'] = False + + # preserve config + self._update_dataset_config(dataset_name='test', config=test_data_config) + + self._test_dl = self._setup_dataloader_from_config(config=test_data_config) + + def _setup_process_dataloader(self, config: Dict) -> 'torch.utils.data.DataLoader': + """Prepare a dataloader for processing files. + + Args: + config: A python dictionary which contains the following keys: + manifest_filepath: path to a manifest file + input_key: key with audio filepaths in the manifest + input_channel_selector: Optional, used to select a subset of channels from input audio files + batch_size: batch size for the dataloader + num_workers: number of workers for the dataloader + + Returns: + A pytorch DataLoader for the given manifest filepath. + """ + dl_config = { + 'manifest_filepath': config['manifest_filepath'], + 'sample_rate': self.sample_rate, + 'input_key': config['input_key'], + 'input_channel_selector': config.get('input_channel_selector', None), + 'target_key': None, + 'target_channel_selector': None, + 'batch_size': config['batch_size'], + 'shuffle': False, + 'num_workers': config.get('num_workers', min(config['batch_size'], os.cpu_count() - 1)), + 'pin_memory': True, + } + + temporary_dataloader = self._setup_dataloader_from_config(config=DictConfig(dl_config)) + return temporary_dataloader + + @staticmethod + def match_batch_length(input: torch.Tensor, batch_length: int) -> torch.Tensor: + """Trim or pad the output to match the batch length. + + Args: + input: tensor with shape (B, C, T) + batch_length: int + + Returns: + Tensor with shape (B, C, T), where T matches the + batch length. + """ + input_length = input.size(-1) + pad_length = batch_length - input_length + pad = (0, pad_length) + # pad with zeros or crop + return torch.nn.functional.pad(input, pad, 'constant', 0) + + @torch.no_grad() + def process( + self, + paths2audio_files: List[str], + output_dir: str, + batch_size: int = 1, + num_workers: Optional[int] = None, + input_channel_selector: Optional[ChannelSelectorType] = None, + ) -> List[str]: """ Takes paths to audio files and returns a list of paths to processed audios. Args: paths2audio_files: paths to audio files to be processed - output_dir: directory to save processed files - batch_size: batch size for inference + output_dir: directory to save the processed files + batch_size: (int) batch size to use during inference. + num_workers: Number of workers for the dataloader + input_channel_selector (int | Iterable[int] | str): select a single channel or a subset of channels from multi-channel audio. + If set to `'average'`, it performs averaging across channels. Disabled if set to `None`. Defaults to `None`. Returns: Paths to processed audio signals. """ - pass + if paths2audio_files is None or len(paths2audio_files) == 0: + return {} + + if num_workers is None: + num_workers = min(batch_size, os.cpu_count() - 1) + + # Output + paths2processed_files = [] + + # Model's mode and device + mode = self.training + device = next(self.parameters()).device + + try: + # Switch model to evaluation mode + self.eval() + # Freeze weights + self.freeze() + + logging_level = logging.get_verbosity() + logging.set_verbosity(logging.WARNING) + + # Processing + with tempfile.TemporaryDirectory() as tmpdir: + # Save temporary manifest + temporary_manifest_filepath = os.path.join(tmpdir, 'manifest.json') + with open(temporary_manifest_filepath, 'w', encoding='utf-8') as fp: + for audio_file in paths2audio_files: + entry = {'input_filepath': audio_file, 'duration': librosa.get_duration(path=audio_file)} + fp.write(json.dumps(entry) + '\n') + + config = { + 'manifest_filepath': temporary_manifest_filepath, + 'input_key': 'input_filepath', + 'input_channel_selector': input_channel_selector, + 'batch_size': min(batch_size, len(paths2audio_files)), + 'num_workers': num_workers, + } + + # Create output dir if necessary + if not os.path.isdir(output_dir): + os.makedirs(output_dir) + + # DataLoader for the input files + temporary_dataloader = self._setup_process_dataloader(config) + + # Indexing of the original files, used to form the output file name + file_idx = 0 + + # Process batches + for test_batch in tqdm(temporary_dataloader, desc="Processing"): + input_signal = test_batch[0] + input_length = test_batch[1] + + # Expand channel dimension, if necessary + # For consistency, the model uses multi-channel format, even if the channel dimension is 1 + if input_signal.ndim == 2: + input_signal = input_signal.unsqueeze(1) + + processed_batch, _ = self.forward( + input_signal=input_signal.to(device), input_length=input_length.to(device) + ) + + for example_idx in range(processed_batch.size(0)): + # This assumes the data loader is not shuffling files + file_name = os.path.basename(paths2audio_files[file_idx]) + # Prepare output file + output_file = os.path.join(output_dir, f'processed_{file_name}') + # Crop the output signal to the actual length + output_signal = processed_batch[example_idx, :, : input_length[example_idx]].cpu().numpy() + # Write audio + sf.write(output_file, output_signal.T, self.sample_rate, 'float') + # Update the file counter + file_idx += 1 + # Save processed file + paths2processed_files.append(output_file) + + del test_batch + del processed_batch + + finally: + # set mode back to its original value + self.train(mode=mode) + if mode is True: + self.unfreeze() + logging.set_verbosity(logging_level) + + return paths2processed_files @classmethod def list_available_models(cls) -> 'List[PretrainedModelInfo]': diff --git a/nemo/collections/asr/models/enhancement_models.py b/nemo/collections/asr/models/enhancement_models.py index b80c357364aa..b765ae0fddad 100644 --- a/nemo/collections/asr/models/enhancement_models.py +++ b/nemo/collections/asr/models/enhancement_models.py @@ -16,6 +16,8 @@ import tempfile from typing import Dict, List, Optional, Union +import einops +import hydra import librosa import soundfile as sf import torch @@ -23,17 +25,13 @@ from pytorch_lightning import Trainer from tqdm import tqdm -from nemo.collections.asr.data import audio_to_audio_dataset -from nemo.collections.asr.data.audio_to_audio_lhotse import LhotseAudioToTargetDataset -from nemo.collections.asr.data.audio_to_text_dataset import inject_dataloader_value_from_model_config + from nemo.collections.asr.models.audio_to_audio_model import AudioToAudioModel -from nemo.collections.asr.parts.utils.audio_utils import ChannelSelectorType -from nemo.collections.common.data.lhotse import get_lhotse_dataloader_from_config from nemo.core.classes.common import PretrainedModelInfo, typecheck -from nemo.core.neural_types import AudioSignal, LengthsType, NeuralType +from nemo.core.neural_types import AudioSignal, LengthsType, LossType, NeuralType from nemo.utils import logging -__all__ = ['EncMaskDecAudioToAudioModel'] +__all__ = ['EncMaskDecAudioToAudioModel', 'ScoreBasedGenerativeAudioToAudioModel', 'PredictiveAudioToAudioModel'] class EncMaskDecAudioToAudioModel(AudioToAudioModel): @@ -69,10 +67,6 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None): logging.debug('Mixture consistency not used') self.mixture_consistency = None - # Future enhancement: - # If subclasses need to modify the config before calling super() - # Check ASRBPE* classes do with their mixin - # Setup augmentation if hasattr(self.cfg, 'channel_augment') and self.cfg.channel_augment is not None: logging.debug('Using channel augmentation') @@ -84,254 +78,6 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None): # Setup optional Optimization flags self.setup_optimization_flags() - @torch.no_grad() - def process( - self, - paths2audio_files: List[str], - output_dir: str, - batch_size: int = 1, - num_workers: Optional[int] = None, - input_channel_selector: Optional[ChannelSelectorType] = None, - ) -> List[str]: - """ - Process audio files provided in paths2audio_files. - Processed signals will be saved in output_dir. - - Args: - paths2audio_files: (a list) of paths to audio files. \ - Recommended length per file is between 5 and 25 seconds. \ - But it is possible to pass a few hours long file if enough GPU memory is available. - output_dir: - batch_size: (int) batch size to use during inference. - Bigger will result in better throughput performance but would use more memory. - num_workers: Number of workers for the dataloader - input_channel_selector (int | Iterable[int] | str): select a single channel or a subset of channels from multi-channel audio. If set to `'average'`, it performs averaging across channels. Disabled if set to `None`. Defaults to `None`. - - Returns: - """ - if paths2audio_files is None or len(paths2audio_files) == 0: - return {} - - if num_workers is None: - num_workers = min(batch_size, os.cpu_count() - 1) - - # Output - paths2processed_files = [] - - # Model's mode and device - mode = self.training - device = next(self.parameters()).device - - try: - # Switch model to evaluation mode - self.eval() - # Freeze weights - self.freeze() - - logging_level = logging.get_verbosity() - logging.set_verbosity(logging.WARNING) - - # Processing - with tempfile.TemporaryDirectory() as tmpdir: - # Save temporary manifest - temporary_manifest_filepath = os.path.join(tmpdir, 'manifest.json') - with open(temporary_manifest_filepath, 'w', encoding='utf-8') as fp: - for audio_file in paths2audio_files: - entry = {'input_filepath': audio_file, 'duration': librosa.get_duration(path=audio_file)} - fp.write(json.dumps(entry) + '\n') - - config = { - 'manifest_filepath': temporary_manifest_filepath, - 'input_key': 'input_filepath', - 'input_channel_selector': input_channel_selector, - 'batch_size': min(batch_size, len(paths2audio_files)), - 'num_workers': num_workers, - } - - # Create output dir if necessary - if not os.path.isdir(output_dir): - os.makedirs(output_dir) - - # DataLoader for the input files - temporary_dataloader = self._setup_process_dataloader(config) - - # Indexing of the original files, used to form the output file name - file_idx = 0 - - # Process batches - for test_batch in tqdm(temporary_dataloader, desc="Processing"): - input_signal = test_batch[0] - input_length = test_batch[1] - - # Expand channel dimension, if necessary - # For consistency, the model uses multi-channel format, even if the channel dimension is 1 - if input_signal.ndim == 2: - input_signal = input_signal.unsqueeze(1) - - processed_batch, _ = self.forward( - input_signal=input_signal.to(device), input_length=input_length.to(device) - ) - - for example_idx in range(processed_batch.size(0)): - # This assumes the data loader is not shuffling files - file_name = os.path.basename(paths2audio_files[file_idx]) - # Prepare output file - output_file = os.path.join(output_dir, f'processed_{file_name}') - # Crop the output signal to the actual length - output_signal = processed_batch[example_idx, :, : input_length[example_idx]].cpu().numpy() - # Write audio - sf.write(output_file, output_signal.T, self.sample_rate, 'float') - # Update the file counter - file_idx += 1 - # Save processed file - paths2processed_files.append(output_file) - - del test_batch - del processed_batch - - finally: - # set mode back to its original value - self.train(mode=mode) - if mode is True: - self.unfreeze() - logging.set_verbosity(logging_level) - - return paths2processed_files - - def _setup_dataloader_from_config(self, config: Optional[Dict]): - - if config.get("use_lhotse", False): - return get_lhotse_dataloader_from_config( - config, global_rank=self.global_rank, world_size=self.world_size, dataset=LhotseAudioToTargetDataset() - ) - - is_concat = config.get('is_concat', False) - if is_concat: - raise NotImplementedError('Concat not implemented') - - # TODO: Consider moving `inject` from `audio_to_text_dataset` to a utility module? - # Automatically inject args from model config to dataloader config - inject_dataloader_value_from_model_config(self.cfg, config, key='sample_rate') - - # Instantiate tarred dataset loader or normal dataset loader - if config.get('is_tarred', False): - raise NotImplementedError('Tarred datasets not supported') - - if 'manifest_filepath' in config and config['manifest_filepath'] is None: - logging.warning(f"Could not load dataset as `manifest_filepath` was None. Provided config : {config}") - return None - - dataset = audio_to_audio_dataset.get_audio_to_target_dataset(config=config) - - if hasattr(dataset, 'collate_fn'): - collate_fn = dataset.collate_fn - elif hasattr(dataset.datasets[0], 'collate_fn'): - # support datasets that are lists of entries - collate_fn = dataset.datasets[0].collate_fn - else: - # support datasets that are lists of lists - collate_fn = dataset.datasets[0].datasets[0].collate_fn - - return torch.utils.data.DataLoader( - dataset=dataset, - batch_size=config['batch_size'], - collate_fn=collate_fn, - drop_last=config.get('drop_last', False), - shuffle=config['shuffle'], - num_workers=config.get('num_workers', 0), - pin_memory=config.get('pin_memory', False), - ) - - def setup_training_data(self, train_data_config: Optional[Union[DictConfig, Dict]]): - """ - Sets up the training data loader via a Dict-like object. - - Args: - train_data_config: A config that contains the information regarding construction - of a training dataset. - - Supported Datasets: - - :class:`~nemo.collections.asr.data.audio_to_audio.AudioToTargetDataset` - """ - if 'shuffle' not in train_data_config: - train_data_config['shuffle'] = True - - # preserve config - self._update_dataset_config(dataset_name='train', config=train_data_config) - - self._train_dl = self._setup_dataloader_from_config(config=train_data_config) - - if 'is_tarred' in train_data_config and train_data_config['is_tarred']: - raise NotImplementedError('Tarred datasets not supported') - - def setup_validation_data(self, val_data_config: Optional[Union[DictConfig, Dict]]): - """ - Sets up the validation data loader via a Dict-like object. - - Args: - val_data_config: A config that contains the information regarding construction - of a validation dataset. - - Supported Datasets: - - :class:`~nemo.collections.asr.data.audio_to_audio.AudioToTargetDataset` - """ - if 'shuffle' not in val_data_config: - val_data_config['shuffle'] = False - - # preserve config - self._update_dataset_config(dataset_name='validation', config=val_data_config) - - self._validation_dl = self._setup_dataloader_from_config(config=val_data_config) - - def setup_test_data(self, test_data_config: Optional[Union[DictConfig, Dict]]): - """ - Sets up the test data loader via a Dict-like object. - - Args: - test_data_config: A config that contains the information regarding construction - of a test dataset. - - Supported Datasets: - - :class:`~nemo.collections.asr.data.audio_to_audio.AudioToTargetDataset` - """ - if 'shuffle' not in test_data_config: - test_data_config['shuffle'] = False - - # preserve config - self._update_dataset_config(dataset_name='test', config=test_data_config) - - self._test_dl = self._setup_dataloader_from_config(config=test_data_config) - - def _setup_process_dataloader(self, config: Dict) -> 'torch.utils.data.DataLoader': - """Prepare a dataloader for processing files. - - Args: - config: A python dictionary which contains the following keys: - manifest_filepath: path to a manifest file - input_key: key with audio filepaths in the manifest - input_channel_selector: Optional, used to select a subset of channels from input audio files - batch_size: batch size for the dataloader - num_workers: number of workers for the dataloader - - Returns: - A pytorch DataLoader for the given manifest filepath. - """ - dl_config = { - 'manifest_filepath': config['manifest_filepath'], - 'sample_rate': self.sample_rate, - 'input_key': config['input_key'], - 'input_channel_selector': config.get('input_channel_selector', None), - 'target_key': None, - 'target_channel_selector': None, - 'batch_size': config['batch_size'], - 'shuffle': False, - 'num_workers': config.get('num_workers', min(config['batch_size'], os.cpu_count() - 1)), - 'pin_memory': True, - } - - temporary_dataloader = self._setup_dataloader_from_config(config=DictConfig(dl_config)) - return temporary_dataloader - @property def input_types(self) -> Dict[str, NeuralType]: return { @@ -350,23 +96,6 @@ def output_types(self) -> Dict[str, NeuralType]: "output_length": NeuralType(tuple('B'), LengthsType(), optional=True), } - def match_batch_length(self, input: torch.Tensor, batch_length: int): - """Trim or pad the output to match the batch length. - - Args: - input: tensor with shape (B, C, T) - batch_length: int - - Returns: - Tensor with shape (B, C, T), where T matches the - batch length. - """ - input_length = input.size(-1) - pad_length = batch_length - input_length - pad = (0, pad_length) - # pad with zeros or crop - return torch.nn.functional.pad(input, pad, 'constant', 0) - @typecheck() def forward(self, input_signal, input_length=None): """ @@ -380,6 +109,7 @@ def forward(self, input_signal, input_length=None): sequences. Returns: + Output signal `output` in the time domain and the length of the output signal `output_length`. """ batch_length = input_signal.size(-1) @@ -414,12 +144,11 @@ def training_step(self, batch, batch_idx): else: input_signal, input_length, target_signal, _ = batch - # Expand channel dimension, if necessary # For consistency, the model uses multi-channel format, even if the channel dimension is 1 if input_signal.ndim == 2: - input_signal = input_signal.unsqueeze(1) + input_signal = einops.rearrange(input_signal, 'B T -> B 1 T') if target_signal.ndim == 2: - target_signal = target_signal.unsqueeze(1) + target_signal = einops.rearrange(target_signal, 'B T -> B 1 T') # Apply channel augmentation if self.training and self.channel_augmentation is not None: @@ -449,12 +178,11 @@ def evaluation_step(self, batch, batch_idx, dataloader_idx: int = 0, tag: str = else: input_signal, input_length, target_signal, _ = batch - # Expand channel dimension, if necessary # For consistency, the model uses multi-channel format, even if the channel dimension is 1 if input_signal.ndim == 2: - input_signal = input_signal.unsqueeze(1) + input_signal = einops.rearrange(input_signal, 'B T -> B 1 T') if target_signal.ndim == 2: - target_signal = target_signal.unsqueeze(1) + target_signal = einops.rearrange(target_signal, 'B T -> B 1 T') # Process input processed_signal, _ = self.forward(input_signal=input_signal, input_length=input_length) @@ -485,3 +213,406 @@ def list_available_models(cls) -> Optional[PretrainedModelInfo]: results = [] return results + + +class PredictiveAudioToAudioModel(AudioToAudioModel): + """This models aims to directly estimate the coefficients + in the encoded domain by applying a neural model. + """ + + def __init__(self, cfg: DictConfig, trainer: Trainer = None): + super().__init__(cfg=cfg, trainer=trainer) + self.sample_rate = self._cfg.sample_rate + + # Setup processing modules + self.encoder = self.from_config_dict(self._cfg.encoder) + self.decoder = self.from_config_dict(self._cfg.decoder) + + # Neural estimator + self.estimator = self.from_config_dict(self._cfg.estimator) + + # Normalization + self.normalize_input = self._cfg.get('normalize_input', False) + + # Term added to the denominator to improve numerical stability + self.eps = self._cfg.get('eps', 1e-8) + + # Setup optional Optimization flags + self.setup_optimization_flags() + + logging.debug('Initialized %s', self.__class__.__name__) + logging.debug('\tnormalize_input: %s', self.normalize_input) + logging.debug('\teps: %s', self.eps) + + @property + def input_types(self) -> Dict[str, NeuralType]: + return { + "input_signal": NeuralType(('B', 'C', 'T'), AudioSignal(freq=self.sample_rate)), + "input_length": NeuralType(tuple('B'), LengthsType(), optional=True), + } + + @property + def output_types(self) -> Dict[str, NeuralType]: + return { + "output_signal": NeuralType(('B', 'C', 'T'), AudioSignal(freq=self.sample_rate)), + "output_length": NeuralType(tuple('B'), LengthsType(), optional=True), + } + + @typecheck() + def forward(self, input_signal, input_length=None): + """Forward pass of the model. + + Args: + input_signal: time-domain signal + input_length: valid length of each example in the batch + + Returns: + Output signal `output` in the time domain and the length of the output signal `output_length`. + """ + batch_length = input_signal.size(-1) + + if self.normalize_input: + # max for each example in the batch + norm_scale = torch.amax(input_signal.abs(), dim=(-1, -2), keepdim=True) + # scale input signal + input_signal = input_signal / (norm_scale + self.eps) + + # Encoder + encoded, encoded_length = self.encoder(input=input_signal, input_length=input_length) + + # Backbone + estimated, estimated_length = self.estimator(input=encoded, input_length=encoded_length) + + # Decoder + output, output_length = self.decoder(input=estimated, input_length=estimated_length) + + if self.normalize_input: + # rescale to the original scale + output = output * norm_scale + + # Trim or pad the estimated signal to match input length + output = self.match_batch_length(input=output, batch_length=batch_length) + return output, output_length + + # PTL-specific methods + def training_step(self, batch, batch_idx): + + if isinstance(batch, dict): + # lhotse batches are dictionaries + input_signal = batch['input_signal'] + input_length = batch['input_length'] + target_signal = batch['target_signal'] + else: + input_signal, input_length, target_signal, _ = batch + + # For consistency, the model uses multi-channel format, even if the channel dimension is 1 + if input_signal.ndim == 2: + input_signal = einops.rearrange(input_signal, 'B T -> B 1 T') + if target_signal.ndim == 2: + target_signal = einops.rearrange(target_signal, 'B T -> B 1 T') + + # Estimate the signal + output_signal, _ = self.forward(input_signal=input_signal, input_length=input_length) + + # Calculate the loss + loss = self.loss(estimate=output_signal, target=target_signal, input_length=input_length) + + # Logs + self.log('train_loss', loss) + self.log('learning_rate', self._optimizer.param_groups[0]['lr']) + self.log('global_step', torch.tensor(self.trainer.global_step, dtype=torch.float32)) + + return loss + + def evaluation_step(self, batch, batch_idx, dataloader_idx: int = 0, tag: str = 'val'): + + if isinstance(batch, dict): + # lhotse batches are dictionaries + input_signal = batch['input_signal'] + input_length = batch['input_length'] + target_signal = batch['target_signal'] + else: + input_signal, input_length, target_signal, _ = batch + + # For consistency, the model uses multi-channel format, even if the channel dimension is 1 + if input_signal.ndim == 2: + input_signal = einops.rearrange(input_signal, 'B T -> B 1 T') + if target_signal.ndim == 2: + target_signal = einops.rearrange(target_signal, 'B T -> B 1 T') + + # Estimate the signal + output_signal, _ = self.forward(input_signal=input_signal, input_length=input_length) + + # Prepare output + loss = self.loss(estimate=output_signal, target=target_signal, input_length=input_length) + + # Update metrics + if hasattr(self, 'metrics') and tag in self.metrics: + # Update metrics for this (tag, dataloader_idx) + for name, metric in self.metrics[tag][dataloader_idx].items(): + metric.update(preds=output_signal, target=target_signal, input_length=input_length) + + # Log global step + self.log('global_step', torch.tensor(self.trainer.global_step, dtype=torch.float32)) + + return {f'{tag}_loss': loss} + + +class ScoreBasedGenerativeAudioToAudioModel(AudioToAudioModel): + """This models is using a score-based diffusion process to generate + an encoded representation of the enhanced signal. + + The model consists of the following blocks: + - encoder: transforms input multi-channel audio signal into an encoded representation (analysis transform) + - estimator: neural model, estimates a score for the diffusion process + - sde: stochastic differential equation (SDE) defining the forward and reverse diffusion process + - sampler: sampler for the reverse diffusion process, estimates coefficients of the target signal + - decoder: transforms sampler output into the time domain (synthesis transform) + """ + + def __init__(self, cfg: DictConfig, trainer: Trainer = None): + super().__init__(cfg=cfg, trainer=trainer) + self.sample_rate = self._cfg.sample_rate + + # Setup processing modules + self.encoder = self.from_config_dict(self._cfg.encoder) + self.decoder = self.from_config_dict(self._cfg.decoder) + + # Neural score estimator + self.estimator = self.from_config_dict(self._cfg.estimator) + + # SDE + self.sde = self.from_config_dict(self._cfg.sde) + + # Sampler + if 'sde' in self._cfg.sampler: + raise ValueError('SDE should be defined in the model config, not in the sampler config') + if 'score_estimator' in self._cfg.sampler: + raise ValueError('Score estimator should be defined in the model config, not in the sampler config') + + self.sampler = hydra.utils.instantiate(self._cfg.sampler, sde=self.sde, score_estimator=self.estimator) + + # Normalization + self.normalize_input = self._cfg.get('normalize_input', False) + + # Metric evaluation + self.max_utts_evaluation_metrics = self._cfg.get('max_utts_evaluation_metrics') + + if self.max_utts_evaluation_metrics is not None: + logging.warning( + 'Metrics will be evaluated on first %d examples of the evaluation datasets.', + self.max_utts_evaluation_metrics, + ) + + # Term added to the denominator to improve numerical stability + self.eps = self._cfg.get('eps', 1e-8) + + # Setup optional Optimization flags + self.setup_optimization_flags() + + logging.debug('Initialized %s', self.__class__.__name__) + logging.debug('\tnormalize_input: %s', self.normalize_input) + logging.debug('\teps: %s', self.eps) + + @property + def input_types(self) -> Dict[str, NeuralType]: + return { + "input_signal": NeuralType(('B', 'C', 'T'), AudioSignal(freq=self.sample_rate)), + "input_length": NeuralType(tuple('B'), LengthsType(), optional=True), + } + + @property + def output_types(self) -> Dict[str, NeuralType]: + return { + "output_signal": NeuralType(('B', 'C', 'T'), AudioSignal(freq=self.sample_rate)), + "output_length": NeuralType(tuple('B'), LengthsType(), optional=True), + } + + @typecheck() + @torch.inference_mode() + def forward(self, input_signal, input_length=None): + """Forward pass of the model. + + Forward pass of the model aplies the following steps: + - encoder to obtain the encoded representation of the input signal + - sampler to generate the estimated coefficients of the target signal + - decoder to transform the sampler output into the time domain + + Args: + input_signal: Tensor that represents a batch of raw audio signals, + of shape [B, T] or [B, T, C]. T here represents timesteps, with 1 second of audio represented as + `self.sample_rate` number of floating point values. + input_signal_length: Vector of length B, that contains the individual lengths of the audio + sequences. + + Returns: + Output signal `output` in the time domain and the length of the output signal `output_length`. + """ + batch_length = input_signal.size(-1) + + if self.normalize_input: + # max for each example in the batch + norm_scale = torch.amax(input_signal.abs(), dim=(-1, -2), keepdim=True) + # scale input signal + input_signal = input_signal / (norm_scale + self.eps) + + # Encoder + encoded, encoded_length = self.encoder(input=input_signal, input_length=input_length) + + # Sampler + generated, generated_length = self.sampler( + prior_mean=encoded, score_condition=encoded, state_length=encoded_length + ) + + # Decoder + output, output_length = self.decoder(input=generated, input_length=generated_length) + + if self.normalize_input: + # rescale to the original scale + output = output * norm_scale + + # Trim or pad the estimated signal to match input length + output = self.match_batch_length(input=output, batch_length=batch_length) + return output, output_length + + @typecheck( + input_types={ + "target_signal": NeuralType(('B', 'C', 'T'), AudioSignal()), + "input_signal": NeuralType(('B', 'C', 'T'), AudioSignal()), + "input_length": NeuralType(tuple('B'), LengthsType()), + }, + output_types={"loss": NeuralType(None, LossType()),}, + ) + def _step(self, target_signal, input_signal, input_length=None): + """Randomly generate a time step for each example in the batch, estimate + the score and calculate the loss value. + + Note that this step does not include sampler. + """ + batch_size = target_signal.size(0) + + if self.normalize_input: + # max for each example in the batch + norm_scale = torch.amax(input_signal.abs(), dim=(-1, -2), keepdim=True) + # scale input signal + input_signal = input_signal / (norm_scale + self.eps) + # scale the target signal + target_signal = target_signal / (norm_scale + self.eps) + + # Apply encoder to both target and the input + input_enc, input_enc_len = self.encoder(input=input_signal, input_length=input_length) + target_enc, _ = self.encoder(input=target_signal, input_length=input_length) + + # Generate random time steps + sde_time = self.sde.generate_time(size=batch_size, device=input_enc.device) + + # Get the mean and the variance of the perturbation kernel + pk_mean, pk_std = self.sde.perturb_kernel_params(state=target_enc, prior_mean=input_enc, time=sde_time) + + # Generate a random sample from a standard normal distribution + z_norm = torch.randn_like(input_enc) + + # Prepare perturbed data + perturbed_enc = pk_mean + pk_std * z_norm + + # Score is conditioned on the perturbed data and the input + estimator_input = torch.cat([perturbed_enc, input_enc], dim=-3) + + # Estimate the score using the neural estimator + # SDE time is used to inform the estimator about the current time step + # Note: + # - some implementations use `score = -self._raw_dnn_output(x, t, y)` + # - this seems to be unimportant, and is an artifact of transfering code from the original Song's repo + score_est, score_len = self.estimator(input=estimator_input, input_length=input_enc_len, condition=sde_time) + + # Score loss weighting as in Section 4.2 in http://arxiv.org/abs/1907.05600 + score_est = score_est * pk_std + score_ref = -z_norm + + # Score matching loss on the normalized scores + loss = self.loss(estimate=score_est, target=score_ref, input_length=score_len) + + return loss + + # PTL-specific methods + def training_step(self, batch, batch_idx): + + if isinstance(batch, dict): + # lhotse batches are dictionaries + input_signal = batch['input_signal'] + input_length = batch['input_length'] + target_signal = batch['target_signal'] + else: + input_signal, input_length, target_signal, _ = batch + + # For consistency, the model uses multi-channel format, even if the channel dimension is 1 + if input_signal.ndim == 2: + input_signal = einops.rearrange(input_signal, 'B T -> B 1 T') + if target_signal.ndim == 2: + target_signal = einops.rearrange(target_signal, 'B T -> B 1 T') + + # Calculate the loss + loss = self._step(target_signal=target_signal, input_signal=input_signal, input_length=input_length) + + # Logs + self.log('train_loss', loss) + self.log('learning_rate', self._optimizer.param_groups[0]['lr']) + self.log('global_step', torch.tensor(self.trainer.global_step, dtype=torch.float32)) + + return loss + + def evaluation_step(self, batch, batch_idx, dataloader_idx: int = 0, tag: str = 'val'): + + if isinstance(batch, dict): + # lhotse batches are dictionaries + input_signal = batch['input_signal'] + input_length = batch['input_length'] + target_signal = batch['target_signal'] + else: + input_signal, input_length, target_signal, _ = batch + + # For consistency, the model uses multi-channel format, even if the channel dimension is 1 + if input_signal.ndim == 2: + input_signal = einops.rearrange(input_signal, 'B T -> B 1 T') + if target_signal.ndim == 2: + target_signal = einops.rearrange(target_signal, 'B T -> B 1 T') + + # Calculate loss + loss = self._step(target_signal=target_signal, input_signal=input_signal, input_length=input_length) + + # Update metrics + update_metrics = False + if self.max_utts_evaluation_metrics is None: + # Always update if max is not configured + update_metrics = True + # Number of examples to process + num_examples = input_signal.size(0) # batch size + else: + # Check how many examples have been used for metric calculation + first_metric_name = next(iter(self.metrics[tag][dataloader_idx])) + num_examples_evaluated = self.metrics[tag][dataloader_idx][first_metric_name].num_examples + # Update metrics if some examples were not processed + update_metrics = num_examples_evaluated < self.max_utts_evaluation_metrics + # Number of examples to process + num_examples = min(self.max_utts_evaluation_metrics - num_examples_evaluated, input_signal.size(0)) + + if update_metrics: + # Generate output signal + output_signal, _ = self.forward( + input_signal=input_signal[:num_examples, ...], input_length=input_length[:num_examples] + ) + + # Update metrics + if hasattr(self, 'metrics') and tag in self.metrics: + # Update metrics for this (tag, dataloader_idx) + for name, metric in self.metrics[tag][dataloader_idx].items(): + metric.update( + preds=output_signal, + target=target_signal[:num_examples, ...], + input_length=input_length[:num_examples], + ) + + # Log global step + self.log('global_step', torch.tensor(self.trainer.global_step, dtype=torch.float32)) + + return {f'{tag}_loss': loss} diff --git a/nemo/collections/asr/modules/audio_modules.py b/nemo/collections/asr/modules/audio_modules.py index 82cfbefeb8d9..67a923099cde 100644 --- a/nemo/collections/asr/modules/audio_modules.py +++ b/nemo/collections/asr/modules/audio_modules.py @@ -17,7 +17,7 @@ import numpy as np import torch -from nemo.collections.asr.losses.audio_losses import temporal_mean +from nemo.collections.asr.losses.audio_losses import calculate_mean from nemo.collections.asr.modules.conformer_encoder import ConformerEncoder from nemo.collections.asr.parts.preprocessing.features import make_seq_mask_like from nemo.collections.asr.parts.submodules.multichannel_modules import ( @@ -39,6 +39,7 @@ 'MaskReferenceChannel', 'MaskBasedBeamformer', 'MaskBasedDereverbWPE', + 'MixtureConsistencyProjection', ] @@ -158,7 +159,7 @@ def get_mean_time_channel(input: torch.Tensor, input_length: Optional[torch.Tens mean = torch.mean(input, dim=(-1, -3), keepdim=True) else: # temporal mean - mean = temporal_mean(input, input_length, keepdim=True) + mean = calculate_mean(input, input_length, dim=-1, keepdim=True) # channel mean mean = torch.mean(mean, dim=-3, keepdim=True) @@ -186,7 +187,7 @@ def get_mean_std_time_channel( mean = cls.get_mean_time_channel(input, input_length) std = (input - mean).pow(2) # temporal mean - std = temporal_mean(std, input_length, keepdim=True) + std = calculate_mean(std, input_length, dim=-1, keepdim=True) # channel mean std = torch.mean(std, dim=-3, keepdim=True) # final value diff --git a/nemo/collections/asr/modules/audio_preprocessing.py b/nemo/collections/asr/modules/audio_preprocessing.py index cc5312403255..643bc4a69d69 100644 --- a/nemo/collections/asr/modules/audio_preprocessing.py +++ b/nemo/collections/asr/modules/audio_preprocessing.py @@ -709,9 +709,11 @@ class AudioToSpectrogram(NeuralModule): hop_length: length of hops/shifts of the sliding window power: exponent for magnitude spectrogram. Default `None` will return a complex-valued spectrogram + magnitude_power: Transform magnitude of the spectrogram as x^magnitude_power. + scale: Positive scaling of the spectrogram. """ - def __init__(self, fft_length: int, hop_length: int, power: Optional[float] = None): + def __init__(self, fft_length: int, hop_length: int, magnitude_power: float = 1.0, scale: float = 1.0): if not HAVE_TORCHAUDIO: logging.error('Could not import torchaudio. Some features might not work.') @@ -726,12 +728,26 @@ def __init__(self, fft_length: int, hop_length: int, power: Optional[float] = No raise ValueError(f'fft_length = {fft_length} must be divisible by 2') self.stft = torchaudio.transforms.Spectrogram( - n_fft=fft_length, hop_length=hop_length, power=power, pad_mode='constant' + n_fft=fft_length, hop_length=hop_length, power=None, pad_mode='constant' ) # number of subbands self.F = fft_length // 2 + 1 + if magnitude_power <= 0: + raise ValueError(f'Magnitude power needs to be positive: current value {magnitude_power}') + self.magnitude_power = magnitude_power + + if scale <= 0: + raise ValueError(f'Scale needs to be positive: current value {scale}') + self.scale = scale + + logging.debug('Initialized %s with:', self.__class__.__name__) + logging.debug('\tfft_length: %s', fft_length) + logging.debug('\thop_length: %s', hop_length) + logging.debug('\tmagnitude_power: %s', magnitude_power) + logging.debug('\tscale: %s', scale) + @property def num_subbands(self) -> int: return self.F @@ -776,6 +792,14 @@ def forward( with torch.cuda.amp.autocast(enabled=False): output = self.stft(input.float()) + if self.magnitude_power != 1: + # apply power on the magnitude + output = torch.pow(output.abs(), self.magnitude_power) * torch.exp(1j * output.angle()) + + if self.scale != 1: + # apply scaling of the coefficients + output = self.scale * output + if input_length is not None: # Mask padded frames output_length = self.get_output_length(input_length=input_length) @@ -810,11 +834,11 @@ class SpectrogramToAudio(NeuralModule): Args: fft_length: length of FFT hop_length: length of hops/shifts of the sliding window - power: exponent for magnitude spectrogram. Default `None` will - return a complex-valued spectrogram + magnitude_power: Transform magnitude of the spectrogram as x^(1/magnitude_power). + scale: Spectrogram will be scaled with 1/scale before the inverse transform. """ - def __init__(self, fft_length: int, hop_length: int): + def __init__(self, fft_length: int, hop_length: int, magnitude_power: float = 1.0, scale: float = 1.0): if not HAVE_TORCHAUDIO: logging.error('Could not import torchaudio. Some features might not work.') @@ -834,6 +858,20 @@ def __init__(self, fft_length: int, hop_length: int): self.F = fft_length // 2 + 1 + if magnitude_power <= 0: + raise ValueError(f'Magnitude power needs to be positive: current value {magnitude_power}') + self.magnitude_power = magnitude_power + + if scale <= 0: + raise ValueError(f'Scale needs to be positive: current value {scale}') + self.scale = scale + + logging.debug('Initialized %s with:', self.__class__.__name__) + logging.debug('\tfft_length: %s', fft_length) + logging.debug('\thop_length: %s', hop_length) + logging.debug('\tmagnitude_power: %s', magnitude_power) + logging.debug('\tscale: %s', scale) + @property def num_subbands(self) -> int: return self.F @@ -875,7 +913,16 @@ def forward(self, input: torch.Tensor, input_length: Optional[torch.Tensor] = No # iSTFT output (B, C, T) with torch.cuda.amp.autocast(enabled=False): - output = self.istft(input.cfloat()) + output = input.cfloat() + + if self.scale != 1: + # apply 1/scale on the coefficients + output = output / self.scale + + if self.magnitude_power != 1: + # apply 1/power on the magnitude + output = torch.pow(output.abs(), 1 / self.magnitude_power) * torch.exp(1j * output.angle()) + output = self.istft(output) if input_length is not None: # Mask padded samples diff --git a/nemo/collections/asr/parts/submodules/diffusion.py b/nemo/collections/asr/parts/submodules/diffusion.py new file mode 100644 index 000000000000..db3d30f49701 --- /dev/null +++ b/nemo/collections/asr/parts/submodules/diffusion.py @@ -0,0 +1,1310 @@ +# Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from abc import ABC, abstractmethod +from typing import Dict, Optional, Sequence, Tuple, Type + +import einops +import einops.layers.torch +import numpy as np +import torch +import torch.nn.functional as F + +from nemo.collections.common.parts.utils import activation_registry +from nemo.collections.tts.parts.utils.helpers import mask_sequence_tensor +from nemo.core.classes import NeuralModule, typecheck +from nemo.core.neural_types import FloatType, LengthsType, NeuralType, SpectrogramType, VoidType +from nemo.utils import logging + +__all__ = [ + 'OrnsteinUhlenbeckVarianceExplodingSDE', + 'SpectrogramNoiseConditionalScoreNetworkPlusPlus', + 'NoiseConditionalScoreNetworkPlusPlus', + 'PredictorCorrectorSampler', +] + + +class StochasticDifferentialEquation(NeuralModule, ABC): + """Base class for stochastic differential equations. + """ + + def __init__(self, time_min: float, time_max: float, num_steps: int): + super().__init__() + + # min and max time + if time_min <= 0: + raise ValueError(f'time_min should be positive, current value {time_min}') + + if time_max <= time_min: + raise ValueError(f'time_max should be larger than time_min, current max {time_max} and min {time_min}') + + self.time_min = time_min + self.time_max = time_max + + # number of steps + if num_steps <= 0: + raise ValueError(f'num_steps needs to be positive: current value {num_steps}') + + self.num_steps = num_steps + + @property + def dt(self) -> float: + """Time step for this SDE. + This denotes the step size between `0` and `self.time_max` when using `self.num_steps`. + """ + return self.time_max / self.num_steps + + @property + def time_delta(self) -> float: + """Time range for this SDE. + """ + return self.time_max - self.time_min + + def generate_time(self, size: int, device: torch.device) -> torch.Tensor: + """Generate random time steps in the valid range. + + Time steps are generated between `self.time_min` and `self.time_max`. + + Args: + size: number of samples + device: device to use + + Returns: + A tensor of floats with shape (size,) + """ + time = torch.rand(size, device=device) * self.time_delta + self.time_min + return time + + @abstractmethod + def coefficients(self, state: torch.Tensor, time: torch.Tensor, **kwargs) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Args: + state: tensor of shape (B, C, D, T) + time: tensor of shape (B,) + + Returns: + Tuple with drift and diffusion coefficients. + """ + pass + + @typecheck( + input_types={"prior_mean": NeuralType(('B', 'C', 'D', 'T'), VoidType()),}, + output_types={"sample": NeuralType(('B', 'C', 'D', 'T'), VoidType()),}, + ) + @abstractmethod + def prior_sampling(self, prior_mean: torch.Tensor) -> torch.Tensor: + """Generate a sample from the prior distribution p_T. + + Args: + prior_mean: Mean of the prior distribution + + Returns: + A sample from the prior distribution. + """ + pass + + def discretize( + self, *, state: torch.Tensor, time: torch.Tensor, state_length: Optional[torch.Tensor] = None, **kwargs + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Assume we have the following SDE: + + dx = drift(x, t) * dt + diffusion(x, t) * dwt + + where `wt` is the standard Wiener process. + + We assume the following discretization: + + new_state = current_state + total_drift + total_diffusion * z_norm + + where `z_norm` is sampled from normal distribution with zero mean and unit variance. + + Args: + state: current state of the process, shape (B, C, D, T) + time: current time of the process, shape (B,) + state_length: length of the valid time steps for each example in the batch, shape (B,) + **kwargs: other parameters + + Returns: + Drift and diffusion. + """ + # Get coefficients + drift_coefficient, diffusion_coefficient = self.coefficients( + state=state, time=time, state_length=state_length, **kwargs + ) + + # Discretized drift + drift = drift_coefficient * self.dt + + # Note: + # Scale with sqrt(dt) because z_norm is sampled from a normal distribution with zero mean and + # unit variance and dwt is normally distributed with zero mean and variance dt + diffusion = diffusion_coefficient * np.sqrt(self.dt) + + return drift, diffusion + + @abstractmethod + def copy(self): + """Create a copy of this SDE. + """ + pass + + def __repr__(self): + desc = f'{self.__class__.__name__}(time_min={self.time_min}, time_max={self.time_max}, num_steps={self.num_steps})' + desc += f'\n\tdt: {self.dt}' + desc += f'\n\ttime_delta: {self.time_delta}' + return desc + + +class OrnsteinUhlenbeckVarianceExplodingSDE(StochasticDifferentialEquation): + """This class implements the Ornstein-Uhlenbeck SDE with variance exploding noise schedule. + + The SDE is given by: + + dx = theta * (y - x) dt + g(t) dw + + where `theta` is the stiffness parameter and `g(t)` is the diffusion coefficient: + + g(t) = std_min * (std_max/std_min)^t * sqrt(2 * log(std_max/std_min)) + + References: + Richter et al., Speech Enhancement and Dereverberation with Diffusion-based Generative Models, Tr. ASLP 2023 + """ + + def __init__( + self, + stiffness: float, + std_min: float, + std_max: float, + num_steps: int = 100, + time_min: float = 3e-2, + time_max: float = 1.0, + eps: float = 1e-8, + ): + super().__init__(time_min=time_min, time_max=time_max, num_steps=num_steps) + + # Small regularization + if eps <= 0: + raise ValueError(f'eps should be positive, current value {eps}') + self.eps = eps + + # stifness + self.stiffness = stiffness + + # noise schedule + if std_min <= 0: + raise ValueError(f'std_min should be positive, current value {std_min}') + + if std_max <= std_min: + raise ValueError(f'std_max should be larger than std_min, current max {std_max} and min {std_min}') + + self.std_min = std_min + self.std_max = std_max + + logging.debug('Initialized %s with', self.__class__.__name__) + logging.debug('\tstiffness: %s', self.stiffness) + logging.debug('\tstd_min: %s', self.std_min) + logging.debug('\tstd_max: %s', self.std_max) + logging.debug('\tnum_steps: %s', self.num_steps) + logging.debug('\ttime_min: %s', self.time_min) + logging.debug('\ttime_max: %s', self.time_max) + logging.debug('\teps: %s', self.eps) + + @property + def std_ratio(self) -> float: + return self.std_max / (self.std_min + self.eps) + + @property + def log_std_ratio(self) -> float: + return np.log(self.std_ratio + self.eps) + + @typecheck( + input_types={ + "state": NeuralType(('B', 'C', 'D', 'T'), VoidType()), + "prior_mean": NeuralType(('B', 'C', 'D', 'T'), VoidType()), + "time": NeuralType(tuple('B'), FloatType()), + }, + output_types={"mean": NeuralType(('B', 'C', 'D', 'T'), FloatType()),}, + ) + def perturb_kernel_mean(self, state: torch.Tensor, prior_mean: torch.Tensor, time: torch.Tensor) -> torch.Tensor: + """Return the mean of the perturbation kernel for this SDE. + + Args: + state: current state of the process, shape (B, C, D, T) + prior_mean: mean of the prior distribution + time: current time of the process, shape (B,) + + Returns: + A tensor of shape (B, C, D, T) + """ + # exponential weighting + weight = torch.exp(-self.stiffness * time) + + # view as [B, C, D, T] + weight = weight.view(-1, 1, 1, 1) + + # closed-form mean + mean = weight * state + (1 - weight) * prior_mean + + return mean + + @typecheck( + input_types={"time": NeuralType(tuple('B'), FloatType()),}, + output_types={"std": NeuralType(tuple('B'), FloatType()),}, + ) + def perturb_kernel_std(self, time: torch.Tensor) -> torch.Tensor: + """Return the standard deviation of the perturbation kernel for this SDE. + + Note that the standard deviation depends on the time and the noise schedule, + which is parametrized using `self.stiffness`, `self.std_min` and `self.std_max`. + + Args: + time: current time of the process, shape (B,) + + Returns: + A tensor of shape (B,) + """ + var = (self.std_min ** 2) * self.log_std_ratio + var *= torch.pow(self.std_ratio, 2 * time) - torch.exp(-2 * self.stiffness * time) + var /= self.stiffness + self.log_std_ratio + std = torch.sqrt(var) + return std + + @typecheck( + input_types={ + "state": NeuralType(('B', 'C', 'D', 'T'), VoidType()), + "prior_mean": NeuralType(('B', 'C', 'D', 'T'), VoidType()), + "time": NeuralType(tuple('B'), FloatType()), + }, + output_types={ + "mean": NeuralType(('B', 'C', 'D', 'T'), FloatType()), + "std": NeuralType(('B', 'C', 'D', 'T'), FloatType()), + }, + ) + def perturb_kernel_params(self, state: torch.Tensor, prior_mean: torch.Tensor, time: torch.Tensor) -> torch.Tensor: + """Return the mean and standard deviation of the perturbation kernel for this SDE. + + Args: + state: current state of the process, shape (B, C, D, T) + prior_mean: mean of the prior distribution + time: current time of the process, shape (B,) + """ + assert torch.all(time <= self.time_max) + assert torch.all(time >= self.time_min) + + # compute the mean + mean = self.perturb_kernel_mean(state=state, prior_mean=prior_mean, time=time) + + # compute the standard deviation + std = self.perturb_kernel_std(time=time) + # view as [B, C, D, T] + std = std.view(-1, 1, 1, 1) + + return mean, std + + @typecheck( + input_types={ + "state": NeuralType(('B', 'C', 'D', 'T'), VoidType()), + "time": NeuralType(tuple('B'), VoidType()), + "prior_mean": NeuralType(('B', 'C', 'D', 'T'), VoidType()), + "state_length": NeuralType(tuple('B'), LengthsType(), optional=True), + }, + output_types={ + "drift_coefficient": NeuralType(('B', 'C', 'D', 'T'), FloatType()), + "diffusion_coefficient": NeuralType(('B', 'C', 'D', 'T'), FloatType()), + }, + ) + def coefficients( + self, + state: torch.Tensor, + time: torch.Tensor, + prior_mean: torch.Tensor, + state_length: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Compute drift and diffusion coefficients for this SDE. + + Args: + state: current state of the process, shape (B, C, D, T) + time: current time of the process, shape (B,) + prior_mean: mean of the prior distribution + state_length: length of the valid time steps for each example in the batch + + Returns: + Drift and diffusion coefficients. + """ + # Drift coefficient + drift_coefficient = self.stiffness * (prior_mean - state) + + # Diffusion coefficient + diffusion_coefficient = self.std_min * torch.pow(self.std_ratio, time) * np.sqrt(2 * self.log_std_ratio) + # View in the same shape as the state + diffusion_coefficient = diffusion_coefficient.view(-1, *([1] * (state.dim() - 1))) + + if state_length is not None: + drift_coefficient = mask_sequence_tensor(drift_coefficient, state_length) + diffusion_coefficient = mask_sequence_tensor(diffusion_coefficient, state_length) + + return drift_coefficient, diffusion_coefficient + + def prior_sampling(self, prior_mean: torch.Tensor) -> torch.Tensor: + """Generate a sample from the prior distribution p_T. + + Args: + prior_mean: Mean of the prior distribution + """ + # Final time step for all samples in the batch + time = self.time_max * torch.ones(prior_mean.shape[0], device=prior_mean.device) + + # Compute the std of the prior distribution + std = self.perturb_kernel_std(time=time) + + # view as [B, C, D, T] + std = std.view(-1, 1, 1, 1) + + # Generate a sample from a normal distribution centered at prior_mean + sample = prior_mean + torch.randn_like(prior_mean) * std + + return sample + + def copy(self): + return OrnsteinUhlenbeckVarianceExplodingSDE( + stiffness=self.stiffness, + std_min=self.std_min, + std_max=self.std_max, + num_steps=self.num_steps, + time_min=self.time_min, + time_max=self.time_max, + eps=self.eps, + ) + + def __repr__(self): + desc = f'{self.__class__.__name__}(stiffness={self.stiffness}, std_min={self.std_min}, std_max={self.std_max}, num_steps={self.num_steps}, time_min={self.time_min}, time_max={self.time_max}, eps={self.eps})' + desc += f'\n\tdt: {self.dt}' + desc += f'\n\ttime_delta: {self.time_delta}' + desc += f'\n\tstd_ratio: {self.std_ratio}' + desc += f'\n\tlog_std_ratio: {self.log_std_ratio}' + + return desc + + +class ReverseStochasticDifferentialEquation(StochasticDifferentialEquation): + def __init__(self, *, sde: Type[StochasticDifferentialEquation], score_estimator: Type[NeuralModule]): + """Use the forward SDE and a score estimator to define the reverse SDE. + + Args: + sde: forward SDE + score_estimator: neural score estimator + """ + super().__init__(time_min=sde.time_min, time_max=sde.time_max, num_steps=sde.num_steps) + self.score_estimator = score_estimator + self.forward_sde = sde + + logging.debug('Initialized %s', self.__class__.__name__) + + def coefficients( + self, + state: torch.Tensor, + time: torch.Tensor, + score_condition: Optional[torch.Tensor] = None, + state_length: Optional[torch.Tensor] = None, + **kwargs, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Compute drift and diffusion coefficients for the reverse SDE. + + Args: + state: current state of the process, shape (B, C, D, T) + time: current time of the process, shape (B,) + """ + raise NotImplementedError('Coefficients not necessary for the reverse SDE.') + + def prior_sampling(self, shape: torch.Size, device: torch.device) -> torch.Tensor: + """Prior sampling is not necessary for the reverse SDE. + """ + raise NotImplementedError('Prior sampling not necessary for the reverse SDE.') + + def discretize( + self, + *, + state: torch.Tensor, + time: torch.Tensor, + score_condition: Optional[torch.Tensor] = None, + state_length: Optional[torch.Tensor] = None, + **kwargs, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Discretize the reverse SDE. + + Args: + state: current state of the process, shape (B, C, D, T) + time: current time of the process, shape (B,) + score_condition: condition for the score estimator + state_length: length of the valid time steps for each example in the batch + **kwargs: other parameters for discretization of the forward SDE + """ + # Drift and diffusion from the forward SDE + forward_drift, forward_diffusion = self.forward_sde.discretize(state=state, time=time, **kwargs) + + # For input for the score estimator: + # - if no condition is provided, use the state + # - if a condition is provided, concatenate the state and the condition along the channel dimension + score_input = state if score_condition is None else torch.cat([state, score_condition], dim=1) + + # Estimate score + score, _ = self.score_estimator(input=score_input, input_length=state_length, condition=time) + + # Adjust drift + drift = forward_drift - forward_diffusion.pow(2) * score + + # Adjust diffusion + diffusion = forward_diffusion + + if state_length is not None: + drift = mask_sequence_tensor(drift, state_length) + diffusion = mask_sequence_tensor(diffusion, state_length) + + return drift, diffusion + + def copy(self): + return ReverseStochasticDifferentialEquation(sde=self.forward_sde.copy(), score_estimator=self.score_estimator) + + def __repr__(self): + desc = f'{self.__class__.__name__}(sde={self.forward_sde}, score_estimator={self.score_estimator})' + return desc + + +class SpectrogramNoiseConditionalScoreNetworkPlusPlus(NeuralModule): + """This model handles complex-valued inputs by stacking real and imaginary components. + Stacked tensor is processed using NCSN++ and the output is projected to generate real + and imaginary components of the output channels. + + Args: + in_channels: number of input complex-valued channels + out_channels: number of output complex-valued channels + """ + + def __init__(self, *, in_channels: int = 1, out_channels: int = 1, **kwargs): + super().__init__() + + # Number of input signals for this estimator + if in_channels < 1: + raise ValueError( + f'Number of input channels needs to be larger or equal to one, current value {in_channels}' + ) + + self.in_channels = in_channels + + # Number of output signals for this estimator + if out_channels < 1: + raise ValueError( + f'Number of output channels needs to be larger or equal to one, current value {out_channels}' + ) + + self.out_channels = out_channels + + # Instantiate noise conditional score network NCSN++ + ncsnpp_params = kwargs.copy() + ncsnpp_params['in_channels'] = ncsnpp_params['out_channels'] = 2 * self.in_channels # stack real and imag + self.ncsnpp = NoiseConditionalScoreNetworkPlusPlus(**ncsnpp_params) + + # Output projection to generate real and imaginary components of the output channels + self.output_projection = torch.nn.Conv2d( + in_channels=2 * self.in_channels, out_channels=2 * self.out_channels, kernel_size=1 + ) + + logging.debug('Initialized %s with', self.__class__.__name__) + logging.debug('\tin_channels: %s', self.in_channels) + logging.debug('\tout_channels: %s', self.out_channels) + + @property + def input_types(self) -> Dict[str, NeuralType]: + """Returns definitions of module output ports. + """ + return { + "input": NeuralType(('B', 'C', 'D', 'T'), SpectrogramType()), + "input_length": NeuralType(('B',), LengthsType(), optional=True), + "condition": NeuralType(('B',), FloatType(), optional=True), + } + + @property + def output_types(self) -> Dict[str, NeuralType]: + """Returns definitions of module output ports. + """ + return { + "output": NeuralType(('B', 'C', 'D', 'T'), SpectrogramType()), + "output_length": NeuralType(('B',), LengthsType(), optional=True), + } + + @typecheck() + def forward(self, input, input_length=None, condition=None): + # Stack real and imaginary components + B, C_in, D, T = input.shape + + if C_in != self.in_channels: + raise RuntimeError(f'Unexpected input channel size {C_in}, expected {self.in_channels}') + + # Stack real and imaginary parts + input_real_imag = torch.stack([input.real, input.imag], dim=2) + input = einops.rearrange(input_real_imag, 'B C RI F T -> B (C RI) F T') + + # Process using NCSN++ + output, output_length = self.ncsnpp(input=input, input_length=input_length, condition=condition) + + # Output projection + output = self.output_projection(output) + + # Convert to complex-valued signal + output = output.reshape(B, 2, self.out_channels, D, T) + # Move real/imag dimension to the end + output = output.permute(0, 2, 3, 4, 1) + output = torch.view_as_complex(output.contiguous()) + + return output, output_length + + +class NoiseConditionalScoreNetworkPlusPlus(NeuralModule): + """Implementation of Noise Conditional Score Network (NCSN++) architecture. + + References: + - Song et al., Score-Based Generative Modeling through Stochastic Differential Equations, NeurIPS 2021 + - Brock et al., Large scale GAN training for high fidelity natural image synthesis, ICLR 2018 + """ + + def __init__( + self, + nonlinearity: str = "swish", + in_channels: int = 2, # number of channels in the input image + out_channels: int = 2, # number of channels in the output image + channels: Sequence[int] = (128, 128, 256, 256, 256), # number of channels at start + at every resolution + num_res_blocks: int = 2, + num_resolutions: int = 4, + init_scale: float = 1e-5, + conditioned_on_time: bool = False, + fourier_embedding_scale: float = 16.0, + dropout_rate: float = 0.0, + pad_time_to: Optional[int] = None, + pad_dimension_to: Optional[int] = None, + **_, + ): + # Network topology is a flavor of UNet, example chart for num_resolutions=4 + # + # 1: Image → Image/2 → Image/4 → Image/8 + # ↓ ↓ ↓ ↓ + # 2: Hidden → Hidden/2 → Hidden/4 → Hidden/8 + # ↓ ↓ ↓ ↓ + # 3: Hidden ← Hidden/2 ← Hidden/4 ← Hidden/8 + # ↓ ↓ ↓ ↓ + # 4: Image ← Image/2 ← Image/4 ← Image/8 + + # Horizontal arrows in (1) are downsampling + # Vertical arrows from (1) to (2) are channel upconversions + # + # Horizontal arrows in (2) are blocks with downsampling where necessary + # Horizontal arrows in (3) are blocks with upsampling where necessary + # + # Vertical arrows from (1) to (2) are downsampling and channel upconversioins + # Vertical arrows from (2) to (3) are sums connections (also with / sqrt(2)) + # Vertical arrows from (3) to (4) are channel downconversions + # Horizontal arrows in (4) are upsampling and addition + super().__init__() + + # same nonlinearity is used throughout the whole network + self.activation: torch.nn.Module = activation_registry[nonlinearity]() + self.init_scale: float = init_scale + + self.downsample = torch.nn.Upsample(scale_factor=0.5, mode="bilinear") + self.upsample = torch.nn.Upsample(scale_factor=2, mode="bilinear") + + self.in_channels = in_channels + self.out_channels = out_channels + self.channels = channels + self.num_res_blocks = num_res_blocks + self.num_resolutions = num_resolutions + self.conditioned_on_time = conditioned_on_time + + # padding setup + self.pad_time_to = pad_time_to or 2 ** self.num_resolutions + self.pad_dimension_to = pad_dimension_to or 2 ** self.num_resolutions + + if self.conditioned_on_time: + self.time_embedding = torch.nn.Sequential( + GaussianFourierProjection(embedding_size=self.channels[0], scale=fourier_embedding_scale), + torch.nn.Linear(self.channels[0] * 2, self.channels[0] * 4), + self.activation, + torch.nn.Linear(self.channels[0] * 4, self.channels[0] * 4), + ) + + self.input_pyramid = torch.nn.ModuleList() + for ch in self.channels[:-1]: + self.input_pyramid.append(torch.nn.Conv2d(in_channels=self.in_channels, out_channels=ch, kernel_size=1)) + + # each block takes an image and outputs an image + # possibly changes number of channels + # output blocks ("reverse" path of the unet) reuse outputs of input blocks ("forward" path) + # so great care must be taken to in/out channels of each block + # resolutions are handled in `forward` + block_params = { + "activation": self.activation, + "dropout_rate": dropout_rate, + "init_scale": self.init_scale, + "diffusion_step_embedding_dim": channels[0] * 4 if self.conditioned_on_time else None, + } + self.input_blocks = torch.nn.ModuleList() + for in_ch, out_ch in zip(self.channels[:-1], self.channels[1:]): + for n in range(num_res_blocks): + block = ResnetBlockBigGANPlusPlus(in_ch=in_ch if n == 0 else out_ch, out_ch=out_ch, **block_params) + self.input_blocks.append(block) + + self.output_blocks = torch.nn.ModuleList() + for in_ch, out_ch in zip(reversed(self.channels[1:]), reversed(self.channels[:-1])): + for n in reversed(range(num_res_blocks)): + block = ResnetBlockBigGANPlusPlus(in_ch=in_ch, out_ch=out_ch if n == 0 else in_ch, **block_params) + self.output_blocks.append(block) + + self.projection_blocks = torch.nn.ModuleList() + for ch in self.channels[:-1]: + self.projection_blocks.append(torch.nn.Conv2d(ch, out_channels, kernel_size=1)) + + assert len(self.input_pyramid) == self.num_resolutions + assert len(self.input_blocks) == self.num_resolutions * self.num_res_blocks + assert len(self.output_blocks) == self.num_resolutions * self.num_res_blocks + assert len(self.projection_blocks) == self.num_resolutions + + self.init_weights_() + + logging.debug('Initialized %s with', self.__class__.__name__) + logging.debug('\tin_channels: %s', self.in_channels) + logging.debug('\tout_channels: %s', self.out_channels) + logging.debug('\tchannels: %s', self.channels) + logging.debug('\tnum_res_blocks: %s', self.num_res_blocks) + logging.debug('\tnum_resolutions: %s', self.num_resolutions) + logging.debug('\tconditioned_on_time: %s', self.conditioned_on_time) + logging.debug('\tpad_time_to: %s', self.pad_time_to) + logging.debug('\tpad_dimension_to: %s', self.pad_dimension_to) + + def init_weights_(self): + for module in self.modules(): + if isinstance(module, (torch.nn.Linear, torch.nn.Conv2d)): + torch.nn.init.xavier_uniform_(module.weight) + if module.bias is not None: + torch.nn.init.zeros_(module.bias) + + # torch.nn submodules with scaled init + for module in self.projection_blocks: + torch.nn.init.xavier_uniform_(module.weight, gain=self.init_scale) + + # non-torch.nn submodules can have their own init schemes + for module in self.modules(): + if module is self: + continue + + if hasattr(module, "init_weights_"): + module.init_weights_() + + @typecheck( + input_types={"input": NeuralType(('B', 'C', 'D', 'T')),}, + output_types={"output": NeuralType(('B', 'C', 'D', 'T')),}, + ) + def pad_input(self, input: torch.Tensor) -> torch.Tensor: + """Pad input tensor to match the required dimensions across `T` and `D`. + """ + *_, D, T = input.shape + output = input + + # padding across time + if T % self.pad_time_to != 0: + output = F.pad(output, (0, self.pad_time_to - T % self.pad_time_to)) + + # padding across dimension + if D % self.pad_dimension_to != 0: + output = F.pad(output, (0, 0, 0, self.pad_dimension_to - D % self.pad_dimension_to)) + + return output + + @property + def input_types(self) -> Dict[str, NeuralType]: + """Returns definitions of module output ports. + """ + return { + "input": NeuralType(('B', 'C', 'D', 'T'), VoidType()), + "input_length": NeuralType(('B',), LengthsType(), optional=True), + "condition": NeuralType(('B',), FloatType(), optional=True), + } + + @property + def output_types(self) -> Dict[str, NeuralType]: + """Returns definitions of module output ports. + """ + return { + "output": NeuralType(('B', 'C', 'D', 'T'), VoidType()), + "output_length": NeuralType(('B',), LengthsType(), optional=True), + } + + @typecheck() + def forward( + self, *, input: torch.Tensor, input_length: Optional[torch.Tensor], condition: Optional[torch.Tensor] = None + ): + """Forward pass of the model. + + Args: + input: input tensor, shjae (B, C, D, T) + input_length: length of the valid time steps for each example in the batch, shape (B,) + condition: scalar condition (time) for the model, will be embedded using `self.time_embedding` + """ + assert input.shape[1] == self.in_channels + + # apply padding at the input + *_, D, T = input.shape + input = self.pad_input(input=input) + + if input_length is None: + # assume all time frames are valid + input_length = torch.LongTensor([input.shape[-1]] * input.shape[0]).to(input.device) + + lengths = input_length + + if condition is not None: + if len(condition.shape) != 1: + raise ValueError( + f"Expected conditon to be a 1-dim tensor, got a {len(condition.shape)}-dim tensor of shape {tuple(condition.shape)}" + ) + if condition.shape[0] != input.shape[0]: + raise ValueError( + f"Condition {tuple(condition.shape)} and input {tuple(input.shape)} should match along the batch dimension" + ) + + condition = self.time_embedding(torch.log(condition)) + + # downsample and project input image to add later in the downsampling path + pyramid = [input] + for resolution_num in range(self.num_resolutions - 1): + pyramid.append(self.downsample(pyramid[-1])) + pyramid = [block(image) for image, block in zip(pyramid, self.input_pyramid)] + + # downsampling path + history = [] + hidden = torch.zeros_like(pyramid[0]) + input_blocks = iter(self.input_blocks) + for resolution_num, image in enumerate(pyramid): + hidden = (hidden + image) / math.sqrt(2.0) + hidden = mask_sequence_tensor(hidden, lengths) + + for _ in range(self.num_res_blocks): + hidden = next(input_blocks)(hidden, condition) + hidden = mask_sequence_tensor(hidden, lengths) + history.append(hidden) + + final_resolution = resolution_num == self.num_resolutions - 1 + if not final_resolution: + hidden = self.downsample(hidden) + lengths = (lengths / 2).ceil().long() + + # upsampling path + to_project = [] + for residual, block in zip(reversed(history), self.output_blocks): + if hidden.shape != residual.shape: + to_project.append(hidden) + hidden = self.upsample(hidden) + lengths = (lengths * 2).long() + + hidden = (hidden + residual) / math.sqrt(2.0) + hidden = block(hidden, condition) + hidden = mask_sequence_tensor(hidden, lengths) + + to_project.append(hidden) + + # projecting to images + images = [] + for tensor, projection in zip(to_project, reversed(self.projection_blocks)): + image = projection(tensor) + images.append(F.interpolate(image, size=input.shape[-2:])) # TODO write this loop using self.upsample + + result = sum(images) + + assert result.shape[-2:] == input.shape[-2:] + + # remove padding + result = result[:, :, :D, :T] + return result, input_length + + +class GaussianFourierProjection(NeuralModule): + """Gaussian Fourier embeddings for input scalars. + + The input scalars are typically time or noise levels. + """ + + def __init__(self, embedding_size: int = 256, scale: float = 1.0): + super().__init__() + self.W = torch.nn.Parameter(torch.randn(embedding_size) * scale, requires_grad=False) + + @property + def input_types(self) -> Dict[str, NeuralType]: + """Returns definitions of module output ports. + """ + return { + "input": NeuralType(('B',), FloatType()), + } + + @property + def output_types(self) -> Dict[str, NeuralType]: + """Returns definitions of module output ports. + """ + return { + "output": NeuralType(('B', 'D'), VoidType()), + } + + def forward(self, input): + x_proj = input[:, None] * self.W[None, :] * 2 * math.pi + return torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1) + + +class ResnetBlockBigGANPlusPlus(torch.nn.Module): + """Implementation of a ResNet block for the BigGAN model. + + References: + - Song et al., Score-Based Generative Modeling through Stochastic Differential Equations, NeurIPS 2021 + - Brock et al., Large scale GAN training for high fidelity natural image synthesis, ICLR 2018 + """ + + def __init__( + self, + activation: torch.nn.Module, + in_ch: int, + out_ch: int, + diffusion_step_embedding_dim: Optional[int] = None, + init_scale: float = 1e-5, + dropout_rate: float = 0.1, + in_num_groups: Optional[int] = None, + out_num_groups: Optional[int] = None, + eps: float = 1e-6, + ): + """ + Args: + activation (torch.nn.Module): activation layer (ReLU, SiLU, etc) + in_ch (int): number of channels in the input image + out_ch (int, optional): number of channels in the output image + diffusion_step_embedding_dim (int, optional): dimension of diffusion timestep embedding. Defaults to None (no embedding). + dropout_rate (float, optional): dropout rate. Defaults to 0.1. + init_scale (float, optional): scaling for weight initialization. Defaults to 0.0. + in_num_groups (int, optional): num_groups in the first GroupNorm. Defaults to min(in_ch // 4, 32) + out_num_groups (int, optional): num_groups in the second GroupNorm. Defaults to min(out_ch // 4, 32) + eps (float, optional): eps parameter of GroupNorms. Defaults to 1e-6. + """ + super().__init__() + in_num_groups = in_num_groups or min(in_ch // 4, 32) + out_num_groups = out_num_groups or min(out_ch // 4, 32) + + self.init_scale = init_scale + + self.input_block = torch.nn.Sequential( + torch.nn.GroupNorm(num_groups=in_num_groups, num_channels=in_ch, eps=eps), activation, + ) + + self.middle_conv = torch.nn.Conv2d(in_channels=in_ch, out_channels=out_ch, kernel_size=3, padding=1) + if diffusion_step_embedding_dim is not None: + self.diffusion_step_projection = torch.nn.Sequential( + activation, + torch.nn.Linear(diffusion_step_embedding_dim, out_ch), + einops.layers.torch.Rearrange("batch dim -> batch dim 1 1"), + ) + + self.output_block = torch.nn.Sequential( + torch.nn.GroupNorm(num_groups=out_num_groups, num_channels=out_ch, eps=eps), + activation, + torch.nn.Dropout(dropout_rate), + torch.nn.Conv2d(in_channels=out_ch, out_channels=out_ch, kernel_size=3, padding=1), + ) + + if in_ch != out_ch: + self.residual_projection = torch.nn.Conv2d(in_channels=in_ch, out_channels=out_ch, kernel_size=1) + + self.act = activation + self.in_ch = in_ch + self.out_ch = out_ch + + self.init_weights_() + + def init_weights_(self): + """Weight initialization + """ + for module in self.modules(): + if isinstance(module, (torch.nn.Conv2d, torch.nn.Linear)): + torch.nn.init.xavier_uniform_(module.weight) + if module.bias is not None: + torch.nn.init.zeros_(module.bias) + + # a single Conv2d is initialized with gain + torch.nn.init.xavier_uniform_(self.output_block[-1].weight, gain=self.init_scale) + + def forward(self, x: torch.Tensor, diffusion_time_embedding: Optional[torch.Tensor] = None): + """Forward pass of the model. + + Args: + x: input tensor + diffusion_time_embedding: embedding of the diffusion time step + + Returns: + Output tensor + """ + h = self.input_block(x) + h = self.middle_conv(h) + + if diffusion_time_embedding is not None: + h = h + self.diffusion_step_projection(diffusion_time_embedding) + + h = self.output_block(h) + + if x.shape != h.shape: # matching number of channels + x = self.residual_projection(x) + return (x + h) / math.sqrt(2.0) + + +class PredictorCorrectorSampler(NeuralModule): + """Predictor-Corrector sampler for the reverse SDE. + + Args: + sde: forward SDE + score_estimator: neural score estimator + predictor: predictor for the reverse process + corrector: corrector for the reverse process + num_steps: number of time steps for the reverse process + num_corrector_steps: number of corrector steps + time_max: maximum time + time_min: minimum time + snr: SNR for Annealed Langevin Dynamics + output_type: type of the output ('state' for the final state, or 'mean' for the mean of the final state) + + References: + - Song et al., Score-based generative modeling through stochastic differential equations, 2021 + """ + + def __init__( + self, + sde, + score_estimator, + predictor: str = 'reverse_diffusion', + corrector: str = 'annealed_langevin_dynamics', + num_steps: int = 50, + num_corrector_steps: int = 1, + time_max: Optional[float] = None, + time_min: Optional[float] = None, + snr: float = 0.5, + output_type: str = 'mean', + ): + super().__init__() + # Create a copy of SDE + self.sde = sde.copy() + + # Update SDE parameters for sampling + if time_max is not None: + self.sde.time_max = time_max + logging.info('sde.time_max set to: %s', self.sde.time_max) + + if time_min is not None: + self.sde.time_min = time_min + logging.info('sde.time_min set to: %s', self.sde.time_min) + + self.sde.num_steps = num_steps + logging.info('sde.num_steps set to: %s', self.sde.num_steps) + + # Update local values + self.time_max = self.sde.time_max + self.time_min = self.sde.time_min + self.num_steps = self.sde.num_steps + + # Predictor setup + if predictor == 'reverse_diffusion': + self.predictor = ReverseDiffusionPredictor(sde=self.sde, score_estimator=score_estimator) + else: + raise RuntimeError(f'Unexpected predictor: {predictor}') + + # Corrector setup + if corrector == 'annealed_langevin_dynamics': + self.corrector = AnnealedLangevinDynamics( + sde=self.sde, score_estimator=score_estimator, snr=snr, num_steps=num_corrector_steps + ) + else: + raise RuntimeError(f'Unexpected corrector: {corrector}') + + if output_type not in ['mean', 'state']: + raise ValueError(f'Unexpected output type: {output_type}') + self.output_type = output_type + + logging.debug('Initialized %s with', self.__class__.__name__) + logging.debug('\tpredictor: %s', predictor) + logging.debug('\tcorrector: %s', corrector) + logging.debug('\tnum_steps: %s', self.num_steps) + logging.debug('\ttime_min: %s', self.time_min) + logging.debug('\ttime_max: %s', self.time_max) + logging.debug('\tnum_corrector_steps: %s', num_corrector_steps) + logging.debug('\tsnr: %s', snr) + logging.debug('\toutput_type: %s', self.output_type) + + @typecheck( + input_types={ + "prior_mean": NeuralType(('B', 'C', 'D', 'T'), SpectrogramType()), + "score_condition": NeuralType(('B', 'C', 'D', 'T'), SpectrogramType(), optional=True), + "state_length": NeuralType(tuple('B'), LengthsType(), optional=True), + }, + output_types={ + "sample": NeuralType(('B', 'C', 'D', 'T'), SpectrogramType()), + "state_length": NeuralType(tuple('B'), LengthsType(), optional=True), + }, + ) + @torch.inference_mode() + def forward( + self, prior_mean: torch.Tensor, score_condition: torch.Tensor, state_length: Optional[torch.Tensor] = None + ) -> torch.Tensor: + """Takes prior (noisy) mean and generates a sample by solving the reverse SDE. + + Args: + prior_mean: mean for the prior distribution, e.g., noisy observation + score_condition: conditioning for the score estimator + state_length: length of the valid time steps for each example in the batch + + Returns: + Generated `sample` and the corresponding `sample_length`. + """ + # Sample from the prior distribution + state = self.sde.prior_sampling(prior_mean=prior_mean) + + if state_length is not None: + state = mask_sequence_tensor(state, state_length) + + # Time steps for evaluation + time_steps = torch.linspace(self.time_max, self.time_min, self.num_steps, device=state.device) + + # Sampling + for t in time_steps: + # time steps for the whole batch + time = t * torch.ones(state.shape[0], device=state.device) + + # corrector step + state, _ = self.corrector( + state=state, time=time, score_condition=score_condition, state_length=state_length + ) + + # predictor step + state, state_mean = self.predictor( + state=state, + time=time, + score_condition=score_condition, + prior_mean=prior_mean, + state_length=state_length, + ) + + # Final output + if self.output_type == 'state': + sample = state + elif self.output_type == 'mean': + sample = state_mean + else: + raise RuntimeError(f'Unexpected output type: {self.output_type}') + + if state_length is not None: + sample = mask_sequence_tensor(sample, state_length) + + return sample, state_length + + +class Predictor(torch.nn.Module, ABC): + """Predictor for the reverse process. + + Args: + sde: forward SDE + score_estimator: neural score estimator + """ + + def __init__(self, sde, score_estimator): + super().__init__() + self.reverse_sde = ReverseStochasticDifferentialEquation(sde=sde, score_estimator=score_estimator) + + @abstractmethod + @torch.inference_mode() + def forward( + self, + *, + state: torch.Tensor, + time: torch.Tensor, + score_condition: Optional[torch.Tensor] = None, + state_length: Optional[torch.Tensor] = None, + **kwargs, + ): + """Predict the next state of the reverse process. + + Args: + state: current state of the process, shape (B, C, D, T) + time: current time of the process, shape (B,) + score_condition: conditioning for the score estimator + state_length: length of the valid time steps for each example in the batch + + Returns: + New state and mean. + """ + pass + + +class ReverseDiffusionPredictor(Predictor): + """Predict the next state of the reverse process using the reverse diffusion process. + + Args: + sde: forward SDE + score_estimator: neural score estimator + """ + + def __init__(self, sde, score_estimator): + super().__init__(sde=sde, score_estimator=score_estimator) + + @torch.inference_mode() + def forward(self, *, state, time, score_condition=None, state_length=None, **kwargs): + """Predict the next state of the reverse process using the reverse diffusion process. + + Args: + state: current state of the process, shape (B, C, D, T) + time: current time of the process, shape (B,) + score_condition: conditioning for the score estimator + state_length: length of the valid time steps for each example in the batch + + Returns: + New state and mean of the diffusion process. + """ + drift, diffusion = self.reverse_sde.discretize( + state=state, time=time, score_condition=score_condition, state_length=state_length, **kwargs + ) + + # Generate a random sample from a standard normal distribution + z_norm = torch.randn_like(state) + + # Compute the mean of the next state + mean = state - drift + + # Compute new state by sampling + new_state = mean + diffusion * z_norm + + if state_length is not None: + new_state = mask_sequence_tensor(new_state, state_length) + mean = mask_sequence_tensor(mean, state_length) + + return new_state, mean + + +class Corrector(NeuralModule, ABC): + """Corrector for the reverse process. + + Args: + sde: forward SDE + score_estimator: neural score estimator + snr: SNR for Annealed Langevin Dynamics + num_steps: number of steps for the corrector + """ + + def __init__( + self, + sde: Type[StochasticDifferentialEquation], + score_estimator: Type[NeuralModule], + snr: float, + num_steps: int, + ): + super().__init__() + self.sde = sde + self.score_estimator = score_estimator + self.snr = snr + self.num_steps = num_steps + + logging.debug('Initialized %s with', self.__class__.__name__) + logging.debug('\tsnr: %s', snr) + logging.debug('\tnum_steps: %s', num_steps) + + @abstractmethod + @typecheck( + input_types={ + "state": NeuralType(('B', 'C', 'D', 'T'), VoidType()), + "time": NeuralType(tuple('B'), FloatType()), + "score_condition": NeuralType(('B', 'C', 'D', 'T'), VoidType(), optional=True), + "state_length": NeuralType(tuple('B'), LengthsType(), optional=True), + }, + output_types={"state": NeuralType(('B', 'C', 'D', 'T'), VoidType()),}, + ) + @torch.inference_mode() + def forward(self, state, time, score_condition=None, state_length=None): + """ + Args: + state: current state of the process, shape (B, C, D, T) + time: current time of the process, shape (B,) + score_condition: conditioning for the score estimator + state_length: length of the valid time steps for each example in the batch + + Returns: + New state and mean. + """ + pass + + +class AnnealedLangevinDynamics(Corrector): + """Annealed Langevin Dynamics for the reverse process. + + References: + - Song et al., Score-based generative modeling through stochastic differential equations, 2021 + """ + + def __init__(self, sde, **kwargs): + if not isinstance(sde, OrnsteinUhlenbeckVarianceExplodingSDE): + raise ValueError(f'Expected an instance of OrnsteinUhlenbeckVarianceExplodingSDE, got {type(sde)}') + super().__init__(sde=sde, **kwargs) + + @torch.inference_mode() + def forward(self, state, time, score_condition=None, state_length=None): + """Correct the state using Annealed Langevin Dynamics. + + Args: + state: current state of the process, shape (B, C, D, T) + time: current time of the process, shape (B,) + score_condition: conditioning for the score estimator + state_length: length of the valid time steps for each example in the batch + + Returns: + New state and mean of the diffusion process. + + References: + Alg. 4 in http://arxiv.org/abs/2011.13456 + """ + # Compute the standard deviation of the diffusion process + std = self.sde.perturb_kernel_std(time=time) + # View as [B, 1, 1, 1] + std = std.view(-1, *([1] * (state.dim() - 1))) + + for i in range(self.num_steps): + # prepare input for the score estimator, concatenate conditioning along the channel dimension + score_input = state if score_condition is None else torch.cat([state, score_condition], dim=1) + + # calculate the score + score, _ = self.score_estimator(input=score_input, input_length=state_length, condition=time) + + # generate a sample from a standard normal distribution + z_norm = torch.randn_like(state) + + # compute the step size + # note: this is slightly different than in the paper, where std = ||z_norm||_2 / ||score||_2 + step_size = 2 * (self.snr * std).pow(2) + + # update the mean + mean = state + step_size * score + + # update the state + state = mean + z_norm * torch.sqrt(step_size * 2) + + if state_length is not None: + state = mask_sequence_tensor(state, state_length) + mean = mask_sequence_tensor(mean, state_length) + + return state, mean diff --git a/requirements/requirements_asr.txt b/requirements/requirements_asr.txt index b7863714eb2d..30e839fd2ca8 100644 --- a/requirements/requirements_asr.txt +++ b/requirements/requirements_asr.txt @@ -1,5 +1,6 @@ braceexpand editdistance +einops g2p_en ipywidgets jiwer diff --git a/tests/collections/asr/test_asr_datasets.py b/tests/collections/asr/test_asr_datasets.py index 946acb614f11..a2e39628e4cb 100644 --- a/tests/collections/asr/test_asr_datasets.py +++ b/tests/collections/asr/test_asr_datasets.py @@ -809,6 +809,39 @@ def test_list_to_multichannel(self, num_channels, num_targets): # Check the list is converted back to the original signal assert (ASRAudioProcessor.list_to_multichannel(target_list) == golden_target).all() + @pytest.mark.unit + @pytest.mark.parametrize('num_channels', [1, 2]) + def test_processor_process_audio(self, num_channels): + """Test signal normalization in process_audio. + """ + num_samples = 1000 + num_examples = 30 + + signals = ['input_signal', 'target_signal', 'reference_signal'] + + for normalization_signal in [None] + signals: + # Create processor + processor = ASRAudioProcessor( + sample_rate=16000, random_offset=False, normalization_signal=normalization_signal + ) + + # Generate random signals + for n in range(num_examples): + example = {signal: torch.randn(num_channels, num_samples) for signal in signals} + processed_example = processor.process_audio(example) + + # Expected scale + if normalization_signal: + scale = 1.0 / (example[normalization_signal].abs().max() + processor.eps) + else: + scale = 1.0 + + # Make sure all signals are scaled as expected + for signal in signals: + assert torch.allclose( + processed_example[signal], example[signal] * scale + ), f'Failed example {n} signal {signal}' + @pytest.mark.unit def test_audio_collate_fn(self): """Test `_audio_collate_fn` diff --git a/tests/collections/asr/test_asr_losses.py b/tests/collections/asr/test_asr_losses.py index e09fd71e0892..e050e7cc07c3 100644 --- a/tests/collections/asr/test_asr_losses.py +++ b/tests/collections/asr/test_asr_losses.py @@ -17,7 +17,9 @@ import torch from nemo.collections.asr.losses.audio_losses import ( + MSELoss, SDRLoss, + calculate_mse_batch, calculate_sdr_batch, convolution_invariant_target, scale_invariant_target, @@ -271,7 +273,7 @@ def test_sdr_binary_mask(self, num_channels): estimate = target + noise # Limit calculation to masked samples - mask = _rng.integers(low=0, high=2, size=(batch_size, max_num_samples)) + mask = _rng.integers(low=0, high=2, size=(batch_size, num_channels, max_num_samples)) # Tensors for testing the loss tensor_estimate = torch.tensor(estimate) @@ -282,7 +284,9 @@ def test_sdr_binary_mask(self, num_channels): golden_sdr = 0 for b in range(batch_size): sdr = [ - calculate_sdr_numpy(estimate=estimate[b, m, mask[b, :] > 0], target=target[b, m, mask[b, :] > 0]) + calculate_sdr_numpy( + estimate=estimate[b, m, mask[b, m, :] > 0], target=target[b, m, mask[b, m, :] > 0] + ) for m in range(num_channels) ] sdr = np.mean(np.array(sdr)) @@ -467,3 +471,187 @@ def test_sdr_convolution_invariant(self, num_channels: int, filter_length: int): assert np.allclose( uut_sdr_loss.cpu().detach().numpy(), -golden_sdr, atol=atol ), f'SDRLoss not matching for example {n}' + + @pytest.mark.unit + @pytest.mark.parametrize('num_channels', [1, 4]) + @pytest.mark.parametrize('ndim', [3, 4]) + def test_mse(self, num_channels: int, ndim: int): + """Test SDR calculation + """ + batch_size = 8 + num_samples = 50 + num_features = 123 + num_batches = 10 + random_seed = 42 + atol = 1e-6 + + signal_shape = ( + (batch_size, num_channels, num_features, num_samples) + if ndim == 4 + else (batch_size, num_channels, num_samples) + ) + + reduction_dim = (-2, -1) if ndim == 4 else -1 + + mse_loss = MSELoss(ndim=ndim) + + _rng = np.random.default_rng(seed=random_seed) + + for n in range(num_batches): + + # Generate random signal + target = _rng.normal(size=signal_shape) + # Random noise + scaling + noise = _rng.uniform(low=0.01, high=1) * _rng.normal(size=signal_shape) + # Estimate + estimate = target + noise + + # DC bias for both + target += _rng.uniform(low=-1, high=1) + estimate += _rng.uniform(low=-1, high=1) + + # Tensors for testing the loss + tensor_estimate = torch.tensor(estimate) + tensor_target = torch.tensor(target) + + # Reference MSE + golden_mse = np.zeros((batch_size, num_channels)) + for b in range(batch_size): + for m in range(num_channels): + err = estimate[b, m, :] - target[b, m, :] + golden_mse[b, m] = np.mean(np.abs(err) ** 2, axis=reduction_dim) + + # Calculate MSE in torch + uut_mse = calculate_mse_batch(estimate=tensor_estimate, target=tensor_target) + + # Calculate MSE loss + uut_mse_loss = mse_loss(estimate=tensor_estimate, target=tensor_target) + + # Compare torch SDR vs numpy + assert np.allclose( + uut_mse.cpu().detach().numpy(), golden_mse, atol=atol + ), f'MSE not matching for example {n}' + + # Compare SDR loss vs average of torch SDR + assert np.isclose(uut_mse_loss, uut_mse.mean(), atol=atol), f'MSELoss not matching for example {n}' + + @pytest.mark.unit + @pytest.mark.parametrize('num_channels', [1, 4]) + @pytest.mark.parametrize('ndim', [3, 4]) + def test_mse_weighted(self, num_channels: int, ndim: int): + """Test SDR calculation with weighting for channels + """ + batch_size = 8 + num_samples = 50 + num_features = 123 + num_batches = 10 + random_seed = 42 + atol = 1e-6 + + signal_shape = ( + (batch_size, num_channels, num_features, num_samples) + if ndim == 4 + else (batch_size, num_channels, num_samples) + ) + + reduction_dim = (-2, -1) if ndim == 4 else -1 + + _rng = np.random.default_rng(seed=random_seed) + + channel_weight = _rng.uniform(low=0.01, high=1.0, size=num_channels) + channel_weight = channel_weight / np.sum(channel_weight) + mse_loss = MSELoss(weight=channel_weight, ndim=ndim) + + for n in range(num_batches): + + # Generate random signal + target = _rng.normal(size=signal_shape) + # Random noise + scaling + noise = _rng.uniform(low=0.001, high=10) * _rng.normal(size=target.shape) + # Estimate + estimate = target + noise + + # Tensors for testing the loss + tensor_estimate = torch.tensor(estimate) + tensor_target = torch.tensor(target) + + # Reference MSE + golden_mse = 0 + for b in range(batch_size): + mse = [ + np.mean(np.abs(estimate[b, m, :] - target[b, m, :]) ** 2, axis=reduction_dim) + for m in range(num_channels) + ] + # weighted sum + mse = np.sum(np.array(mse) * channel_weight) + golden_mse += mse + golden_mse /= batch_size # average over batch + + # Calculate MSE loss + uut_mse_loss = mse_loss(estimate=tensor_estimate, target=tensor_target) + + # Compare + assert np.allclose( + uut_mse_loss.cpu().detach().numpy(), golden_mse, atol=atol + ), f'MSELoss not matching for example {n}' + + @pytest.mark.unit + @pytest.mark.parametrize('num_channels', [1, 4]) + @pytest.mark.parametrize('ndim', [3, 4]) + def test_mse_input_length(self, num_channels: int, ndim: int): + """Test SDR calculation with input length. + """ + batch_size = 8 + max_num_samples = 50 + num_features = 123 + num_batches = 10 + random_seed = 42 + atol = 1e-6 + + signal_shape = ( + (batch_size, num_channels, num_features, max_num_samples) + if ndim == 4 + else (batch_size, num_channels, max_num_samples) + ) + + reduction_dim = (-2, -1) if ndim == 4 else -1 + + _rng = np.random.default_rng(seed=random_seed) + + mse_loss = MSELoss(ndim=ndim) + + for n in range(num_batches): + + # Generate random signal + target = _rng.normal(size=signal_shape) + # Random noise + scaling + noise = _rng.uniform(low=0.001, high=10) * _rng.normal(size=target.shape) + # Estimate + estimate = target + noise + + # Limit calculation to random input_length samples + input_length = _rng.integers(low=1, high=max_num_samples, size=batch_size) + + # Tensors for testing the loss + tensor_estimate = torch.tensor(estimate) + tensor_target = torch.tensor(target) + tensor_input_length = torch.tensor(input_length) + + # Reference MSE + golden_mse = 0 + for b, b_len in enumerate(input_length): + mse = [ + np.mean(np.abs(estimate[b, m, ..., :b_len] - target[b, m, ..., :b_len]) ** 2, axis=reduction_dim) + for m in range(num_channels) + ] + mse = np.mean(np.array(mse)) + golden_mse += mse + golden_mse /= batch_size # average over batch + + # Calculate MSE + uut_mse_loss = mse_loss(estimate=tensor_estimate, target=tensor_target, input_length=tensor_input_length) + + # Compare + assert np.allclose( + uut_mse_loss.cpu().detach().numpy(), golden_mse, atol=atol + ), f'MSELoss not matching for example {n}' diff --git a/tests/collections/asr/test_audio_preprocessing.py b/tests/collections/asr/test_audio_preprocessing.py index b0875936a7f7..600b9fed44fa 100644 --- a/tests/collections/asr/test_audio_preprocessing.py +++ b/tests/collections/asr/test_audio_preprocessing.py @@ -155,7 +155,11 @@ def test_spec_to_audio(self, fft_length: int, num_channels: int): @pytest.mark.skipif(not HAVE_TORCHAUDIO, reason="Modules in this test require torchaudio") @pytest.mark.parametrize('fft_length', [128, 1024]) @pytest.mark.parametrize('num_channels', [1, 4]) - def test_audio_to_spectrogram_reconstruction(self, fft_length: int, num_channels: int): + @pytest.mark.parametrize('magnitude_power', [0.5, 1, 2]) + @pytest.mark.parametrize('scale', [0.1, 1.0]) + def test_audio_to_spectrogram_reconstruction( + self, fft_length: int, num_channels: int, magnitude_power: float, scale: float + ): """Test analysis and synthesis transform result in a perfect reconstruction. """ batch_size = 4 @@ -169,8 +173,12 @@ def test_audio_to_spectrogram_reconstruction(self, fft_length: int, num_channels hop_lengths = [fft_length // 2, fft_length // 4] for hop_length in hop_lengths: - audio2spec = AudioToSpectrogram(fft_length=fft_length, hop_length=hop_length) - spec2audio = SpectrogramToAudio(fft_length=fft_length, hop_length=hop_length) + audio2spec = AudioToSpectrogram( + fft_length=fft_length, hop_length=hop_length, magnitude_power=magnitude_power, scale=scale + ) + spec2audio = SpectrogramToAudio( + fft_length=fft_length, hop_length=hop_length, magnitude_power=magnitude_power, scale=scale + ) for n in range(num_examples): x = _rng.normal(size=(batch_size, num_channels, num_samples)) From 8c1ce65961c60df8c58817cae6f1cb7b5e5d407a Mon Sep 17 00:00:00 2001 From: Elena Rastorgueva <80532067+erastorgueva-nv@users.noreply.github.com> Date: Wed, 1 May 2024 15:52:36 -0700 Subject: [PATCH 05/73] Fix docs errors and most warnings (#9006) * add various docs fixes Signed-off-by: Elena Rastorgueva * make conf.py changes clearer Signed-off-by: Elena Rastorgueva * fix Duplicate explicit target name error for links Signed-off-by: Elena Rastorgueva * more fixes, mainly citations Signed-off-by: Elena Rastorgueva * fix some code formatting Signed-off-by: Elena Rastorgueva * update hf space iframe link Signed-off-by: Elena Rastorgueva * fix new ERRORs Signed-off-by: Elena Rastorgueva * Update docs Signed-off-by: yaoyu-33 --------- Signed-off-by: Elena Rastorgueva Signed-off-by: yaoyu-33 Co-authored-by: yaoyu-33 Co-authored-by: Eric Harper --- docs/source/asr/datasets.rst | 53 ++++++------ docs/source/asr/intro.rst | 4 +- docs/source/asr/models.rst | 4 +- docs/source/asr/speech_intent_slot/api.rst | 2 + docs/source/asr/ssl/api.rst | 2 + docs/source/ckpt_converters/dev_guide.rst | 4 +- docs/source/ckpt_converters/user_guide.rst | 84 +++++++++---------- docs/source/conf.py | 3 +- docs/source/core/adapters/api.rst | 7 ++ docs/source/core/adapters/components.rst | 12 ++- docs/source/core/adapters/intro.rst | 1 + docs/source/core/core.rst | 11 +-- docs/source/core/exp_manager.rst | 1 + docs/source/core/export.rst | 3 +- docs/source/core/neural_types.rst | 3 + docs/source/features/memory_optimizations.rst | 13 +-- docs/source/multimodal/api.rst | 9 +- docs/source/multimodal/mllm/checkpoint.rst | 10 +-- docs/source/multimodal/nerf/dreamfusion.rst | 6 +- .../source/multimodal/text2img/controlnet.rst | 8 +- .../source/multimodal/text2img/dreambooth.rst | 8 +- docs/source/multimodal/text2img/imagen.rst | 10 +-- docs/source/multimodal/text2img/insp2p.rst | 6 +- docs/source/multimodal/text2img/intro.rst | 1 + .../multimodal/text2img/sdxl_quantization.rst | 10 ++- docs/source/multimodal/vlm/clip.rst | 6 +- docs/source/nlp/api.rst | 19 ++--- docs/source/nlp/information_retrieval.rst | 2 +- .../machine_translation.rst | 8 +- .../nlp/nemo_megatron/gpt/gpt_training.rst | 2 +- .../nemo_megatron/positional_embeddings.rst | 28 +++---- ...ation_and_capitalization_lexical_audio.rst | 6 +- .../text_normalization_as_tagging.rst | 8 +- docs/source/starthere/best-practices.rst | 2 +- docs/source/starthere/migration-guide.rst | 20 ++--- docs/source/tools/nemo_forced_aligner.rst | 8 +- docs/source/vision/checkpoint.rst | 2 +- docs/source/vision/vit.rst | 6 +- nemo/collections/asr/models/asr_model.py | 4 +- nemo/collections/asr/models/msdd_models.py | 13 ++- nemo/collections/asr/modules/rnnt.py | 23 +++-- .../tokenizers/huggingface/auto_tokenizer.py | 11 ++- .../language_modeling/megatron/t5_dataset.py | 3 +- .../megatron/t5_prompt_learning_dataset.py | 4 +- .../language_modeling/megatron/ul2_dataset.py | 4 +- .../megatron_bert_embedding_model.py | 8 +- .../language_modeling/megatron_bert_model.py | 8 +- .../language_modeling/megatron_gpt_model.py | 8 +- .../megatron_lm_encoder_decoder_model.py | 12 ++- .../common/transformer/text_generation.py | 57 ++++++------- .../megatron_vit_classification_models.py | 8 +- nemo/core/classes/dataset.py | 15 ++-- nemo/utils/exp_manager.py | 4 +- 53 files changed, 306 insertions(+), 268 deletions(-) diff --git a/docs/source/asr/datasets.rst b/docs/source/asr/datasets.rst index b4656eec3f3f..a6e9cbe96c63 100644 --- a/docs/source/asr/datasets.rst +++ b/docs/source/asr/datasets.rst @@ -261,11 +261,6 @@ Semi Sorted Batching Sorting samples by duration and spliting them into batches speeds up training, but can degrade the quality of the model. To avoid quality degradation and maintain some randomness in the partitioning process, we add pseudo noise to the sample length when sorting. - .. image:: images/ssb.png - :align: center - :alt: semi sorted batching - :scale: 50% - It may result into training speeedup of more than 40 percent with the same quality. To enable and use semi sorted batching add some lines in config. .. code:: @@ -772,30 +767,30 @@ To enable multimodal dataloading, we provide several configuration options: Example 3. Combine an ASR (audio-text) dataset with an MT (text-only) dataset so that mini-batches have some examples from both datasets. Provide a custom prompt field for both datasets (to be leveraged by a relevant dataset class): -```yaml -use_multimodal_sampling: true -batch_tokens: 1024 -token_equivalent_duration: 0.08 # 0.01 frame shift * 8 subsampling factor -quadratic_factor: 50 -num_buckets: 30 -use_bucketing: true -input_cfg: - - type: nemo_tarred - manifest_filepath: /path/to/manifest__OP_0..512_CL_.json - tarred_audio_filepath: /path/to/tarred_audio/audio__OP_0..512_CL_.tar - weight: 0.5 - tags: - lang: en - prompt: "Given the following recording, transcribe what the person is saying:" - - type: txt_pair - source_path: /path/to/en__OP_0..512_CL_.txt - target_path: /path/to/pl__OP_0..512_CL_.txt - source_language: en - target_language: pl - weight: 0.5 - tags: - prompt: "Translate the following text to Polish:" -``` +.. code-block:: yaml + + use_multimodal_sampling: true + batch_tokens: 1024 + token_equivalent_duration: 0.08 # 0.01 frame shift * 8 subsampling factor + quadratic_factor: 50 + num_buckets: 30 + use_bucketing: true + input_cfg: + - type: nemo_tarred + manifest_filepath: /path/to/manifest__OP_0..512_CL_.json + tarred_audio_filepath: /path/to/tarred_audio/audio__OP_0..512_CL_.tar + weight: 0.5 + tags: + lang: en + prompt: "Given the following recording, transcribe what the person is saying:" + - type: txt_pair + source_path: /path/to/en__OP_0..512_CL_.txt + target_path: /path/to/pl__OP_0..512_CL_.txt + source_language: en + target_language: pl + weight: 0.5 + tags: + prompt: "Translate the following text to Polish:" .. caution:: We strongly recommend to use multiple shards for text files as well so that different nodes and dataloading workers are able to randomize the order of text iteration. Otherwise, multi-GPU training has a high risk of duplication of text examples. diff --git a/docs/source/asr/intro.rst b/docs/source/asr/intro.rst index 7d1270af1267..d353b4d983dd 100644 --- a/docs/source/asr/intro.rst +++ b/docs/source/asr/intro.rst @@ -156,11 +156,11 @@ Canary-1B is a multi-lingual, multi-task model, supporting automatic speech-to-t .. raw:: html - diff --git a/docs/source/asr/models.rst b/docs/source/asr/models.rst index 97dafcb2bf6d..f002137beb0f 100644 --- a/docs/source/asr/models.rst +++ b/docs/source/asr/models.rst @@ -46,12 +46,14 @@ HuggingFace Spaces to try out Parakeet models in your browser: * `Parakeet-TDT-1.1B `__ space .. _Conformer_model: + Conformer --------- + .. _Conformer-CTC_model: + Conformer-CTC ~~~~~~~~~~~~~ -------------- Conformer-CTC is a CTC-based variant of the Conformer model introduced in :cite:`asr-models-gulati2020conformer`. Conformer-CTC has a similar encoder as the original Conformer but uses CTC loss and decoding instead of RNNT/Transducer loss, which makes it a non-autoregressive model. diff --git a/docs/source/asr/speech_intent_slot/api.rst b/docs/source/asr/speech_intent_slot/api.rst index 735c583f9115..d45f24f807f6 100644 --- a/docs/source/asr/speech_intent_slot/api.rst +++ b/docs/source/asr/speech_intent_slot/api.rst @@ -15,8 +15,10 @@ Mixins .. autoclass:: nemo.collections.asr.parts.mixins.ASRModuleMixin :show-inheritance: :members: + :no-index: .. autoclass:: nemo.collections.asr.parts.mixins.ASRBPEMixin :show-inheritance: :members: + :no-index: diff --git a/docs/source/asr/ssl/api.rst b/docs/source/asr/ssl/api.rst index 7103243a4b20..8e6f83986032 100644 --- a/docs/source/asr/ssl/api.rst +++ b/docs/source/asr/ssl/api.rst @@ -15,10 +15,12 @@ Mixins .. autoclass:: nemo.collections.asr.parts.mixins.mixins.ASRModuleMixin :show-inheritance: :members: + :no-index: .. autoclass:: nemo.core.classes.mixins.access_mixins.AccessMixin :show-inheritance: :members: + :no-index: diff --git a/docs/source/ckpt_converters/dev_guide.rst b/docs/source/ckpt_converters/dev_guide.rst index 9faa752df2e1..601e69749b64 100644 --- a/docs/source/ckpt_converters/dev_guide.rst +++ b/docs/source/ckpt_converters/dev_guide.rst @@ -48,7 +48,7 @@ Script Placement and Naming Conventions Code Template ------------- -Below template tries to address the 11 steps in the guideline part. Please also use `Gemma Huggingface to NeMo converter `_ as an full example for development. +Below template tries to address the 11 steps in the guideline part. Please also use `Gemma Huggingface to NeMo converter `__ as an full example for development. .. code-block:: python @@ -210,7 +210,7 @@ A Simple Guide for Model Mapping and Conversion 2. **Common issues when converting: results not matching between Community model and NeMo model**: - a. Megatron Core uses a special QKV layout, which needs careful handling and reshaping from community models, especially when GQA or MQA is used. Refer to the `Gemma Huggingface to NeMo converter `_ for guidance. + a. Megatron Core uses a special QKV layout, which needs careful handling and reshaping from community models, especially when GQA or MQA is used. Refer to the `Gemma Huggingface to NeMo converter `__ for guidance. b. GLU Variants weights could also be a common source of error. In Megatron Core, the regular feedforward projection weights and gated forward weights are fused together, requiring careful attention to the order of these two. Refer to the `Gemma Huggingface to NeMo converter `_ for more details. diff --git a/docs/source/ckpt_converters/user_guide.rst b/docs/source/ckpt_converters/user_guide.rst index 9de22f4b5994..451679a7e3ae 100644 --- a/docs/source/ckpt_converters/user_guide.rst +++ b/docs/source/ckpt_converters/user_guide.rst @@ -6,45 +6,45 @@ This guide provides instructions on how to use the conversion scripts to convert Support Matrix -------------- -+----------------------+------------------+---------------------+--------------------------------------------------------------------------------------------------------------------+ -| Conversion | From | To | Github Link | -+======================+==================+=====================+====================================================================================================================+ -| Baichuan | Hugging Face | NeMo | `Link `_ | -+----------------------+------------------+---------------------+--------------------------------------------------------------------------------------------------------------------+ -| Baichuan | NeMo | Hugging Face | `Link `_ | -+----------------------+------------------+---------------------+--------------------------------------------------------------------------------------------------------------------+ -| BERT | Hugging Face | NeMo | `Link `_ | -+----------------------+------------------+---------------------+--------------------------------------------------------------------------------------------------------------------+ -| BERT | NeMo | Hugging Face | `Link `_ | -+----------------------+------------------+---------------------+--------------------------------------------------------------------------------------------------------------------+ -| Falcon | Hugging Face | NeMo | `Link `_ | -+----------------------+------------------+---------------------+--------------------------------------------------------------------------------------------------------------------+ -| Falcon | NeMo | Hugging Face | `Link `_ | -+----------------------+------------------+---------------------+--------------------------------------------------------------------------------------------------------------------+ -| Gemma | Hugging Face | NeMo | `Link `_ | -+----------------------+------------------+---------------------+--------------------------------------------------------------------------------------------------------------------+ -| Gemma | JAX | NeMo | `Link `_ | -+----------------------+------------------+---------------------+--------------------------------------------------------------------------------------------------------------------+ -| Gemma | PyTorch | NeMo | `Link `_ | -+----------------------+------------------+---------------------+--------------------------------------------------------------------------------------------------------------------+ -| GPT/LLaMA | NeMo (Legacy) | NeMo (Megatron-Core)| `Link `_ | -+----------------------+------------------+---------------------+--------------------------------------------------------------------------------------------------------------------+ -| LLaMA | Hugging Face | NeMo | `Link `_ | -+----------------------+------------------+---------------------+--------------------------------------------------------------------------------------------------------------------+ -| LLaMA | NeMo | Hugging Face | `Link `_ | -+----------------------+------------------+---------------------+--------------------------------------------------------------------------------------------------------------------+ -| Mistral 7B | Hugging Face | NeMo | `Link `_ | -+----------------------+------------------+---------------------+--------------------------------------------------------------------------------------------------------------------+ -| Mistral 7B | NeMo | Hugging Face | `Link `_ | -+----------------------+------------------+---------------------+--------------------------------------------------------------------------------------------------------------------+ -| Mixtral | Hugging Face | NeMo | `Link `_ | -+----------------------+------------------+---------------------+--------------------------------------------------------------------------------------------------------------------+ -| Mixtral | NeMo | Hugging Face | `Link `_ | -+----------------------+------------------+---------------------+--------------------------------------------------------------------------------------------------------------------+ -| MPT | Hugging Face | NeMo | `Link `_ | -+----------------------+------------------+---------------------+--------------------------------------------------------------------------------------------------------------------+ -| Starcoder | Hugging Face | NeMo | `Link `_ | -+----------------------+------------------+---------------------+--------------------------------------------------------------------------------------------------------------------+ ++----------------------+------------------+---------------------+---------------------------------------------------------------------------------------------------------------------+ +| Conversion | From | To | Github Link | ++======================+==================+=====================+=====================================================================================================================+ +| Baichuan | Hugging Face | NeMo | `Link `__ | ++----------------------+------------------+---------------------+---------------------------------------------------------------------------------------------------------------------+ +| Baichuan | NeMo | Hugging Face | `Link `__ | ++----------------------+------------------+---------------------+---------------------------------------------------------------------------------------------------------------------+ +| BERT | Hugging Face | NeMo | `Link `__ | ++----------------------+------------------+---------------------+---------------------------------------------------------------------------------------------------------------------+ +| BERT | NeMo | Hugging Face | `Link `__ | ++----------------------+------------------+---------------------+---------------------------------------------------------------------------------------------------------------------+ +| Falcon | Hugging Face | NeMo | `Link `__ | ++----------------------+------------------+---------------------+---------------------------------------------------------------------------------------------------------------------+ +| Falcon | NeMo | Hugging Face | `Link `__ | ++----------------------+------------------+---------------------+---------------------------------------------------------------------------------------------------------------------+ +| Gemma | Hugging Face | NeMo | `Link `__ | ++----------------------+------------------+---------------------+---------------------------------------------------------------------------------------------------------------------+ +| Gemma | JAX | NeMo | `Link `__ | ++----------------------+------------------+---------------------+---------------------------------------------------------------------------------------------------------------------+ +| Gemma | PyTorch | NeMo | `Link `__ | ++----------------------+------------------+---------------------+---------------------------------------------------------------------------------------------------------------------+ +| GPT/LLaMA | NeMo (Legacy) | NeMo (Megatron-Core)| `Link `__ | ++----------------------+------------------+---------------------+---------------------------------------------------------------------------------------------------------------------+ +| LLaMA | Hugging Face | NeMo | `Link `__ | ++----------------------+------------------+---------------------+---------------------------------------------------------------------------------------------------------------------+ +| LLaMA | NeMo | Hugging Face | `Link `__ | ++----------------------+------------------+---------------------+---------------------------------------------------------------------------------------------------------------------+ +| Mistral 7B | Hugging Face | NeMo | `Link `__ | ++----------------------+------------------+---------------------+---------------------------------------------------------------------------------------------------------------------+ +| Mistral 7B | NeMo | Hugging Face | `Link `__ | ++----------------------+------------------+---------------------+---------------------------------------------------------------------------------------------------------------------+ +| Mixtral | Hugging Face | NeMo | `Link `__ | ++----------------------+------------------+---------------------+---------------------------------------------------------------------------------------------------------------------+ +| Mixtral | NeMo | Hugging Face | `Link `__ | ++----------------------+------------------+---------------------+---------------------------------------------------------------------------------------------------------------------+ +| MPT | Hugging Face | NeMo | `Link `__ | ++----------------------+------------------+---------------------+---------------------------------------------------------------------------------------------------------------------+ +| Starcoder | Hugging Face | NeMo | `Link `__ | ++----------------------+------------------+---------------------+---------------------------------------------------------------------------------------------------------------------+ Convert Hugging Face LLaMA Checkpoints to NeMo @@ -54,7 +54,7 @@ To convert a Hugging Face LLaMA checkpoint into a NeMo checkpoint, use the follo .. code-block:: bash - python convert_llama_hf_to_nemo.py>`_ \ + python convert_llama_hf_to_nemo.py \ --input_name_or_path \ --output_path @@ -67,7 +67,7 @@ To convert a NeMo checkpoint into a Hugging Face LLaMA checkpoint, you have two .. code-block:: bash - python convert__nemo_to_hf.py>`_ \ + python convert__nemo_to_hf.py \ --input_name_or_path /path/to/file.nemo or /path/to/extracted_folder \ --output_path /path/to/pytorch_model.bin @@ -75,7 +75,7 @@ To convert a NeMo checkpoint into a Hugging Face LLaMA checkpoint, you have two .. code-block:: bash - python convert__nemo_to_hf.py>`_ \ + python convert__nemo_to_hf.py \ --input_name_or_path /path/to/file.nemo or /path/to/extracted_folder \ --output_path /path/to/model_folder \ --hf_input_path /path/to/input_hf_folder \ diff --git a/docs/source/conf.py b/docs/source/conf.py index e8fba7457605..c599f630d7f7 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -113,10 +113,9 @@ "sphinx.ext.viewcode", "sphinx.ext.napoleon", "sphinx.ext.githubpages", - "sphinxcontrib.bibtex", "sphinx.ext.inheritance_diagram", "sphinx.ext.intersphinx", - "sphinx.ext.autosectionlabel", + # "sphinx.ext.autosectionlabel", "sphinxcontrib.bibtex", "sphinx_copybutton", "sphinxext.opengraph", diff --git a/docs/source/core/adapters/api.rst b/docs/source/core/adapters/api.rst index b0f2a8e13610..8922c72d63eb 100644 --- a/docs/source/core/adapters/api.rst +++ b/docs/source/core/adapters/api.rst @@ -9,6 +9,7 @@ Core :members: :member-order: bysource :undoc-members: adapter_module_names + :no-index: ----- @@ -17,6 +18,7 @@ Core :members: :member-order: bysource :undoc-members: adapter_module_names + :no-index: ----- @@ -28,6 +30,7 @@ Adapter Networks :show-inheritance: :members: :member-order: bysource + :no-index: ----- @@ -35,6 +38,7 @@ Adapter Networks :show-inheritance: :members: :member-order: bysource + :no-index: ----- @@ -47,6 +51,7 @@ Adapter Strategies :members: :member-order: bysource :undoc-members: adapter_module_names + :no-index: ----- @@ -55,6 +60,7 @@ Adapter Strategies :members: :member-order: bysource :undoc-members: adapter_module_names + :no-index: ----- @@ -63,3 +69,4 @@ Adapter Strategies :members: :member-order: bysource :undoc-members: adapter_module_names + :no-index: diff --git a/docs/source/core/adapters/components.rst b/docs/source/core/adapters/components.rst index cc2ea0b525df..d8bed1b23a75 100644 --- a/docs/source/core/adapters/components.rst +++ b/docs/source/core/adapters/components.rst @@ -8,7 +8,7 @@ An adapter module can be any pytorch module, but it must follow certain straight 1) The model accepts an input of some input dimension, and its output must match this dimension. 2) Ideally, the module is initialized such that the output of the adapter when initialized is such that it does not modify the original input. This allows the model to produce the same output results, even when additional parameters have been added. -According to Junxian et al :cite:`adapters-Junxian2021unified`, we can consider an adapter being represented as three components - +According to Junxian et al :cite:`adapters-components-Junxian2021unified`, we can consider an adapter being represented as three components - 1) Functional form - the trainable parameters that will modify the input 2) Insertion form - Where the adapter outputs are integrated with the original input. The input to the adapters can be the last output of the layer, the input to some attention layer, or even the original input to the module itself (before even the modules forward pass). @@ -17,7 +17,7 @@ According to Junxian et al :cite:`adapters-Junxian2021unified`, we can consider Functional Form - Adapter Networks ================================== -Adapter modules represent the functional form of the adapter. We discuss an example of a most commonly used adapter module found in literature, titled the ``LinearAdapter`` (or Houlsby Adapter) :cite:`adapters-houlsby2019adapter`. +Adapter modules represent the functional form of the adapter. We discuss an example of a most commonly used adapter module found in literature, titled the ``LinearAdapter`` (or Houlsby Adapter) :cite:`adapters-components-houlsby2019adapter`. .. note:: @@ -28,6 +28,7 @@ Adapter modules represent the functional form of the adapter. We discuss an exam :show-inheritance: :members: :member-order: bysource + :no-index: ----- @@ -35,12 +36,13 @@ Adapter modules represent the functional form of the adapter. We discuss an exam :show-inheritance: :members: :member-order: bysource + :no-index: Insertion Form - Module Adapters -------------------------------- -Adapter modules can be integrated into many different locations of a given module. For example, it is possible to have an adapter that affects only the outputs of the final layer in each module. We can also have a ``Parallel Adapter`` :cite:`adapters-Junxian2021unified` that operates at the input of the module itself, in parallel to the forward pass of the module. Yet another insertion location is inside the Multi Head Attention Layers. +Adapter modules can be integrated into many different locations of a given module. For example, it is possible to have an adapter that affects only the outputs of the final layer in each module. We can also have a ``Parallel Adapter`` :cite:`adapters-components-Junxian2021unified` that operates at the input of the module itself, in parallel to the forward pass of the module. Yet another insertion location is inside the Multi Head Attention Layers. On top of this, while adapters are commonly used only in the layers containing the most parameters (say the Encoder of a network), some models can support adapters in multiple locations (Encoder-Decoder architecture for Language Models, Machine Translation, or even Encoder-Decoder-Joint for ASR with Transducer Loss). As such, NeMo utilizes the concept of ``Module Adapters``. @@ -70,6 +72,7 @@ We discuss a simple residual additional connection strategy below - that accepts :members: :member-order: bysource :undoc-members: adapter_module_names + :no-index: ----- @@ -78,6 +81,7 @@ We discuss a simple residual additional connection strategy below - that accepts :members: :member-order: bysource :undoc-members: adapter_module_names + :no-index: ----- @@ -87,4 +91,4 @@ References .. bibliography:: ./adapter_bib.bib :style: plain - :keyprefix: adapters- + :keyprefix: adapters-components- diff --git a/docs/source/core/adapters/intro.rst b/docs/source/core/adapters/intro.rst index fd94c8d23446..8c5e9cbc8895 100644 --- a/docs/source/core/adapters/intro.rst +++ b/docs/source/core/adapters/intro.rst @@ -144,4 +144,5 @@ References .. bibliography:: ./adapter_bib.bib :style: plain + :labelprefix: adapters :keyprefix: adapters- diff --git a/docs/source/core/core.rst b/docs/source/core/core.rst index 6e5efa56d5f0..1c9325cf0a96 100644 --- a/docs/source/core/core.rst +++ b/docs/source/core/core.rst @@ -16,9 +16,10 @@ NeMo models contain everything needed to train and reproduce Conversational AI m NeMo uses `Hydra `_ for configuring both NeMo models and the PyTorch Lightning Trainer. -.. note:: Every NeMo model has an example configuration file and training script that can be found `here `_. +.. note:: + Every NeMo model has an example configuration file and training script that can be found `here `__. -The end result of using NeMo, `Pytorch Lightning `_, and Hydra is that NeMo models all have the same look and feel and are also fully compatible with the PyTorch ecosystem. +The end result of using NeMo, `Pytorch Lightning `__, and Hydra is that NeMo models all have the same look and feel and are also fully compatible with the PyTorch ecosystem. Pretrained ---------- @@ -42,14 +43,14 @@ To see all available pretrained models for a specific NeMo model, use the ``list For detailed information on the available pretrained models, refer to the collections documentation: -- :ref:`Automatic Speech Recognition (ASR)` +- :doc:`Automatic Speech Recognition (ASR) <../asr/intro>` - :doc:`Natural Language Processing (NLP) <../nlp/models>` - :doc:`Text-to-Speech Synthesis (TTS) <../tts/intro>` Training -------- -NeMo leverages `PyTorch Lightning `_ for model training. PyTorch Lightning lets NeMo decouple the +NeMo leverages `PyTorch Lightning `__ for model training. PyTorch Lightning lets NeMo decouple the conversational AI code from the PyTorch training code. This means that NeMo users can focus on their domain (ASR, NLP, TTS) and build complex AI applications without having to rewrite boiler plate code for PyTorch training. @@ -298,7 +299,7 @@ With NeMo and Hydra, every aspect of model training can be modified from the com of experiments on compute clusters or for quickly testing parameters while developing. All NeMo `examples `_ come with instructions on how to -run the training/inference script from the command-line (see `here `_ +run the training/inference script from the command-line (see `here `__ for an example). With Hydra, arguments are set using the ``=`` operator: diff --git a/docs/source/core/exp_manager.rst b/docs/source/core/exp_manager.rst index b44d27c38b4b..efb55b0feabb 100644 --- a/docs/source/core/exp_manager.rst +++ b/docs/source/core/exp_manager.rst @@ -379,3 +379,4 @@ ExpManagerConfig :show-inheritance: :members: :member-order: bysource + :no-index: diff --git a/docs/source/core/export.rst b/docs/source/core/export.rst index 990769452a5c..c53dd8159a60 100644 --- a/docs/source/core/export.rst +++ b/docs/source/core/export.rst @@ -194,7 +194,7 @@ To facilitate that, the hooks below are provided. To export, for example, 'encod First goes the one receiving input (input_example) """ -Some nertworks may be exported differently according to user-settable options (like ragged batch support for TTS or cache support for ASR). To facilitate that - `set_export_config()` method is provided by Exportable to set key/value pairs to predefined model.export_config dictionary, to be used during the export: +Some networks may be exported differently according to user-settable options (like ragged batch support for TTS or cache support for ASR). To facilitate that - `set_export_config()` method is provided by Exportable to set key/value pairs to predefined model.export_config dictionary, to be used during the export: .. code-block:: Python @@ -202,6 +202,7 @@ Some nertworks may be exported differently according to user-settable options (l """ Sets/updates export_config dictionary """ + Also, if an action hook on setting config is desired, this method may be overloaded by `Exportable` descendants to include one. An example can be found in ``/nemo/collections/asr/models/rnnt_models.py``. diff --git a/docs/source/core/neural_types.rst b/docs/source/core/neural_types.rst index 9003b9ca5203..ec7d94336c05 100644 --- a/docs/source/core/neural_types.rst +++ b/docs/source/core/neural_types.rst @@ -24,6 +24,7 @@ Types are implemented in ``nemo.core.neural_types.NeuralType`` class. When you i are expected to include both *axes* information and *element type* information. .. autoclass:: nemo.core.neural_types.NeuralType + :no-index: Type Comparison Results ----------------------- @@ -31,6 +32,7 @@ Type Comparison Results When comparing two neural types, the following comparison results are generated. .. autoclass:: nemo.core.neural_types.NeuralTypeComparisonResult + :no-index: Examples -------- @@ -113,6 +115,7 @@ Custom element types It is possible to create user-defined element types to express the semantics of elements in your tensors. To do so, the user will need to inherit and implement abstract methods of the ``nemo.core.neural_types.elements.ElementType`` class .. autoclass:: nemo.core.neural_types.elements.ElementType + :no-index: Note that element types can be parametrized. Consider this example where it distinguishes between audio sampled at 8Khz and 16Khz. diff --git a/docs/source/features/memory_optimizations.rst b/docs/source/features/memory_optimizations.rst index 0e0b3ad84402..d72d54ab7c2c 100644 --- a/docs/source/features/memory_optimizations.rst +++ b/docs/source/features/memory_optimizations.rst @@ -3,7 +3,7 @@ Memory Optimizations Parallelism ----------- -Refer to :doc:`Parallelism <./parallelism>`. +Refer to :doc:`Parallelism <./parallelisms>`. Flash Attention --------------- @@ -20,10 +20,8 @@ In the NeMo Framework, Flash Attention is supported through the Transformer Engi For more details on the supported Dot Attention backend, please refer to the Transformer Engine source code available at `Transformer Engine's Attention Mechanism `_. -.. bibliography:: ./nlp_all.bib - :style: plain - :labelprefix: nlp-megatron - :keyprefix: nlp-megatron- +Activation Recomputation +------------------------ Overview ^^^^^^^^ @@ -41,8 +39,3 @@ Selective Activation Recomputation This method reduces memory footprint of activations significantly via smart activation checkpointing. This approach involves selectively storing only crucial activations and recomputing the others as needed. It is particularly useful in large models to minimize memory usage while controlling the computational cost. Refer to "Reducing Activation Recomputation in Large Transformer Models" for more details: https://arxiv.org/abs/2205.05198 - -.. bibliography:: ./nlp_all.bib - :style: plain - :labelprefix: nlp-megatron - :keyprefix: nlp-megatron- \ No newline at end of file diff --git a/docs/source/multimodal/api.rst b/docs/source/multimodal/api.rst index d6f96e6c6ea4..3228cd76d4ad 100644 --- a/docs/source/multimodal/api.rst +++ b/docs/source/multimodal/api.rst @@ -8,6 +8,7 @@ Model Classes :show-inheritance: :no-members: :members: __init__, configure_optimizers + :no-index: .. autoclass:: nemo.collections.multimodal.models.text_to_image.stable_diffusion.ldm.ddpm.MegatronLatentDiffusion @@ -16,18 +17,18 @@ Model Classes :members: __init__, training_step, validation_step, setup, build_train_valid_test_datasets -.. autoclass:: nemo.collections.multimodal.models.dreambooth.dreambooth.MegatronDreamBooth +.. autoclass:: nemo.collections.multimodal.models.text_to_image.dreambooth.dreambooth.MegatronDreamBooth :show-inheritance: :no-members: :members: __init__, training_step, validation_step, setup, build_train_valid_test_datasets -.. autoclass:: nemo.collections.multimodal.models.controlnet.controlnet.MegatronControlNet +.. autoclass:: nemo.collections.multimodal.models.text_to_image.controlnet.controlnet.MegatronControlNet :show-inheritance: :no-members: :members: __init__, training_step, validation_step, setup, build_train_valid_test_datasets -.. autoclass:: nemo.collections.multimodal.models.imagen.imagen.MegatronImagen +.. autoclass:: nemo.collections.multimodal.models.text_to_image.imagen.imagen.MegatronImagen :show-inheritance: :no-members: :members: __init__, training_step, validation_step, setup, build_train_valid_test_datasets @@ -65,7 +66,7 @@ Modules :members: __init__, encode -.. autoclass:: nemo.collections.multimodal.models.controlnet.controlnet.ControlledUnetModel +.. autoclass:: nemo.collections.multimodal.models.text_to_image.controlnet.controlnet.ControlledUnetModel :show-inheritance: :no-members: :members: forward diff --git a/docs/source/multimodal/mllm/checkpoint.rst b/docs/source/multimodal/mllm/checkpoint.rst index 46c6da631ba2..d1fe7b651e66 100644 --- a/docs/source/multimodal/mllm/checkpoint.rst +++ b/docs/source/multimodal/mllm/checkpoint.rst @@ -41,7 +41,7 @@ Converting Local Checkpoints The training script only auto-converts the final checkpoint into the ``.nemo`` format. To evaluate intermediate training checkpoints, conversion to ``.nemo`` might be needed. For this: -.. code-block:: python +.. code-block:: bash python -m torch.distributed.launch --nproc_per_node= * \ examples/multimodal/convert_ckpt_to_nemo.py \ @@ -59,12 +59,12 @@ NeVA Checkpoints Currently, the conversion mainly supports LLaVA checkpoints based on "llama-2 chat" checkpoints. As a reference, we'll consider the checkpoint `llava-llama-2-13b-chat-lightning-preview `_. -After downloading this checkpoint and saving it at `/path/to/llava-llama-2-13b-chat-lightning-preview`, undertake the following procedures: +After downloading this checkpoint and saving it at ``/path/to/llava-llama-2-13b-chat-lightning-preview``, undertake the following procedures: Modifying the Tokenizer """"""""""""""""""""""" -NeMo mandates adding specific tokens to the tokenizer model for peak performance. To modify an existing tokenizer located in `/path/to/llava-llama-2-13b-chat-lightning-preview/tokenizer`, execute the following in the NeMo container: +NeMo mandates adding specific tokens to the tokenizer model for peak performance. To modify an existing tokenizer located in ``/path/to/llava-llama-2-13b-chat-lightning-preview/tokenizer``, execute the following in the NeMo container: .. code-block:: bash @@ -82,7 +82,7 @@ Checkpoint Conversion For conversion: -.. code-block:: python +.. code-block:: bash python examples/multimodal/mllm/neva/convert_hf_llava_to_neva.py \ --in-file /path/to/llava-llama-2-13b-chat-lightning-preview \ @@ -99,7 +99,7 @@ NeVA Checkpoints Adjust model parallelism with: -.. code-block:: python +.. code-block:: bash python examples/nlp/language_modeling/megatron_change_num_partitions.py \ --model_file=/path/to/source.nemo \ diff --git a/docs/source/multimodal/nerf/dreamfusion.rst b/docs/source/multimodal/nerf/dreamfusion.rst index a9f2f630bcdd..d6c926392556 100644 --- a/docs/source/multimodal/nerf/dreamfusion.rst +++ b/docs/source/multimodal/nerf/dreamfusion.rst @@ -3,7 +3,7 @@ DreamFusion Model Introduction ------------------- -DreamFusion :cite:`mm-models-poole2022dreamfusion` uses a pretrained text-to-image diffusion model to perform +DreamFusion :cite:`mm-models-df-poole2022dreamfusion` uses a pretrained text-to-image diffusion model to perform text-to-3D synthesis. The model uses a loss based on probability density distillation that enables the use of a 2D diffusion model as a prior for optimization of a parametric image generator. @@ -306,5 +306,5 @@ References .. bibliography:: ../mm_all.bib :style: plain :filter: docname in docnames - :labelprefix: MM-MODELS - :keyprefix: mm-models- + :labelprefix: MM-MODELS-DF + :keyprefix: mm-models-df- diff --git a/docs/source/multimodal/text2img/controlnet.rst b/docs/source/multimodal/text2img/controlnet.rst index 6eae36dd017a..b9f55031b79d 100644 --- a/docs/source/multimodal/text2img/controlnet.rst +++ b/docs/source/multimodal/text2img/controlnet.rst @@ -4,12 +4,12 @@ ControlNet Model Introduction -------------------- -ControlNet :cite:`mm-models-controlnetgithub` is a neural network structure to control diffusion models by adding extra conditions. +ControlNet :cite:`mm-models-cn-controlnetgithub` is a neural network structure to control diffusion models by adding extra conditions. It copies the weights of neural network blocks into a "locked" copy and a "trainable" copy. The "trainable" one learns your condition. The "locked" one preserves your model. In this way, the ControlNet can reuse the SD encoder as a deep, strong, robust, and powerful backbone to learn diverse controls. NeMo Multimodal provides a training pipeline and example implementation for generating images based on segmentation maps. Users have the flexibility to explore other implementations using their own control input dataset and recipe. .. image:: ./images/controlnet-structure.png - :alt: ControlNet structure on stable diffusion (See :cite:`mm-models-controlnetgithub`) + :alt: ControlNet structure on stable diffusion (See :cite:`mm-models-cn-controlnetgithub`) ControlNet Dataset @@ -102,5 +102,5 @@ Reference .. bibliography:: ../mm_all.bib :style: plain :filter: docname in docnames - :labelprefix: MM-MODELS - :keyprefix: mm-models- + :labelprefix: MM-MODELS-CN + :keyprefix: mm-models-cn- diff --git a/docs/source/multimodal/text2img/dreambooth.rst b/docs/source/multimodal/text2img/dreambooth.rst index fa7e52a7ccbb..1c6a420d49f2 100644 --- a/docs/source/multimodal/text2img/dreambooth.rst +++ b/docs/source/multimodal/text2img/dreambooth.rst @@ -5,7 +5,7 @@ DreamBooth Model Introduction -------------------- -DreamBooth :cite:`mm-models-dreamboothpaper` is a fine-tuning technique and a solution to personalize large diffusion models like Stable Diffusion, which are powerful but lack the +DreamBooth :cite:`mm-models-db-dreamboothpaper` is a fine-tuning technique and a solution to personalize large diffusion models like Stable Diffusion, which are powerful but lack the ability to mimic subjects of a given reference set. With DreamBooth, you only need a few images of a specific subject to fine-tune a pretrained text-to-image model, so that it learns to bind a unique identifier with a special subject. This unique identifier can then be used to synthesize fully-novel photorealistic images of the subject contextualized in @@ -28,7 +28,7 @@ NeMo's Dreambooth is built upon the Stable Diffusion framework. While its archit - Training Dataset - NeMo's Dreambooth model dataset is different from other NeMo multimodal models in that it doesn't necessitate data stored in the webdataset format. You can find a sample dataset at :cite:`mm-models-dreamboothdataset`. For each object you aim to integrate into the model, just place its images (typically 3-5) in a folder and specify its path in ``model.data.instance_dir``. When training with the prior preservation loss, store images produced by the original model in a distinct folder and reference its path in ``model.data.regularization_dir``. This process is automated in NeMo's DreamBooth implementation. + NeMo's Dreambooth model dataset is different from other NeMo multimodal models in that it doesn't necessitate data stored in the webdataset format. You can find a sample dataset at :cite:`mm-models-db-dreamboothdataset`. For each object you aim to integrate into the model, just place its images (typically 3-5) in a folder and specify its path in ``model.data.instance_dir``. When training with the prior preservation loss, store images produced by the original model in a distinct folder and reference its path in ``model.data.regularization_dir``. This process is automated in NeMo's DreamBooth implementation. Model Configuration -------------------- @@ -130,5 +130,5 @@ Reference .. bibliography:: ../mm_all.bib :style: plain :filter: docname in docnames - :labelprefix: MM-MODELS - :keyprefix: mm-models- + :labelprefix: MM-MODELS-DB + :keyprefix: mm-models-db- diff --git a/docs/source/multimodal/text2img/imagen.rst b/docs/source/multimodal/text2img/imagen.rst index 9aeff2f2a061..844f68df747f 100644 --- a/docs/source/multimodal/text2img/imagen.rst +++ b/docs/source/multimodal/text2img/imagen.rst @@ -4,7 +4,7 @@ Imagen Model Introduction ------------------- -Imagen :cite:`mm-models-saharia2022photorealistic` is a multi-stage text-to-image diffusion model with an unprecedented +Imagen :cite:`mm-models-imagen-saharia2022photorealistic` is a multi-stage text-to-image diffusion model with an unprecedented degree of photorealism and a deep level of language understanding. Given a text prompt, Imagen first generates an image at a 64x64 resolution and then upsamples the generated image to 256x256 and 1024x1024 resolutions, all using diffusion models. @@ -75,9 +75,9 @@ Recommended Efficient UNet size for SR256 and SR1024 models are listed below: Noise Scheduling / Sampler ^^^^^^^^^^^^^^^^^^^^^^^^^^ -NeMo Imagen supports two types of noise scheduling: Continous DDPM :cite:`mm-models-nichol2021improved` and EDM :cite:`mm-models-karras2022elucidating`. +NeMo Imagen supports two types of noise scheduling: Continous DDPM :cite:`mm-models-imagen-nichol2021improved` and EDM :cite:`mm-models-imagen-karras2022elucidating`. -Denoising diffusion probabilistic models (DDPM) :cite:`mm-models-ho2020denoising` +Denoising diffusion probabilistic models (DDPM) :cite:`mm-models-imagen-ho2020denoising` represents the most widely adopted noise scheduling approach among all diffusion models. Continuous DDPM introduces several modifications to the standard DDPM framework, with the most significant change being the transition from a discrete noise space to a continuous space. @@ -285,5 +285,5 @@ Reference .. bibliography:: ../mm_all.bib :style: plain :filter: docname in docnames - :labelprefix: MM-MODELS - :keyprefix: mm-models- + :labelprefix: MM-MODELS-IMAGEN + :keyprefix: mm-models-imagen- diff --git a/docs/source/multimodal/text2img/insp2p.rst b/docs/source/multimodal/text2img/insp2p.rst index 177734584bc7..282874444738 100644 --- a/docs/source/multimodal/text2img/insp2p.rst +++ b/docs/source/multimodal/text2img/insp2p.rst @@ -4,7 +4,7 @@ InstructPix2Pix Model Introduction -------------------- -InstructPix2Pix [InstructPix2Pix]_ :cite:`mm-models-insp2p` offers a unique approach to image editing using human-written instructions. Given an input image and a textual directive, the model adjusts the image according to the provided instructions. NeMo Multimodal presents a training pipeline for this conditional diffusion model, utilizing a dataset generated by harnessing the strengths of two prominent pretrained models: a language model (GPT-3) and a text-to-image model (Stable Diffusion). The InstructPix2Pix model operates swiftly, editing images within seconds, eliminating the need for per-example fine-tuning or inversion. It has demonstrated remarkable results across a wide variety of input images and written instructions. +InstructPix2Pix [InstructPix2Pix]_ :cite:`mm-models-insp2p-insp2p` offers a unique approach to image editing using human-written instructions. Given an input image and a textual directive, the model adjusts the image according to the provided instructions. NeMo Multimodal presents a training pipeline for this conditional diffusion model, utilizing a dataset generated by harnessing the strengths of two prominent pretrained models: a language model (GPT-3) and a text-to-image model (Stable Diffusion). The InstructPix2Pix model operates swiftly, editing images within seconds, eliminating the need for per-example fine-tuning or inversion. It has demonstrated remarkable results across a wide variety of input images and written instructions. Built upon the Stable Diffusion framework, NeMo's InstructPix2Pix shares a similar architecture with Stable Diffusion (refer to :doc:`Stable Diffusion <./sd>`). What sets it apart is its unique training dataset and the combined guidance from both image and text prompts. Specifically, InstructPix2pix ::class::``nemo.collections.multimodal.models.instruct_pix2pix.ldm.ddpm_edit.MegatronLatentDiffusionEdit`` is derived directly from Stable Diffusion's ::class::``nemo.collections.multimodal.models.text_to_image.stable_diffusion.ldm.ddpm.MegatronLatentDiffusion``, with alterations to accommodate the dataset and provide support for dual guidance. @@ -79,7 +79,7 @@ References .. bibliography:: ../mm_all.bib :style: plain :filter: docname in docnames - :labelprefix: MM-MODELS - :keyprefix: mm-models- + :labelprefix: MM-MODELS-INSP2P + :keyprefix: mm-models-insp2p- diff --git a/docs/source/multimodal/text2img/intro.rst b/docs/source/multimodal/text2img/intro.rst index 3c3c17768679..599c9bae5e15 100644 --- a/docs/source/multimodal/text2img/intro.rst +++ b/docs/source/multimodal/text2img/intro.rst @@ -13,4 +13,5 @@ NeMo multimodal provides implementations of multiple image-to-text models, inclu imagen dreambooth controlnet + insp2p sdxl_quantization diff --git a/docs/source/multimodal/text2img/sdxl_quantization.rst b/docs/source/multimodal/text2img/sdxl_quantization.rst index 78403e9c402c..68bb7ff8d511 100644 --- a/docs/source/multimodal/text2img/sdxl_quantization.rst +++ b/docs/source/multimodal/text2img/sdxl_quantization.rst @@ -7,16 +7,17 @@ This example shows how to use Ammo to calibrate and quantize the UNet part of th We also provide instructions on deploying and running E2E SDXL pipeline with Ammo quantized int8 UNet to generate images and measure latency on target GPUs. -To get started, it is required to have a pretrained SDXL checkpoint in `nemo` format. The example training configs are provided in NeMo, -which is located in `NeMo/examples/multimodal/text2img/stable_diffusion`. +To get started, it is required to have a pretrained SDXL checkpoint in ``nemo`` format. The example training configs are provided in NeMo, +which is located in ``NeMo/examples/multimodal/text2img/stable_diffusion``. Calibration --------------- The first step is to run quantization script with default config, and finally the script will export the quantized unet to onnx file. -Here is the default config for `NeMo/examples/multimodal/text2img/stable_diffusion/sd_xl_quantize.py`. +Here is the default config for ``NeMo/examples/multimodal/text2img/stable_diffusion/sd_xl_quantize.py``. .. code-block:: yaml + quantize exp_name: nemo n_steps: 20 # number of inference steps @@ -41,6 +42,7 @@ Build the TRT engine for the Quantized ONNX UNet ------------------------------------------------------------ .. code-block:: bash + trtexec --onnx=./nemo_onnx/unet.onnx --shapes=x:8x4x128x128,timesteps:8,context:8x80x2048,y:8x2816 --fp16 --int8 --builderOptimizationLevel=4 --saveEngine=nemo_unet_xl.plan @@ -57,6 +59,7 @@ Build End-to-end Stable Diffusion XL Pipeline with NeMo We provide a script to build end to end TRT inference pipeline with NeMo backend, which is located at `NeMo/examples/multimodal/text2img/stable_diffusion/sd_xl_export.py`. .. code-block:: yaml + infer: out_path: sdxl_export width: 1024 @@ -82,6 +85,7 @@ Run End-to-end Stable Diffusion XL TRT Pipeline The inference script can be found at `NeMo/examples/multimodal/text2img/stable_diffusion/sd_xl_trt_inference.py`. .. code-block:: yaml + unet_xl: sdxl_export/plan/unet_xl.plan vae: sdxl_export/plan/vae.plan clip1: sdxl_export/plan/clip1.plan diff --git a/docs/source/multimodal/vlm/clip.rst b/docs/source/multimodal/vlm/clip.rst index e28fb836ff4a..976baadb5a83 100644 --- a/docs/source/multimodal/vlm/clip.rst +++ b/docs/source/multimodal/vlm/clip.rst @@ -4,7 +4,7 @@ CLIP Model Introduction ------------------- -Contrastive Language-Image Pre-training (CLIP) :cite:`mm-models-radford2021learning` offers an efficient method for learning image representations using natural language supervision. The essence of CLIP is to train both an image encoder and a text encoder from scratch. The model aims to predict the correct pairings of a batch of (image, text) training examples by jointly training these encoders. During pre-training, CLIP is designed to predict which images and texts form a semantically coherent pair by maximizing the similarity between the correct (image, text) pairs while minimizing the similarity between incorrect pairs. This contrastive learning approach ensures that CLIP learns meaningful and contextually rich representations of both visual and textual data. +Contrastive Language-Image Pre-training (CLIP) :cite:`mm-models-clip-radford2021learning` offers an efficient method for learning image representations using natural language supervision. The essence of CLIP is to train both an image encoder and a text encoder from scratch. The model aims to predict the correct pairings of a batch of (image, text) training examples by jointly training these encoders. During pre-training, CLIP is designed to predict which images and texts form a semantically coherent pair by maximizing the similarity between the correct (image, text) pairs while minimizing the similarity between incorrect pairs. This contrastive learning approach ensures that CLIP learns meaningful and contextually rich representations of both visual and textual data. NeMo's implementation of the CLIP model leverages its parallel transformer implementation, specifically the `nemo.collections.nlp.modules.common.megatron.transformer.ParallelTransformer`, to enable model parallelism support in both the text encoder and vision model. This design choice ensures efficient scaling and utilization of resources during training. Additionally, some of the model design and loss implementations in NeMo's CLIP are inspired by the open-source [open_clip](https://github.com/mlfoundations/open_clip) repository. @@ -153,5 +153,5 @@ References .. bibliography:: ../mm_all.bib :style: plain :filter: docname in docnames - :labelprefix: MM-MODELS - :keyprefix: mm-models- + :labelprefix: MM-MODELS-CLIP + :keyprefix: mm-models-clip- diff --git a/docs/source/nlp/api.rst b/docs/source/nlp/api.rst index b9b4d529ba46..52c1b537b0bf 100755 --- a/docs/source/nlp/api.rst +++ b/docs/source/nlp/api.rst @@ -22,7 +22,7 @@ Pretraining Model Classes .. autoclass:: nemo.collections.nlp.models.language_modeling.megatron_bart_model.MegatronBARTModel :show-inheritance: :no-members: - :members: training_step, validation_step, build_train_valid_test_datasets, setup, on_save_checkpoint, on_load_checkpoint + :members: training_step, validation_step, build_train_valid_test_datasets, setup .. autoclass:: nemo.collections.nlp.models.language_modeling.megatron_retrieval_model.MegatronRetrievalModel :show-inheritance: @@ -45,32 +45,27 @@ Customization Model Classes .. autoclass:: nemo.collections.nlp.models.language_modeling.megatron_gpt_adapter_model.MegatronGPTAdapterLearningModel :show-inheritance: :no-members: - :members: __init__, state_dict, generate, training_step, validation_step, build_train_valid_test_datasets, setup + :members: __init__, state_dict, generate, training_step, validation_step, setup .. autoclass:: nemo.collections.nlp.models.language_modeling.megatron_gpt_adapter_model.MegatronGPTInfusedAdapterModel :show-inheritance: :no-members: - :members: __init__, state_dict, generate, training_step, validation_step, build_train_valid_test_datasets, setup + :members: __init__, state_dict, generate, training_step, validation_step, setup .. autoclass:: nemo.collections.nlp.models.language_modeling.megatron_gpt_prompt_learning_model.MegatronGPTPromptLearningModel :show-inheritance: :no-members: - :members: built_virtual_prompt_dataset, generate, training_step, validation_step, build_train_valid_test_datasets, setup + :members: build_virtual_prompt_dataset, generate, training_step, validation_step, setup .. autoclass:: nemo.collections.nlp.models.language_modeling.megatron_t5_adapter_model.MegatronT5AdapterLearningModel :show-inheritance: :no-members: - :members: __init__, state_dict, training_step, validation_step, build_train_valid_test_datasets, setup - -.. autoclass:: nemo.collections.nlp.models.language_modeling.megatron_t5_adapter_model.MegatronT5AdapterLearningModel - :show-inheritance: - :no-members: - :members: _add_adapters_to_component, __init__, state_dict, training_step, validation_step, build_train_valid_test_datasets, setup + :members: _add_adapters_to_component, __init__, state_dict, training_step, validation_step, setup .. autoclass:: nemo.collections.nlp.models.language_modeling.megatron_t5_adapter_model.MegatronT5InfusedAdapterModel :show-inheritance: :no-members: - :members: _add_adapters_to_component, __init__, state_dict, training_step, validation_step, build_train_valid_test_datasets, setup + :members: _add_adapters_to_component, __init__, state_dict, training_step, validation_step, setup Modules ------- @@ -86,7 +81,7 @@ Modules :no-members: :members: forward -.. autoclass:: nemo.collections.nlp.models.language_modeling.megatron.bert_model.BertModel +.. autoclass:: nemo.collections.nlp.models.language_modeling.megatron.bert.bert_model.NeMoBertModel :show-inheritance: :no-members: :members: forward diff --git a/docs/source/nlp/information_retrieval.rst b/docs/source/nlp/information_retrieval.rst index fa9157e45b59..26732283e8f4 100644 --- a/docs/source/nlp/information_retrieval.rst +++ b/docs/source/nlp/information_retrieval.rst @@ -53,7 +53,7 @@ BERT checkpoint to NeMo (mcore) using the following: Then you can fine-tune the sentence-BERT model using the following script: -.. code-block:: python +.. code-block:: bash #!/bin/bash diff --git a/docs/source/nlp/machine_translation/machine_translation.rst b/docs/source/nlp/machine_translation/machine_translation.rst index 190ac5b07da9..f58c67551abe 100644 --- a/docs/source/nlp/machine_translation/machine_translation.rst +++ b/docs/source/nlp/machine_translation/machine_translation.rst @@ -470,12 +470,12 @@ NMT with bottleneck encoder architecture is also supported (i.e., fixed size bot 1. Supported learning frameworks (**model.model_type**): * NLL - Conditional cross entropy (the usual NMT loss) - * VAE - Variational Auto-Encoder (`paper `_) - * MIM - Mutual Information Machine (`paper `_) + * VAE - Variational Auto-Encoder (`paper `__) + * MIM - Mutual Information Machine (`paper `__) 2. Supported encoder architectures (**model.encoder.arch**): * seq2seq - the usual transformer encoder without a bottleneck - * bridge - attention bridge bottleneck (`paper `_) - * perceiver - Perceiver bottleneck (`paper `_) + * bridge - attention bridge bottleneck (`paper `__) + * perceiver - Perceiver bottleneck (`paper `__) +----------------------------------------+----------------+--------------+-------------------------------------------------------------------------------------------------------+ diff --git a/docs/source/nlp/nemo_megatron/gpt/gpt_training.rst b/docs/source/nlp/nemo_megatron/gpt/gpt_training.rst index 2e94cc45b40f..efc2ac3f8439 100644 --- a/docs/source/nlp/nemo_megatron/gpt/gpt_training.rst +++ b/docs/source/nlp/nemo_megatron/gpt/gpt_training.rst @@ -70,7 +70,7 @@ Note that training tokenizer model will also take some time. --pad_id=0 --unk_id=1 --bos_id=2 --eos_id=3 \ --split_digits true -After this is done (will take a while), you'll have two files: ```spm_32k_wiki.model``` and ```spm_32k_wiki.vocab``corresponding to the model and vocabulary. +After this is done (will take a while), you'll have two files: ``spm_32k_wiki.model`` and ``spm_32k_wiki.vocab`` corresponding to the model and vocabulary. **Step 4: Convert training data into memory map format** diff --git a/docs/source/nlp/nemo_megatron/positional_embeddings.rst b/docs/source/nlp/nemo_megatron/positional_embeddings.rst index 332ce304049d..cac0bb452f58 100644 --- a/docs/source/nlp/nemo_megatron/positional_embeddings.rst +++ b/docs/source/nlp/nemo_megatron/positional_embeddings.rst @@ -18,38 +18,38 @@ GPT - .. code:: model.position_embedding_type='learned_absolute' - - Absolute Position Encodings :cite:`nlp-megatron-vaswani2023attention` are position embeddings used in Transformer-based models, added to input embeddings in the encoder and decoder sections. These encodings match the dimension of embeddings and are created using sine and cosine functions of various frequencies. Each dimension in the encoding corresponds to a sinusoid with wavelengths forming a geometric progression. + - Absolute Position Encodings :cite:`pos-emb-vaswani2023attention` are position embeddings used in Transformer-based models, added to input embeddings in the encoder and decoder sections. These encodings match the dimension of embeddings and are created using sine and cosine functions of various frequencies. Each dimension in the encoding corresponds to a sinusoid with wavelengths forming a geometric progression. * - **rope** - .. code:: model.position_embedding_type='rope' model.rotary_percentage=1.0 - - Rotary Position Embedding (RoPE) :cite:`nlp-megatron-su2022roformer` incorporates positional information by utilizing a rotation matrix to encode the absolute positions of tokens while maintaining relative positional relationships in self-attention formulations. It achieves this by leveraging the geometric properties of vectors and complex numbers and applying a rotation based on a preset non-zero constant and the relative positions of the tokens to the word embeddings. - + - Rotary Position Embedding (RoPE) :cite:`pos-emb-su2022roformer` incorporates positional information by utilizing a rotation matrix to encode the absolute positions of tokens while maintaining relative positional relationships in self-attention formulations by leveraging the geometric properties of vectors and complex numbers, applying a rotation based on a preset non-zero constant and the relative positions of the tokens to the word embeddings. + * - **alibi** - .. code:: model.position_embedding_type='alibi' - - Attention with Linear Biases (ALiBi) :cite:`nlp-megatron-press2022train` modifies the way attention scores are computed in the attention sublayer of the network. ALiBi introduces a static, non-learned bias after the query-key dot product during the computation of attention scores. This bias is added in the form of a head-specific slope that is determined before training, creating a geometric sequence of slopes for the different heads in the model. The method has an inductive bias towards recency, penalizing attention scores between distant query-key pairs with the penalty increasing as the distance grows, and it leverages different rates of penalty increase across different heads based on the slope magnitude. + - Attention with Linear Biases (ALiBi) :cite:`pos-emb-press2022train` modifies the way attention scores are computed in the attention sublayer of the network. ALiBi introduces a static, non-learned bias after the query-key dot product during the computation of attention scores. This bias is added in the form of a head-specific slope that is determined before training, creating a geometric sequence of slopes for the different heads in the model. The method has an inductive bias towards recency, penalizing attention scores between distant query-key pairs with the penalty increasing as the distance grows, and it leverages different rates of penalty increase across different heads based on the slope magnitude. * - **kerple** - .. code:: model.position_embedding_type='kerple' - - Kernelized Relative Positional Embedding for Length Extrapolation (KERPLE) :cite:`nlp-megatron-chi2022kerple` generalizes relative positional embeddings (RPE) by kernelizing positional differences using Conditionally Positive Definite (CPD) kernels known for generalizing distance metrics. They transform CPD kernels into positive definite (PD) kernels by adding a constant offset, which is absorbed during softmax normalization in the self-attention mechanism of transformers. This approach allows for a variety of RPEs that facilitate length extrapolation in a principled manner. + - Kernelized Relative Positional Embedding for Length Extrapolation (KERPLE) :cite:`pos-emb-chi2022kerple` generalizes relative positional embeddings (RPE) by kernelizing positional differences using conditionally positive definite (CPD) kernels known for generalizing distance metrics. They transform CPD kernels into positive definite (PD) kernels by adding a constant offset, which is absorbed during softmax normalization in the self-attention mechanism of transformers. This approach allows for a variety of RPEs that facilitate length extrapolation in a principled manner. * - **xpos** - .. code:: model.position_embedding_type='xpos' - - Extrapolatable Position Embedding (xPos) :cite:`nlp-megatron-sun2022lengthextrapolatable` + - Extrapolatable Position Embedding (xPos) :cite:`pos-emb-sun2022lengthextrapolatable` * - **sandwich** - .. code:: model.position_embedding_type='sandwich' - - Sandwich :cite:`nlp-megatron-chi2023dissecting` + - Sandwich :cite:`pos-emb-chi2023dissecting` T5 ^^ @@ -67,32 +67,32 @@ T5 model.encoder.position_embedding_type='learned_absolute' model.decoder.position_embedding_type='learned_absolute' - - Absolute Position Encodings :cite:`nlp-megatron-vaswani2023attention` are position embeddings used in Transformer-based models, added to input embeddings in the encoder and decoder sections. These encodings match the dimension of embeddings and are created using sine and cosine functions of various frequencies. Each dimension in the encoding corresponds to a sinusoid with wavelengths forming a geometric progression. + - Absolute Position Encodings :cite:`pos-emb-vaswani2023attention` are position embeddings used in Transformer-based models, added to input embeddings in the encoder and decoder sections. These encodings match the dimension of embeddings and are created using sine and cosine functions of various frequencies. Each dimension in the encoding corresponds to a sinusoid with wavelengths forming a geometric progression. * - **relative** - .. code:: model.encoder.position_embedding_type='relative' model.decoder.position_embedding_type='relative' - - Relative Position Representations :cite:`nlp-megatron-shaw2018selfattention` + - Relative Position Representations :cite:`pos-emb-shaw2018selfattention` * - **alibi** - .. code:: model.encoder.position_embedding_type='alibi' model.decoder.position_embedding_type='alibi' - - Attention with Linear Biases (ALiBi) :cite:`nlp-megatron-press2022train` modifies the way attention scores are computed in the attention sublayer of the network. ALiBi introduces a static, non-learned bias after the query-key dot product during the computation of attention scores. This bias is added in the form of a head-specific slope that is determined before training, creating a geometric sequence of slopes for the different heads in the model. The method has an inductive bias towards recency, penalizing attention scores between distant query-key pairs with the penalty increasing as the distance grows, and it leverages different rates of penalty increase across different heads based on the slope magnitude. + - Attention with Linear Biases (ALiBi) :cite:`pos-emb-press2022train` modifies the way attention scores are computed in the attention sublayer of the network. ALiBi introduces a static, non-learned bias after the query-key dot product during the computation of attention scores. This bias is added in the form of a head-specific slope that is determined before training, creating a geometric sequence of slopes for the different heads in the model. The method has an inductive bias towards recency, penalizing attention scores between distant query-key pairs with the penalty increasing as the distance grows, and it leverages different rates of penalty increase across different heads based on the slope magnitude. * - **kerple** - .. code:: model.encoder.position_embedding_type='kerple' model.decoder.position_embedding_type='kerple' - - Kernelized Relative Positional Embedding for Length Extrapolation (KERPLE) :cite:`nlp-megatron-chi2022kerple` generalizes relative positional embeddings (RPE) by kernelizing positional differences using Conditionally Positive Definite (CPD) kernels known for generalizing distance metrics. They transform CPD kernels into positive definite (PD) kernels by adding a constant offset, which is absorbed during softmax normalization in the self-attention mechanism of transformers. This approach allows for a variety of RPEs that facilitate length extrapolation in a principled manner. + - Kernelized Relative Positional Embedding for Length Extrapolation (KERPLE) :cite:`pos-emb-chi2022kerple` generalizes relative positional embeddings (RPE) by kernelizing positional differences using conditionally positive definite (CPD) kernels known for generalizing distance metrics. They transform CPD kernels into positive definite (PD) kernels by adding a constant offset, which is absorbed during softmax normalization in the self-attention mechanism of transformers. This approach allows for a variety of RPEs that facilitate length extrapolation in a principled manner. Positional interpolation ------------------------ -Position Interpolation (PI) :cite:`nlp-megatron-chen2023extending` is a method introduced to extend the context window sizes of Rotary Position Embedding (RoPE)-based pretrained large language models (LLMs). The central principle of PI is to reduce the position indices so that they align with the initial context window size through interpolation. +Position Interpolation (PI) :cite:`pos-emb-chen2023extending` is a method introduced to extend the context window sizes of Rotary Position Embedding (RoPE)-based pretrained large language models (LLMs). The central principle of PI is to reduce the position indices so that they align with the initial context window size through interpolation. Positional Interpolation is supported in Megatron GPT SFT models. Set RoPE Interpolation factor for sequence length :code:`seq_len_interpolation_factor` to enable it. @@ -107,5 +107,5 @@ References .. bibliography:: ../nlp_all.bib :style: plain - :labelprefix: nlp-megatron - :keyprefix: nlp-megatron- \ No newline at end of file + :labelprefix: pos-emb + :keyprefix: pos-emb- \ No newline at end of file diff --git a/docs/source/nlp/punctuation_and_capitalization_lexical_audio.rst b/docs/source/nlp/punctuation_and_capitalization_lexical_audio.rst index 8314676e5c4c..4cd13abd2264 100644 --- a/docs/source/nlp/punctuation_and_capitalization_lexical_audio.rst +++ b/docs/source/nlp/punctuation_and_capitalization_lexical_audio.rst @@ -36,7 +36,7 @@ Quick Start Guide Model Description ----------------- In addition to :doc:`Punctuation And Capitalization model <./punctuation_and_capitalization>` we add audio encoder (e.g. Conformer's encoder) and attention based fusion of lexical and audio features. -This model architecture is based on `Multimodal Semi-supervised Learning Framework for Punctuation Prediction in Conversational Speech `__ :cite:`nlp-punct-sunkara20_interspeech`. +This model architecture is based on `Multimodal Semi-supervised Learning Framework for Punctuation Prediction in Conversational Speech `__ :cite:`nlp-punct-lex-sunkara20_interspeech`. .. note:: @@ -386,6 +386,6 @@ References .. bibliography:: nlp_all.bib :style: plain - :labelprefix: NLP-PUNCT - :keyprefix: nlp-punct- + :labelprefix: NLP-PUNCT-LEX + :keyprefix: nlp-punct-lex- diff --git a/docs/source/nlp/text_normalization/text_normalization_as_tagging.rst b/docs/source/nlp/text_normalization/text_normalization_as_tagging.rst index 672226622357..702fb9425026 100644 --- a/docs/source/nlp/text_normalization/text_normalization_as_tagging.rst +++ b/docs/source/nlp/text_normalization/text_normalization_as_tagging.rst @@ -59,7 +59,7 @@ In the example, ```` denotes that the spoken form is the same as the writt -More information about the Google Text Normalization Dataset can be found in the paper `RNN Approaches to Text Normalization: A Challenge `__ :cite:`nlp-textnorm-sproat2016rnn`. +More information about the Google Text Normalization Dataset can be found in the paper `RNN Approaches to Text Normalization: A Challenge `__ :cite:`nlp-textnorm-tag-sproat2016rnn`. Data preprocessing @@ -146,7 +146,7 @@ contextualized representation for each input token. It then uses a classificatio to predict the tag for each token. Another classification head is used to predict a "semiotic" class label for each token. Overall, our design is partly inspired by the LaserTagger approach proposed in the paper -`Encode, tag, realize: High-precision text editing `__ :cite:`nlp-textnorm-malmi2019encode`. +`Encode, tag, realize: High-precision text editing `__ :cite:`nlp-textnorm-tag-malmi2019encode`. The LaserTagger method is not directly applicable to ITN because it can only regard the whole non-common fragment as a single replacement tag, whereas spoken-to-written conversion, e.g. a date, needs to be aligned on a more granular level. Otherwise, @@ -161,5 +161,5 @@ References .. bibliography:: tn_itn_all.bib :style: plain - :labelprefix: NLP-TEXTNORM - :keyprefix: nlp-textnorm- + :labelprefix: NLP-TEXTNORM-TAG + :keyprefix: nlp-textnorm-tag diff --git a/docs/source/starthere/best-practices.rst b/docs/source/starthere/best-practices.rst index ec0fea1985cc..759ee108ed7b 100644 --- a/docs/source/starthere/best-practices.rst +++ b/docs/source/starthere/best-practices.rst @@ -23,7 +23,7 @@ NeMo excels in training large-scale LLM & MM, utilizing optimizations from Megat - Advanced checkpointing through the Distributed Checkpoint Format. Speech AI --------- +--------- Data Augmentation ~~~~~~~~~~~~~~~~~ diff --git a/docs/source/starthere/migration-guide.rst b/docs/source/starthere/migration-guide.rst index 1d9816493a5b..7005873e5343 100644 --- a/docs/source/starthere/migration-guide.rst +++ b/docs/source/starthere/migration-guide.rst @@ -8,39 +8,39 @@ Upgrade guide to use lightning 2.0 .. _dummy_header: -* Replace ``trainer.strategy=null`` with ``trainer.strategy=auto`` as `lightning 2.0 doesn't have None strategy `_. +* Replace ``trainer.strategy=null`` with ``trainer.strategy=auto`` as `lightning 2.0 doesn't have None strategy `__. -* Remove ``resume_from_checkpoint`` if being used as a trainer flag and pass the path to `Trainer.fit(ckpt_path="...") method `_. +* Remove ``resume_from_checkpoint`` if being used as a trainer flag and pass the path to `Trainer.fit(ckpt_path="...") method `__. * Set ``trainer.strategy = "ddp_find_unused_parameters_true"`` if there are unused parameters in your model as lightning 2.0 has find_unused_parameters as False by default. - Reference: `NeMo PR 6433 `_. More details about this change: `lightning PR 16611 `_. + Reference: `NeMo PR 6433 `__. More details about this change: `lightning PR 16611 `__. -* If used Trainer's flag ``replace_sampler_ddp`` replace it with `use_distributed_sampler `_. +* If used Trainer's flag ``replace_sampler_ddp`` replace it with `use_distributed_sampler `__. -* If using ``CheckpointConnector`` replace it with `_CheckpointConnector `_. +* If using ``CheckpointConnector`` replace it with `_CheckpointConnector `__. * To set or get ``ckpt_path`` use ``trainer.ckpt_path`` directly instead of calling protected API via ``trainer._checkpoint_connector._ckpt_path`` or using ``trainer._checkpoint_connector.resume_from_checkpoint_fit_path``. * Change ``import load`` from pytorch_lightning.utilities.cloud_io to ``import _load``. -* If used ``from pytorch_lightning.plugins.precision.native_amp import NativeMixedPrecisionPlugin`` from replace it with `from pytorch_lightning.plugins.precision import MixedPrecisionPlugin `_. +* If used ``from pytorch_lightning.plugins.precision.native_amp import NativeMixedPrecisionPlugin`` from replace it with `from pytorch_lightning.plugins.precision import MixedPrecisionPlugin `__. * Lightning 2.0 adds ``'16-mixed'``, ``'bf16-mixed'`` as the preicison values for fp16 mixed precision and bf16 mixed precision respectively. - For backward compatbility ``16`` or ``'16'`` and ``'bf16'`` also perform mixed precision and is equivalent to ``'16-mixed'`` and ``'bf16-mixed'`` respectively. However, lightning recommends to use ``'16-mixed'`` and ``'bf16-mixed'`` to make it less ambiguous. Due to this, ``MegatronHalfPrecisionPlugin's`` parent class from lightning ``MixedPrecisionPlugin`` class, expects the precision arg to be ``'16-mixed'`` and ``'bf16-mixed'``. As a result it's required to pass ``'16-mixed'`` or ``'bf16-mixed'`` to ``MixedPrecisionPLugin`` whenever the precision passed is any of ``[16, '16', '16-mixed']`` or ``['bf16', 'bf16-mixed']``. This can be taken care as shown here: `NeMo upgrade to lightning 2.0 PR `_ and here: `MixedPrecisionPlugin `_. Also, ``'32-true'`` is added as a precsion value for pure fp32 along with ``32``, ``'32'`` that existed. This can be taken into account as shown here in the `NeMo upgrade to lightning 2.0 PR `_. + For backward compatbility ``16`` or ``'16'`` and ``'bf16'`` also perform mixed precision and is equivalent to ``'16-mixed'`` and ``'bf16-mixed'`` respectively. However, lightning recommends to use ``'16-mixed'`` and ``'bf16-mixed'`` to make it less ambiguous. Due to this, ``MegatronHalfPrecisionPlugin's`` parent class from lightning ``MixedPrecisionPlugin`` class, expects the precision arg to be ``'16-mixed'`` and ``'bf16-mixed'``. As a result it's required to pass ``'16-mixed'`` or ``'bf16-mixed'`` to ``MixedPrecisionPLugin`` whenever the precision passed is any of ``[16, '16', '16-mixed']`` or ``['bf16', 'bf16-mixed']``. This can be taken care as shown here: `NeMo upgrade to lightning 2.0 PR `__ and here: `MixedPrecisionPlugin `__. Also, ``'32-true'`` is added as a precsion value for pure fp32 along with ``32``, ``'32'`` that existed. This can be taken into account as shown here in the `NeMo upgrade to lightning 2.0 PR `__. -* Lightning 2.0 renames epoch end hooks from ``training_epoch_end``, ``validation_epoch_end``, ``test_epoch_end`` to ``on_train_epoch_end``, ``on_validation_epoch_end``, ``on_test_epoch_end``. The renamed hooks do not accept the outputs arg but instead outputs needs to be defined as an instance variable of the model class to which the outputs of the step needs to be manually appended. More detailed examples implementing this can be found under migration guide of `lightning's PR 16520 `_. Example from NeMo can be found `here `_. +* Lightning 2.0 renames epoch end hooks from ``training_epoch_end``, ``validation_epoch_end``, ``test_epoch_end`` to ``on_train_epoch_end``, ``on_validation_epoch_end``, ``on_test_epoch_end``. The renamed hooks do not accept the outputs arg but instead outputs needs to be defined as an instance variable of the model class to which the outputs of the step needs to be manually appended. More detailed examples implementing this can be found under migration guide of `lightning's PR 16520 `__. Example from NeMo can be found `here `__. * Lightning 2.0 is not currently supporting multiple dataloders for validation and testing in case of ``dataloader_iter``. The support for this will be added back soon in an upcoming release. If ``dataloader_iter`` is being used and your config passes multiple files to ``validation_ds.file_names`` or ``test_ds.file_names``, please use just one file until this issue is fixed with pytorch lightning. * With lightning 2.0 it's required to set ``limit_val_batches`` and ``num_sanity_val_steps`` to be a multiple of number of microbatches while using ``dataloader_iter`` (applies only to Megatron files that use dataloader_iter) for all pretraining files (not downstream tasks like finetuning). This is being taken care internally in NeMo and does not require anything to be done by the user. However, if you are a developer of NeMo and are building a new model for pretraining that uses ``dataloader_iter`` instead of batch in ``validation_step`` methods please make sure to call ``self._reconfigure_val_batches()`` in ``build_train_valid_test_datasets method`` of your model. * If model is being wrapped with ``LightningDistributedModule`` in ``configure_ddp`` method please replace it with ``_LightningModuleWrapperBase`` - as being done here: `NeMo upgrade to lightning 2.0 PR `_. + as being done here: `NeMo upgrade to lightning 2.0 PR `__. -* If using ``pre_configure_ddp()`` in your DDP, remove it as it's not required anymore. `NeMo upgrade to lightning 2.0 PR `_. +* If using ``pre_configure_ddp()`` in your DDP, remove it as it's not required anymore. `NeMo upgrade to lightning 2.0 PR `__. * If any of the tests use CPU as the device, ensure to explicitly pass it in the trainer as ``trainer = pl.Trainer(max_epochs=1, accelerator='cpu')`` since deafult val in PTL >= 2.0 is auto and it picks cuda. diff --git a/docs/source/tools/nemo_forced_aligner.rst b/docs/source/tools/nemo_forced_aligner.rst index aa8d2139653f..df872e7d2195 100644 --- a/docs/source/tools/nemo_forced_aligner.rst +++ b/docs/source/tools/nemo_forced_aligner.rst @@ -12,14 +12,14 @@ NFA can be used on long audio files of 1+ hours duration (subject to your hardwa Demos & Tutorials ----------------- -* HuggingFace Space `demo `_ to quickly try out NFA in various languages. -* NFA "how-to" notebook `tutorial `_. -* "How forced alignment works" NeMo blog `tutorial `_. +* HuggingFace Space `demo `__ to quickly try out NFA in various languages. +* NFA "how-to" notebook `tutorial `__. +* "How forced alignment works" NeMo blog `tutorial `__. Quickstart ---------- -1. Install `NeMo `_. +1. Install `NeMo `__. 2. Prepare a NeMo-style manifest containing the paths of audio files you would like to proces, and (optionally) their text. 3. Run NFA's ``align.py`` script with the desired config, e.g.: diff --git a/docs/source/vision/checkpoint.rst b/docs/source/vision/checkpoint.rst index 7e3e197a1169..49848b90d51a 100644 --- a/docs/source/vision/checkpoint.rst +++ b/docs/source/vision/checkpoint.rst @@ -63,7 +63,7 @@ ViT Checkpoints To adjust model parallelism from original model parallelism size to a new model parallelism size (Note: NeMo ViT currently only supports `pipeline_model_parallel_size=1`): -.. code-block:: python +.. code-block:: bash python examples/nlp/language_modeling/megatron_change_num_partitions.py \ --model_file=/path/to/source.nemo \ diff --git a/docs/source/vision/vit.rst b/docs/source/vision/vit.rst index 679313bcbd66..a7b4e2546f22 100644 --- a/docs/source/vision/vit.rst +++ b/docs/source/vision/vit.rst @@ -4,7 +4,7 @@ ViT Model Introduction ------------------- -The Vision Transformer, commonly referred to as ViT :cite:`vision-models-vit`, serves as a foundational model +The Vision Transformer, commonly referred to as ViT :cite:`vision-models-vit-vit`, serves as a foundational model for image classification tasks in NeMo. Unlike conventional convolutional neural networks, ViT adopts a transformer-like architecture to process image data. In this approach, an image is divided into fixed-size patches, typically 14x14 or 16x16. These patches are linearly embedded and augmented with position embeddings. The resulting @@ -136,5 +136,5 @@ Reference .. bibliography:: ./vision_all.bib :style: plain :filter: docname in docnames - :labelprefix: VISION-MODELS - :keyprefix: vision-models- + :labelprefix: VISION-MODELS-VIT + :keyprefix: vision-models-vit- diff --git a/nemo/collections/asr/models/asr_model.py b/nemo/collections/asr/models/asr_model.py index 4420318dd416..e14424cec5c1 100644 --- a/nemo/collections/asr/models/asr_model.py +++ b/nemo/collections/asr/models/asr_model.py @@ -203,9 +203,9 @@ def forward_for_export( """ This forward is used when we need to export the model to ONNX format. Inputs cache_last_channel and cache_last_time are needed to be passed for exporting streaming models. + Args: - input: Tensor that represents a batch of raw audio signals, - of shape [B, T]. T here represents timesteps. + input: Tensor that represents a batch of raw audio signals of shape [B, T]. T here represents timesteps. length: Vector of length B, that contains the individual lengths of the audio sequences. cache_last_channel: Tensor of shape [N, B, T, H] which contains the cache for last channel layers cache_last_time: Tensor of shape [N, B, H, T] which contains the cache for last time layers diff --git a/nemo/collections/asr/models/msdd_models.py b/nemo/collections/asr/models/msdd_models.py index d96bafd5af9b..01926eb4ae79 100644 --- a/nemo/collections/asr/models/msdd_models.py +++ b/nemo/collections/asr/models/msdd_models.py @@ -400,10 +400,15 @@ def get_cluster_avg_embs_model( multi-scale input tensors during forward propagating. Example: `batch_size=3, scale_n=6, emb_dim=192` - ms_seg_counts = - [[8, 9, 12, 16, 25, 51], - [11, 13, 14, 17, 25, 51], - [ 9, 9, 11, 16, 23, 50]] + .. code:: python + + ms_seg_counts = + [ + [ 8, 9, 12, 16, 25, 51], + [11, 13, 14, 17, 25, 51], + [ 9, 9, 11, 16, 23, 50] + ] + Counts of merged segments: (121, 131, 118) embs has shape of (370, 192) clus_label_index has shape of (3, 131) diff --git a/nemo/collections/asr/modules/rnnt.py b/nemo/collections/asr/modules/rnnt.py index 5a7457f6379d..055066c00660 100644 --- a/nemo/collections/asr/modules/rnnt.py +++ b/nemo/collections/asr/modules/rnnt.py @@ -1559,13 +1559,13 @@ def joint_after_projection(self, f: torch.Tensor, g: torch.Tensor) -> torch.Tens NOTE: The implementation of this model is slightly modified from the original paper. The original paper proposes the following steps : - (enc, dec) -> Expand + Concat + Sum [B, T, U, H1+H2] -> Forward through joint hidden [B, T, U, H] -- *1 - *1 -> Forward through joint final [B, T, U, V + 1]. + (enc, dec) -> Expand + Concat + Sum [B, T, U, H1+H2] -> Forward through joint hidden [B, T, U, H] -- \*1 + \*1 -> Forward through joint final [B, T, U, V + 1]. We instead split the joint hidden into joint_hidden_enc and joint_hidden_dec and act as follows: - enc -> Forward through joint_hidden_enc -> Expand [B, T, 1, H] -- *1 - dec -> Forward through joint_hidden_dec -> Expand [B, 1, U, H] -- *2 - (*1, *2) -> Sum [B, T, U, H] -> Forward through joint final [B, T, U, V + 1]. + enc -> Forward through joint_hidden_enc -> Expand [B, T, 1, H] -- \*1 + dec -> Forward through joint_hidden_dec -> Expand [B, 1, U, H] -- \*2 + (\*1, \*2) -> Sum [B, T, U, H] -> Forward through joint final [B, T, U, V + 1]. Args: f: Output of the Encoder model. A torch.Tensor of shape [B, T, H1] @@ -2050,8 +2050,7 @@ def sampled_joint( """ Compute the sampled joint step of the network. - # Reference - - [Memory-Efficient Training of RNN-Transducer with Sampled Softmax](https://arxiv.org/abs/2203.16868) + Reference: `Memory-Efficient Training of RNN-Transducer with Sampled Softmax `__. Here, B = Batch size @@ -2065,13 +2064,13 @@ def sampled_joint( NOTE: The implementation of this joint model is slightly modified from the original paper. The original paper proposes the following steps : - (enc, dec) -> Expand + Concat + Sum [B, T, U, H1+H2] -> Forward through joint hidden [B, T, U, H] -- *1 - *1 -> Forward through joint final [B, T, U, V + 1]. + (enc, dec) -> Expand + Concat + Sum [B, T, U, H1+H2] -> Forward through joint hidden [B, T, U, H] -- \*1 + \*1 -> Forward through joint final [B, T, U, V + 1]. We instead split the joint hidden into joint_hidden_enc and joint_hidden_dec and act as follows: - enc -> Forward through joint_hidden_enc -> Expand [B, T, 1, H] -- *1 - dec -> Forward through joint_hidden_dec -> Expand [B, 1, U, H] -- *2 - (*1, *2) -> Sum [B, T, U, H] -> Sample Vocab V_Pos (for target tokens) and V_Neg -> + enc -> Forward through joint_hidden_enc -> Expand [B, T, 1, H] -- \*1 + dec -> Forward through joint_hidden_dec -> Expand [B, 1, U, H] -- \*2 + (\*1, \*2) -> Sum [B, T, U, H] -> Sample Vocab V_Pos (for target tokens) and V_Neg -> (V_Neg is sampled not uniformly by as a rand permutation of all vocab tokens, then eliminate all Intersection(V_Pos, V_Neg) common tokens to avoid duplication of loss) -> Concat new Vocab V_Sampled = Union(V_Pos, V_Neg) diff --git a/nemo/collections/common/tokenizers/huggingface/auto_tokenizer.py b/nemo/collections/common/tokenizers/huggingface/auto_tokenizer.py index b264890ce48d..dc0cef692ee2 100644 --- a/nemo/collections/common/tokenizers/huggingface/auto_tokenizer.py +++ b/nemo/collections/common/tokenizers/huggingface/auto_tokenizer.py @@ -26,9 +26,10 @@ class AutoTokenizer(TokenizerSpec): - ''' + """ Wrapper of HuggingFace AutoTokenizer https://huggingface.co/transformers/model_doc/auto.html#autotokenizer. - ''' + + """ def __init__( self, @@ -52,7 +53,7 @@ def __init__( For more details please refer to https://huggingface.co/transformers/_modules/transformers/tokenization_auto.html#AutoTokenizer.from_pretrained. The list of all supported models can be found here: ALL_PRETRAINED_CONFIG_ARCHIVE_MAP vocab_file: path to file with vocabulary which consists - of characters separated by '\n'. + of characters separated by newlines. mask_token: mask token bos_token: the beginning of sequence token eos_token: the end of sequence token. Usually equal to sep_token @@ -167,11 +168,13 @@ def add_special_tokens(self, special_tokens_dict: dict) -> int: """ Adds a dictionary of special tokens (eos, pad, cls...). If special tokens are NOT in the vocabulary, they are added to it (indexed starting from the last index of the current vocabulary). + Args: special_tokens_dict: dict of string. Keys should be in the list of predefined special attributes: [``bos_token``, ``eos_token``, ``unk_token``, ``sep_token``, ``pad_token``, ``cls_token``, ``mask_token``, ``additional_special_tokens``]. - Tokens are only added if they are not already in the vocabulary. + Tokens are only added if they are not already in the vocabulary. + Returns: Number of tokens added to the vocabulary. """ diff --git a/nemo/collections/nlp/data/language_modeling/megatron/t5_dataset.py b/nemo/collections/nlp/data/language_modeling/megatron/t5_dataset.py index 72f4fd0e12a1..f0efaf5cd1aa 100644 --- a/nemo/collections/nlp/data/language_modeling/megatron/t5_dataset.py +++ b/nemo/collections/nlp/data/language_modeling/megatron/t5_dataset.py @@ -252,7 +252,8 @@ def build_training_sample( skip_masking_id=None, ): """Build training sample. - Arguments: + + Args: sample: A list of sentences in which each sentence is a list token ids. target_seq_length: Desired sequence length. max_seq_length: Maximum length of the sequence. All values are padded to diff --git a/nemo/collections/nlp/data/language_modeling/megatron/t5_prompt_learning_dataset.py b/nemo/collections/nlp/data/language_modeling/megatron/t5_prompt_learning_dataset.py index 5ed0da009cf2..fb8ec9554a95 100644 --- a/nemo/collections/nlp/data/language_modeling/megatron/t5_prompt_learning_dataset.py +++ b/nemo/collections/nlp/data/language_modeling/megatron/t5_prompt_learning_dataset.py @@ -72,10 +72,10 @@ def load_data(self, dataset): """ Loads a dataset by filling in the task templates specified in the config file with the information from each training/inference example. Converts all input - text into token ids. Also replaces the <|VIRTUAL_PROMPT_#|> placeholders in + text into token ids. Also replaces the ``<|VIRTUAL_PROMPT_#|>`` placeholders in the task templates with the actual virtual prompt token ids. - params: + Args: dataset: A list of json objects or a dictionary objects each containing the information needed for a training example """ diff --git a/nemo/collections/nlp/data/language_modeling/megatron/ul2_dataset.py b/nemo/collections/nlp/data/language_modeling/megatron/ul2_dataset.py index c2d19305cf03..485388d84343 100644 --- a/nemo/collections/nlp/data/language_modeling/megatron/ul2_dataset.py +++ b/nemo/collections/nlp/data/language_modeling/megatron/ul2_dataset.py @@ -25,6 +25,7 @@ class UL2Dataset(T5Dataset): """ UL2 Dataset from https://arxiv.org/abs/2205.05131. Consists of three different objectives: + 1. Short span masking with small probabilities (ex: T5). Typically max ngram size of 5 with 0.15 mask prob. 2. Extreme span masking with either large probabilities or large ngram sizes or both. 3. Prefx-LM as in the T5 or LM-adapted T5 (prompt-tuning paper). @@ -312,7 +313,8 @@ def build_extreme_masking_training_sample( skip_masking_id=None, ): """Build training sample. - Arguments: + + Args: sample: A list of sentences in which each sentence is a list token ids. target_seq_length: Desired sequence length. max_seq_length: Maximum length of the sequence. All values are padded to diff --git a/nemo/collections/nlp/models/information_retrieval/megatron_bert_embedding_model.py b/nemo/collections/nlp/models/information_retrieval/megatron_bert_embedding_model.py index d974c8182234..102ab5ec0f84 100644 --- a/nemo/collections/nlp/models/information_retrieval/megatron_bert_embedding_model.py +++ b/nemo/collections/nlp/models/information_retrieval/megatron_bert_embedding_model.py @@ -182,9 +182,11 @@ def build_train_valid_test_datasets(self): return self._train_ds, self._validation_ds, self._test_ds def setup(self, stage=None): - """ PTL hook that is executed after DDP spawns. - We setup datasets here as megatron datasets require DDP to instantiate. - See https://pytorch-lightning.readthedocs.io/en/latest/common/lightning_module.html#setup for more information. + """ + PTL hook that is executed after DDP spawns. + We setup datasets here as megatron datasets require DDP to instantiate. + See https://pytorch-lightning.readthedocs.io/en/latest/common/lightning_module.html#setup for more information. + Args: stage (str, optional): Can be 'fit', 'validate', 'test' or 'predict'. Defaults to None. """ diff --git a/nemo/collections/nlp/models/language_modeling/megatron_bert_model.py b/nemo/collections/nlp/models/language_modeling/megatron_bert_model.py index dc6d81649122..0f1fa76f9b01 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_bert_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_bert_model.py @@ -760,9 +760,11 @@ def _append_sequence_parallel_module_grads(self, module, grads): grads.append(grad.data) def setup(self, stage=None): - """ PTL hook that is executed after DDP spawns. - We setup datasets here as megatron datasets require DDP to instantiate. - See https://pytorch-lightning.readthedocs.io/en/latest/common/lightning_module.html#setup for more information. + """ + PTL hook that is executed after DDP spawns. + We setup datasets here as megatron datasets require DDP to instantiate. + See https://pytorch-lightning.readthedocs.io/en/latest/common/lightning_module.html#setup for more information. + Args: stage (str, optional): Can be 'fit', 'validate', 'test' or 'predict'. Defaults to None. """ diff --git a/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py b/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py index 7a2f3459470c..d7f489abf158 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py @@ -1475,9 +1475,11 @@ def build_pretraining_data_loader( ) def setup(self, stage=None): - """ PTL hook that is executed after DDP spawns. - We setup datasets here as megatron datasets require DDP to instantiate. - See https://pytorch-lightning.readthedocs.io/en/latest/common/lightning_module.html#setup for more information. + """ + PTL hook that is executed after DDP spawns. + We setup datasets here as megatron datasets require DDP to instantiate. + See https://pytorch-lightning.readthedocs.io/en/latest/common/lightning_module.html#setup for more information. + Args: stage (str, optional): Can be 'fit', 'validate', 'test' or 'predict'. Defaults to None. """ diff --git a/nemo/collections/nlp/models/language_modeling/megatron_lm_encoder_decoder_model.py b/nemo/collections/nlp/models/language_modeling/megatron_lm_encoder_decoder_model.py index 459bf5b71c7e..4c39bd877b4a 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_lm_encoder_decoder_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_lm_encoder_decoder_model.py @@ -911,9 +911,11 @@ def build_pretraining_data_loader(self, dataset, consumed_samples, num_workers): ) def setup(self, stage=None): - """ PTL hook that is executed after DDP spawns. - We setup datasets here as megatron datasets require DDP to instantiate. - See https://pytorch-lightning.readthedocs.io/en/latest/common/lightning_module.html#setup for more information. + """ + PTL hook that is executed after DDP spawns. + We setup datasets here as megatron datasets require DDP to instantiate. + See https://pytorch-lightning.readthedocs.io/en/latest/common/lightning_module.html#setup for more information. + Args: stage (str, optional): Can be 'fit', 'validate', 'test' or 'predict'. Defaults to None. """ @@ -1413,11 +1415,13 @@ def dummy(): def complete(self, request: Dict): """ - Autoregressively invokes language model in the inference mode + Autoregressively invokes language model in the inference mode + Args: request: Dictionary with the following fields * prompt: a string which text the model should complete. * tokens_to_generate: how many tokens to generate while doing prompt completion. + Returns: response: A python dictionary with the following fields * prompt: original text of the prompt diff --git a/nemo/collections/nlp/modules/common/transformer/text_generation.py b/nemo/collections/nlp/modules/common/transformer/text_generation.py index a4e37935adc9..5f0275ff4553 100644 --- a/nemo/collections/nlp/modules/common/transformer/text_generation.py +++ b/nemo/collections/nlp/modules/common/transformer/text_generation.py @@ -67,47 +67,48 @@ def generate( inputs (Union[List[str], Tensor, List[dict]]): Can be one of the 3 types: - 1. List of strings. Each element of the list provides input prompt. The model will apply tokenizer on it. - E.g [‘sentence’, ‘sentence2’ … ] + 1. List of strings. Each element of the list provides input prompt. The model will apply tokenizer on it. + E.g [‘sentence’, ‘sentence2’ … ] - 2. Tuple of Pytorch Tensors (context_tokens, context_lengths). The `context_tokens` has shape (batch_size, seq_length), it's the batched sequences of tokens used as a prompst for the generation or as model inputs to the encoder. - The generative model will skip the tokenization and padding step. The `context_lengths` has shape (batch_size,), it indicates the length of the context tokens for each of the input sequences. - E.g. ( torch.tensor([[23,5234,23,35,…], [223,323,23,23232,232,...] …]), torch.tensor([20, 30, …])) + 2. Tuple of Pytorch Tensors (context_tokens, context_lengths). The `context_tokens` has shape (batch_size, seq_length), it's the batched sequences of tokens used as a prompst for the generation or as model inputs to the encoder. + The generative model will skip the tokenization and padding step. The `context_lengths` has shape (batch_size,), it indicates the length of the context tokens for each of the input sequences. + E.g. ( torch.tensor([[23,5234,23,35,…], [223,323,23,23232,232,...] …]), torch.tensor([20, 30, …])) - 3. List of python dict objects. Used for prompt/p-tuning inputs where a set of key-value pairs are converted into input token embeddings for the model. - E.g. [{"prompt-tag": "sentiment", "sentence": "this is a good movie"}, - {"prompt-tag": "qa", "context": "some context text", "question": "a simple question"} ... ] - where 'prompt-tag' is used to identify the type of NLP task to solve. + 3. List of python dict objects. Used for prompt/p-tuning inputs where a set of key-value pairs are converted into input token embeddings for the model. + E.g. [{"prompt-tag": "sentiment", "sentence": "this is a good movie"}, + {"prompt-tag": "qa", "context": "some context text", "question": "a simple question"} ... ] + where 'prompt-tag' is used to identify the type of NLP task to solve. length_params (LengthParam): a dictionary type which controls the sampling length. - max_length: int, The maximum length of the sequence to be generated. - - min_length: int, The minimum length of the sequence to be generated. + * max_length: int, The maximum length of the sequence to be generated. + * min_length: int, The minimum length of the sequence to be generated. If None, max_length is set to 30, and min_length is set to None + sampling_params (SamplingParam): a dictionary type which contains the parameters for text sampling. It has the following keys - use_greedy: bool, Whether or not to use sampling ; use greedy decoding otherwise - top_k: int, The number of highest probability vocabulary tokens to keep for top-k-filtering. - top_p: float, If set to float < 1, only the most probable tokens with probabilities that add up to top_p or higher are kept for generation. - repetition_penalty: float, The parameter for repetition penalty. 1.0 means no penalty. - add_BOS: bool, Whether add the bos token at the begining of the prompt - all_probs: bool # whether return the log prob for all the tokens in vocab - compute_logprob: bool # a flag used to compute logprob of all the input text, a very special case of running inference, default False - end_strings: List[str] # generation will stop when one of these tokens is generated + * use_greedy: bool, Whether or not to use sampling ; use greedy decoding otherwise + * top_k: int, The number of highest probability vocabulary tokens to keep for top-k-filtering. + * top_p: float, If set to float < 1, only the most probable tokens with probabilities that add up to top_p or higher are kept for generation. + * repetition_penalty: float, The parameter for repetition penalty. 1.0 means no penalty. + * add_BOS: bool, Whether add the bos token at the begining of the prompt + * all_probs: bool # whether return the log prob for all the tokens in vocab + * compute_logprob: bool # a flag used to compute logprob of all the input text, a very special case of running inference, default False + * end_strings: List[str] # generation will stop when one of these tokens is generated + Default None, If it is None, use_greedy will be "True". Returns: - OutputType: It generates the output in a dictionary type. It has the following keys: - - sentences: List[str], output sentences - tokens: List[List[str]], output sentences borken into tokens - logprob: List[List[float]], log prob of generated tokens - full_logprob: List[List[float]], log prob of all the tokens in the vocab - token_ids: List[List[int]], output sentence token ids - offsets: List[List[int]] # list of tokens start positions in text + It generates the output in a dictionary type. It has the following keys, + + * sentences: List[str], output sentences + * tokens: List[List[str]], output sentences borken into tokens + * logprob: List[List[float]], log prob of generated tokens + * full_logprob: List[List[float]], log prob of all the tokens in the vocab + * token_ids: List[List[int]], output sentence token ids + * offsets: List[List[int]] # list of tokens start positions in text """ raise NotImplementedError("please implement this method") diff --git a/nemo/collections/vision/models/megatron_vit_classification_models.py b/nemo/collections/vision/models/megatron_vit_classification_models.py index c27c37c2b917..ea6d3578c540 100644 --- a/nemo/collections/vision/models/megatron_vit_classification_models.py +++ b/nemo/collections/vision/models/megatron_vit_classification_models.py @@ -621,9 +621,11 @@ def build_pretraining_data_loader(self, dataset, consumed_samples, drop_last=Tru ) def setup(self, stage=None): - """ PTL hook that is executed after DDP spawns. - We setup datasets here as megatron datasets require DDP to instantiate. - See https://pytorch-lightning.readthedocs.io/en/latest/common/lightning_module.html#setup for more information. + """ + PTL hook that is executed after DDP spawns. + We setup datasets here as megatron datasets require DDP to instantiate. + See https://pytorch-lightning.readthedocs.io/en/latest/common/lightning_module.html#setup for more information. + Args: stage (str, optional): Can be 'fit', 'validate', 'test' or 'predict'. Defaults to None. """ diff --git a/nemo/core/classes/dataset.py b/nemo/core/classes/dataset.py index 738ae22f5416..789fc0b863d7 100644 --- a/nemo/core/classes/dataset.py +++ b/nemo/core/classes/dataset.py @@ -42,12 +42,15 @@ def collate_fn(self, batch): Please note, subclasses of Dataset should not implement `input_types`. - # Usage: - dataloader = torch.utils.data.DataLoader( - ...., - collate_fn=dataset.collate_fn, - .... - ) + Usage: + + .. code-block:: python + + dataloader = torch.utils.data.DataLoader( + ...., + collate_fn=dataset.collate_fn, + .... + ) Returns: Collated batch, with or without types. diff --git a/nemo/utils/exp_manager.py b/nemo/utils/exp_manager.py index be9a6e8cfbb3..5c7cac5a9a55 100644 --- a/nemo/utils/exp_manager.py +++ b/nemo/utils/exp_manager.py @@ -304,9 +304,9 @@ def exp_manager(trainer: 'pytorch_lightning.Trainer', cfg: Optional[Union[DictCo recent checkpoint under ``*last.ckpt``, and the final checkpoint after training completes under ``*end.ckpt``. Defaults to True. - create_early_stopping_callback (bool): Flag to decide if early stopping should be used to stop training. Default is False. - See EarlyStoppingParams dataclass above. + See EarlyStoppingParams dataclass above. - create_preemption_callback (bool): Flag to decide whether to enable preemption callback to save checkpoints and exit training - immediately upon preemption. Default is True. + immediately upon preemption. Default is True. - files_to_copy (list): A list of files to copy to the experiment logging directory. Defaults to None which copies no files. - log_local_rank_0_only (bool): Whether to only create log files for local rank 0. Defaults to False. From 9e2325d18b4a0e6576ffabe8003c3cad26eb3954 Mon Sep 17 00:00:00 2001 From: Valerie Sarge Date: Wed, 1 May 2024 16:34:21 -0700 Subject: [PATCH 06/73] Handle case where num_query_groups is set to null for LoRA config setup (#9075) Signed-off-by: Valerie Sarge --- nemo/collections/nlp/parts/peft_config.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/nemo/collections/nlp/parts/peft_config.py b/nemo/collections/nlp/parts/peft_config.py index 47d5167d630e..820e2ad63f24 100644 --- a/nemo/collections/nlp/parts/peft_config.py +++ b/nemo/collections/nlp/parts/peft_config.py @@ -123,6 +123,9 @@ def __init__(self, cfg): kv_channels = self._calculate_kv_channels(cfg) projection_size = kv_channels * cfg.num_attention_heads num_query_groups = cfg.get("num_query_groups", cfg.num_attention_heads) + if num_query_groups is None: + # Cover the case where num_query_groups is explicitly set to null + num_query_groups = cfg.num_attention_heads qkv_projection_size = projection_size + (2 * kv_channels * num_query_groups) From d66ca999b80bb9da0af05da13b6b3b51142535dc Mon Sep 17 00:00:00 2001 From: Alexey Panteleev Date: Wed, 1 May 2024 17:33:32 -0700 Subject: [PATCH 07/73] TRT-LLM export P-tuning related fixes (#8863) * Fixed the uses of pathlib.Path. Signed-off-by: Alexey Panteleev * Add the bos token to LLAMA based models. Signed-off-by: Alexey Panteleev * P-tuning related fixes: - Remember the vtoken counts for each p-tuning table when the tables are added; - Prepend the right number of vtokens to each query based on its task_id; - Preserve the dtype of the p-tuning table when it is padded; - Validate that all p-tuning tables fit into max_prompt_embedding_table_size limit. Signed-off-by: Alexey Panteleev * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Alexey Panteleev Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Pablo Garay --- nemo/export/tensorrt_llm.py | 24 ++++++++-- nemo/export/trt_llm/tensorrt_llm_model.py | 4 +- nemo/export/trt_llm/tensorrt_llm_run.py | 55 ++++++++++++++++++----- 3 files changed, 68 insertions(+), 15 deletions(-) diff --git a/nemo/export/tensorrt_llm.py b/nemo/export/tensorrt_llm.py index 40fb93816a33..033044b3b328 100644 --- a/nemo/export/tensorrt_llm.py +++ b/nemo/export/tensorrt_llm.py @@ -97,6 +97,7 @@ def __init__(self, model_dir: str, lora_ckpt_list: List[str] = None, load_model: self.ptuning_tables = [] self.p_table = None self.task_vocab_size = 0 + self.task_vtoken_counts = [] self.task_ids = {} if load_model: @@ -358,12 +359,15 @@ def forward( prompt_embeddings_table, prompt_embeddings_checkpoint_path ) tv_size = prompt_table.size(dim=0) + task_vtoken_counts = [tv_size] elif len(self.ptuning_tables) > 0: prompt_table = self.p_table tv_size = self.task_vocab_size + task_vtoken_counts = self.task_vtoken_counts else: prompt_table = None tv_size = None + task_vtoken_counts = None if task_ids is None: assert prompt_table is None, "There is a prompt embedding table and task_ids cannot be None" @@ -404,6 +408,7 @@ def forward( temperature=temperature, prompt_table=prompt_table, task_vocab_size=tv_size, + task_vtoken_counts=task_vtoken_counts, task_ids=input_task_ids, lora_uids=lora_uids, stop_words_list=stop_words_list, @@ -423,6 +428,7 @@ def forward( temperature=temperature, prompt_table=prompt_table, task_vocab_size=tv_size, + task_vtoken_counts=task_vtoken_counts, task_ids=input_task_ids, lora_uids=lora_uids, stop_words_list=stop_words_list, @@ -578,19 +584,31 @@ def _prep_ptuning_table(self): if self.task_vocab_size < pt["table"].size(dim=0): self.task_vocab_size = pt["table"].size(dim=0) - # pad tasks to longest task embedding table + # pad tasks to longest task embedding table, remember the original task vtoken counts vtokens_embeddings = [] + self.task_vtoken_counts = [] self.task_ids = {} tid = 0 for i, ptuning_table in enumerate(self.ptuning_tables): - padded_table = torch.zeros((self.task_vocab_size, self.get_hidden_size)) - padded_table[: ptuning_table["table"].size(dim=0), :] = ptuning_table["table"] + original_table = ptuning_table["table"] + vtoken_count = original_table.size(dim=0) + padded_table = torch.zeros((self.task_vocab_size, self.get_hidden_size), dtype=original_table.dtype) + padded_table[:vtoken_count, :] = original_table vtokens_embeddings.append(padded_table) self.task_ids[ptuning_table["task_name"]] = tid + self.task_vtoken_counts.append(vtoken_count) tid = tid + 1 if len(vtokens_embeddings) > 0: self.p_table = torch.stack(vtokens_embeddings, dim=0).view(-1, self.get_hidden_size) + + max_prompt_embedding_table_size = self.config['builder_config']['max_prompt_embedding_table_size'] + actual_prompt_table_size = self.p_table.shape[0] + + if actual_prompt_table_size > max_prompt_embedding_table_size: + raise Exception( + f"The size of the combined prompt embedding table ({actual_prompt_table_size}) is greater than max_prompt_embedding_table_size ({max_prompt_embedding_table_size})." + ) else: self.p_table = None diff --git a/nemo/export/trt_llm/tensorrt_llm_model.py b/nemo/export/trt_llm/tensorrt_llm_model.py index 52e9c4960fc9..736d6180807e 100644 --- a/nemo/export/trt_llm/tensorrt_llm_model.py +++ b/nemo/export/trt_llm/tensorrt_llm_model.py @@ -26,7 +26,7 @@ from tensorrt_llm.module import Module, ModuleList from nemo.export.trt_llm.decoder import build_decoder_layer -from nemo.export.trt_llm.model_config import DECODER_GEMMA, ModelConfig +from nemo.export.trt_llm.model_config import DECODER_GEMMA, DECODER_LLAMA, ModelConfig from nemo.export.trt_llm.quantization_utils import quantize_linear from nemo.export.trt_llm.tensorrt_llm_build import build from nemo.export.trt_llm.tensorrt_llm_utils import ( @@ -65,7 +65,7 @@ def __init__(self, model_config: ModelConfig): else model_config.head_size ) self._use_prompt_tuning = model_config.use_prompt_tuning - self._add_bos = model_config.layers[0].decoder_type == DECODER_GEMMA + self._add_bos = model_config.layers[0].decoder_type in (DECODER_GEMMA, DECODER_LLAMA) self._mapping = model_config.mapping self.rank = model_config.mapping.rank self.max_lora_rank = model_config.max_lora_rank diff --git a/nemo/export/trt_llm/tensorrt_llm_run.py b/nemo/export/trt_llm/tensorrt_llm_run.py index d7e3e40c87a2..c490f37e1fc4 100644 --- a/nemo/export/trt_llm/tensorrt_llm_run.py +++ b/nemo/export/trt_llm/tensorrt_llm_run.py @@ -491,6 +491,47 @@ def forward( raise RuntimeError("Internal error") +def prepare_input_tensors( + input_texts: List[str], + host_context: TensorrtLLMHostContext, + prompt_table=None, + task_vtoken_counts: List[int] = None, + task_ids: List[int] = None, +): + tokenizer = host_context.tokenizer + + if host_context.add_bos: + bos_tokens = [tokenizer.bos_token_id] + else: + bos_tokens = [] + + input_tokens = [bos_tokens + tokenizer.encode(t) for t in input_texts] + + # If p-tuning is used, we need to prepend vtokens to each input. + if prompt_table is not None: + + # Go over the tokenized prompts and prepend vtokens. + # The number of vtokens could be different for each task. + for prompt_index in range(len(input_texts)): + # Find out the number of vtokens to generate + task_id = task_ids[prompt_index] + num_vtokens = task_vtoken_counts[task_id] + + # Create a tensor with vtokens, e.g. 32000, 32001, 32002... when vocab_size=32000 + # TRT-LLM will convert each vtoken into its corresponding embedding row from the prompt table. + vocab_size = tokenizer.vocab_size + vtokens = list(range(vocab_size, vocab_size + num_vtokens)) + + # Concatenate the vtokens with the real tokens + real_tokens = input_tokens[prompt_index] + input_tokens[prompt_index] = vtokens + real_tokens + + # Convert input token lists to tensors + input_tensors = [torch.IntTensor(token_list) for token_list in input_tokens] + + return input_tensors + + def generate( input_texts: List[str], max_output_len: int, @@ -500,6 +541,7 @@ def generate( temperature: float = 1.0, prompt_table=None, task_vocab_size=None, + task_vtoken_counts: List[int] = None, task_ids: List[int] = None, lora_uids: List[str] = None, stop_words_list=None, @@ -515,11 +557,7 @@ def generate( Returns a 2D string list with shape [batch_size, num_beams]. """ tokenizer = host_context.tokenizer - - if host_context.add_bos: - input_tensors = [torch.IntTensor([tokenizer.bos_token_id] + tokenizer.encode(t)) for t in input_texts] - else: - input_tensors = [torch.IntTensor(tokenizer.encode(t)) for t in input_texts] + input_tensors = prepare_input_tensors(input_texts, host_context, prompt_table, task_vtoken_counts, task_ids) stop_words_list_tensors = None if stop_words_list is not None: @@ -582,6 +620,7 @@ def generate_streaming( temperature: float = 1.0, prompt_table=None, task_vocab_size=None, + task_vtoken_counts: List[int] = None, task_ids: List[int] = None, lora_uids: List[str] = None, stop_words_list=None, @@ -594,11 +633,7 @@ def generate_streaming( Returns a 2D string list with shape [batch_size, num_beams]. """ tokenizer = host_context.tokenizer - - if host_context.add_bos: - input_tensors = [torch.IntTensor([tokenizer.bos_token_id] + tokenizer.encode(t)) for t in input_texts] - else: - input_tensors = [torch.IntTensor(tokenizer.encode(t)) for t in input_texts] + input_tensors = prepare_input_tensors(input_texts, host_context, prompt_table, task_vtoken_counts, task_ids) batch_size = len(input_texts) From 0643511a29101801afad070c80d26040d48eaa3a Mon Sep 17 00:00:00 2001 From: Marc Romeyn Date: Thu, 2 May 2024 05:47:11 +0200 Subject: [PATCH 08/73] [NeMo-UX] Add mixed-precision plugin (#9065) * Adding MegatronParallel * Move over _strategy_liMegatronCheckpointIO * Adding GPTModel & MockDataModule * Adding mixed-precision to NeMo * Fix import * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * revert unintended changes Signed-off-by: Chen Cui * clean up code and reinstate mix precision tests Signed-off-by: Chen Cui * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * clean up Signed-off-by: Chen Cui * use cpu for unit test Signed-off-by: Chen Cui --------- Signed-off-by: Chen Cui Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Chen Cui --- nemo/lightning/__init__.py | 11 +- nemo/lightning/pytorch/plugins/__init__.py | 6 +- .../pytorch/plugins/mixed_precision.py | 166 ++++++++++++++++++ tests/lightning/test_megatron_parallel.py | 106 ++++++----- 4 files changed, 232 insertions(+), 57 deletions(-) create mode 100644 nemo/lightning/pytorch/plugins/mixed_precision.py diff --git a/nemo/lightning/__init__.py b/nemo/lightning/__init__.py index f900345f96eb..afbdb39f42d4 100644 --- a/nemo/lightning/__init__.py +++ b/nemo/lightning/__init__.py @@ -4,7 +4,7 @@ from pytorch_lightning import plugins as _pl_plugins from nemo.lightning.base import get_vocab_size, teardown -from nemo.lightning.pytorch.plugins import MegatronDataSampler +from nemo.lightning.pytorch.plugins import MegatronDataSampler, MegatronMixedPrecision from nemo.lightning.pytorch.plugins import data_sampler as _data_sampler from nemo.lightning.pytorch.strategies import MegatronStrategy from nemo.lightning.pytorch.trainer import Trainer @@ -22,4 +22,11 @@ def _is_slurm_interactive_mode(): _pl_plugins._PLUGIN_INPUT = Union[_pl_plugins._PLUGIN_INPUT, _data_sampler.DataSampler] # noqa: SLF001 -__all__ = ["MegatronStrategy", "MegatronDataSampler", "Trainer", "get_vocab_size", "teardown"] +__all__ = [ + "MegatronStrategy", + "MegatronDataSampler", + "MegatronMixedPrecision", + "Trainer", + "get_vocab_size", + "teardown", +] diff --git a/nemo/lightning/pytorch/plugins/__init__.py b/nemo/lightning/pytorch/plugins/__init__.py index 45f88a383681..d99e1a3ca7b9 100644 --- a/nemo/lightning/pytorch/plugins/__init__.py +++ b/nemo/lightning/pytorch/plugins/__init__.py @@ -1,3 +1,7 @@ from nemo.lightning.pytorch.plugins.data_sampler import MegatronDataSampler +from nemo.lightning.pytorch.plugins.mixed_precision import MegatronMixedPrecision -__all__ = ["MegatronDataSampler"] +__all__ = [ + "MegatronDataSampler", + "MegatronMixedPrecision", +] diff --git a/nemo/lightning/pytorch/plugins/mixed_precision.py b/nemo/lightning/pytorch/plugins/mixed_precision.py new file mode 100644 index 000000000000..af7054526957 --- /dev/null +++ b/nemo/lightning/pytorch/plugins/mixed_precision.py @@ -0,0 +1,166 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from contextlib import contextmanager +from typing import Any, Callable, Generator, List, Literal, Tuple, TypeVar, Union + +import pytorch_lightning as pl +import torch +from pytorch_lightning.plugins.precision import MixedPrecision +from torch.nn import Module +from torch.optim import Optimizer + +from nemo.lightning._strategy_lib import GradScaler + +AnyT = TypeVar("AnyT") + + +class MegatronMixedPrecision(MixedPrecision): + def __init__(self, precision: Literal["16-mixed", "bf16-mixed"], amp_O2: bool = True, device="cuda",) -> None: + if precision == "bf16-mixed": + scaler = None + else: + scaler = GradScaler(init_scale=2 ** 32, growth_interval=1000, hysteresis=2) + + super().__init__(precision, device, scaler) + + # MixedPrecisionPlugin class in PTL >= 2.0 takes only "16-mixed" or "bf16-mixed" for precision arg + if precision == "16-mixed": + dtype = torch.float16 + + def float16_convertor(val): + return val.half() + + elif precision == "bf16-mixed": + dtype = torch.bfloat16 + + def float16_convertor(val): + return val.bfloat16() + + else: + raise ValueError("precision must be '16-mixed' or 'bf16-mixed'") + + self.dtype = dtype + torch.set_autocast_gpu_dtype(dtype) + self.float16_convertor = float16_convertor + self.amp_O2 = amp_O2 + + def connect( + self, model: Module, optimizers: List[Optimizer], lr_schedulers: List[Any] + ) -> Tuple[Module, List[Optimizer], List[Any]]: + """Connects this plugin to the accelerator and the training process.""" + from nemo.core.optim import MainParamsOptimizerWrapper + + if not optimizers or not self.amp_O2 or isinstance(optimizers[0], MainParamsOptimizerWrapper): + return model, optimizers, lr_schedulers + + _optimizers = [*optimizers] + _optimizers[0] = self.convert_optimizer(_optimizers[0]) + + return model, _optimizers, lr_schedulers + + def convert_module(self, module: Module) -> Module: + """Convert the module parameters to the precision type this plugin handles. + + This is optional and depends on the precision limitations during optimization. + + """ + if self.precision == "bf16-mixed": + return module.bfloat16() + if self.precision == "16-mixed": + return module.half() + + return module + + def convert_optimizer(self, optimizer: Optimizer) -> Optimizer: + """Convert the optimizer parameters to the precision type this plugin handles. + + This is optional and depends on the precision limitations during optimization. + + """ + from nemo.core.optim import MainParamsOptimizerWrapper + + if isinstance(optimizer, MainParamsOptimizerWrapper) or not self.amp_O2: + return optimizer + + return MainParamsOptimizerWrapper(optimizer, fp32_grad_accum=True, contiguous_grad_bucket=True,) + + def convert_input(self, data: AnyT) -> AnyT: + """Convert model inputs (forward) to the floating point precision type of this plugin. + + Note: MegatronStrategy will take care of only doing this when: + parallel_state.is_pipeline_first_stage() + + """ + from megatron.core.transformer.module import fp32_to_float16 + + return fp32_to_float16(data, self.float16_convertor) + + def convert_output(self, data: AnyT) -> AnyT: + """Convert outputs to the floating point precision type expected after model's forward. + + Note: MegatronStrategy will take care of only doing this when: + parallel_state.is_pipeline_last_stage() + + """ + from megatron.core.transformer.module import float16_to_fp32 + + return float16_to_fp32(data) + + def optimizer_step( + self, + optimizer: torch.optim.Optimizer, + model: Union["pl.LightningModule", torch.nn.Module], + closure: Callable[[], Any], + **kwargs: Any, + ) -> None: + from nemo.core.optim import MainParamsOptimizerWrapper + + if not self.amp_O2 and not isinstance(optimizer, MainParamsOptimizerWrapper): + return super().optimizer_step(optimizer, model, closure, **kwargs) + + if self.scaler is None: + assert optimizer.fp32_grad_accumulation, "BF16 uses FP32 grad accumulation" + _ = closure() + self._after_closure(model, optimizer) + return optimizer.step(**kwargs) + + assert not optimizer.fp32_grad_accumulation, "FP16 uses FP16 grad accumulation" + closure_result = closure() + + # TODO: Add an option for merged all-reduce + + # cast fp16 grads to fp32 and copy to main grads, which are used for unscale and param update + optimizer.copy_model_grads_to_main_grads() + # `unscale` after the closure is executed but before the `on_before_optimizer_step` hook. + # unscale main (fp32) gradients + self.scaler.unscale_(optimizer) + self._after_closure(model, optimizer) + skipped_backward = closure_result is None + # in manual optimization, the closure does not return a value + if not isinstance(model, pl.LightningModule) or not model.automatic_optimization or not skipped_backward: + # note: the scaler will skip the `optimizer.step` if nonfinite gradients are found + self.scaler.step(optimizer, **kwargs) + self.scaler.update() + + @contextmanager + def forward_context(self) -> Generator[None, None, None]: + """No explicit precision casting. Inputs are supposed to be manually casted.""" + try: + yield + finally: + pass + + +__all__ = ["MegatronMixedPrecision"] diff --git a/tests/lightning/test_megatron_parallel.py b/tests/lightning/test_megatron_parallel.py index 06e614d48251..877e6a39a976 100644 --- a/tests/lightning/test_megatron_parallel.py +++ b/tests/lightning/test_megatron_parallel.py @@ -1,6 +1,7 @@ from collections import defaultdict import pytest +from megatron.core import parallel_state from torch import nn from nemo import lightning as nl @@ -24,11 +25,10 @@ def forward(self, x): return DummyModule() - # TODO (chcui): Uncomment this test when we merge mixed-precision - # @pytest.fixture - # def mock_precision_plugin(self, mocker): - # """Fixture to create a mock precision plugin.""" - # return nl.MegatronMixedPrecision(precision="bf16-mixed") + @pytest.fixture + def mock_precision_plugin(self, mocker): + """Fixture to create a mock precision plugin.""" + return nl.MegatronMixedPrecision(precision="bf16-mixed") @pytest.fixture def mock_callbacks(self, mocker): @@ -64,55 +64,53 @@ def test_init_with_defaults(self, mocker, mock_pipeline): assert megatron_parallel.forward_step == mp.default_forward_step assert megatron_parallel.loss_reduction is None - # TODO (chcui): Uncomment this test when we merge mixed-precision - # def test_init_with_custom_parameters( - # self, - # mocker, - # mock_pipeline, - # mock_precision_plugin, - # mock_callbacks, - # mock_data_step, - # mock_forward_step, - # mock_loss_reduction - # ): - # """Test __init__ with custom parameters.""" - # mocker.patch('megatron.core.parallel_state.get_pipeline_model_parallel_world_size', return_value=1) - # mocker.patch('megatron.core.parallel_state.model_parallel_is_initialized', return_value=False) - # - # megatron_parallel = mp.MegatronParallel( - # pipeline=mock_pipeline, - # precision_plugin=mock_precision_plugin, - # callbacks=mock_callbacks, - # data_step=mock_data_step, - # forward_step=mock_forward_step, - # loss_reduction=mock_loss_reduction - # ) - # - # assert megatron_parallel.pipeline == mock_pipeline - # assert megatron_parallel.precision_plugin == mock_precision_plugin - # assert megatron_parallel.callbacks == mock_callbacks - # assert megatron_parallel.data_step == mock_data_step - # assert megatron_parallel.forward_step == mock_forward_step - # assert megatron_parallel.loss_reduction == mock_loss_reduction - - # TODO: Comment-out this test when we merge nemo.io - # def test_init_with_virtual_pipeline(self, mocker, mock_pipeline): - # """Test __init__ with virtual pipeline model parallel world size.""" - # mocker.patch('torch.distributed.get_rank', return_value=1) - # mocker.patch('megatron.core.parallel_state.get_tensor_model_parallel_group', return_value=1) - # mocker.patch('megatron.core.parallel_state.get_pipeline_model_parallel_group', return_value=1) - # mocker.patch('megatron.core.parallel_state.get_pipeline_model_parallel_world_size', return_value=2) - # mocker.patch('megatron.core.parallel_state.model_parallel_is_initialized', return_value=True) - # mocker.patch('megatron.core.parallel_state.set_virtual_pipeline_model_parallel_world_size') - # mocker.patch('megatron.core.parallel_state.set_virtual_pipeline_model_parallel_rank') - # mocker.patch('nemo_ext.lightning._strategy_lib.init_lightning_module', return_value=mock_pipeline) - - # megatron_parallel = mp.MegatronParallel(mock_pipeline, vp_size=2) - - # assert len(megatron_parallel.pipeline) == 2 - # assert all(isinstance(mod, nn.Module) for mod in megatron_parallel.pipeline) - # megatron.core.parallel_state.set_virtual_pipeline_model_parallel_world_size.assert_called_once_with(2) - # assert megatron.core.parallel_state.set_virtual_pipeline_model_parallel_rank.call_count == 1 + def test_init_with_custom_parameters( + self, + mocker, + mock_pipeline, + mock_precision_plugin, + mock_callbacks, + mock_data_step, + mock_forward_step, + mock_loss_reduction, + ): + """Test __init__ with custom parameters.""" + mocker.patch('megatron.core.parallel_state.get_pipeline_model_parallel_world_size', return_value=1) + mocker.patch('megatron.core.parallel_state.model_parallel_is_initialized', return_value=False) + + megatron_parallel = mp.MegatronParallel( + pipeline=mock_pipeline, + precision_plugin=mock_precision_plugin, + callbacks=mock_callbacks, + data_step=mock_data_step, + forward_step=mock_forward_step, + loss_reduction=mock_loss_reduction, + ) + + assert megatron_parallel.pipeline == mock_pipeline + assert megatron_parallel.precision_plugin == mock_precision_plugin + assert megatron_parallel.callbacks == mock_callbacks + assert megatron_parallel.data_step == mock_data_step + assert megatron_parallel.forward_step == mock_forward_step + assert megatron_parallel.loss_reduction == mock_loss_reduction + + def test_init_with_virtual_pipeline(self, mocker, mock_pipeline): + """Test __init__ with virtual pipeline model parallel world size.""" + mocker.patch('torch.distributed.get_rank', return_value=1) + mocker.patch('megatron.core.parallel_state.get_tensor_model_parallel_group', return_value=1) + mocker.patch('megatron.core.parallel_state.get_pipeline_model_parallel_group', return_value=1) + mocker.patch('megatron.core.parallel_state.get_pipeline_model_parallel_world_size', return_value=2) + mocker.patch('megatron.core.parallel_state.model_parallel_is_initialized', return_value=True) + mocker.patch('megatron.core.parallel_state.set_virtual_pipeline_model_parallel_world_size') + mocker.patch('megatron.core.parallel_state.set_virtual_pipeline_model_parallel_rank') + mocker.patch('nemo.io.reinit', return_value=mock_pipeline) + + megatron_parallel = mp.MegatronParallel(mock_pipeline, vp_size=2, cpu=True) + + assert len(megatron_parallel.pipeline) == 2 + assert all(isinstance(mod, nn.Module) for mod in megatron_parallel.pipeline) + parallel_state.set_virtual_pipeline_model_parallel_world_size.assert_called_once_with(2) + assert parallel_state.set_virtual_pipeline_model_parallel_rank.call_count == 1 class TestCallbackConnector: From a8e0ca1b6206b4158c96781176f5b0d80b49f9cc Mon Sep 17 00:00:00 2001 From: Eric Harper Date: Thu, 2 May 2024 00:06:20 -0600 Subject: [PATCH 09/73] Comment baichuan test and update pr template (#9085) * comment test Signed-off-by: eharper * comment test Signed-off-by: eharper --------- Signed-off-by: eharper --- .github/PULL_REQUEST_TEMPLATE.md | 8 ++--- .github/workflows/cicd-main.yml | 51 +++++++++++++++++--------------- 2 files changed, 31 insertions(+), 28 deletions(-) diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md index 2c4946bbbde1..ae22ede4807b 100644 --- a/.github/PULL_REQUEST_TEMPLATE.md +++ b/.github/PULL_REQUEST_TEMPLATE.md @@ -14,13 +14,13 @@ Add a one line overview of what this PR aims to accomplish. # Add a code snippet demonstrating how to use this ``` -# Jenkins CI +# GitHub Actions CI The Jenkins CI system has been replaced by GitHub Actions self-hosted runners. -There's no need to comment `jenkins` on the PR to trigger Jenkins CI. -The GitHub Actions CI will run automatically when the PR is opened. -To run CI on an untrusted fork, a NeMo user with write access must click "Approve and run". +The GitHub Actions CI will run automatically when the "Run CICD" label is added to the PR. +To re-run CI remove and add the label again. +To run CI on an untrusted fork, a NeMo user with write access must first click "Approve and run". # Before your PR is "Ready for review" **Pre checks**: diff --git a/.github/workflows/cicd-main.yml b/.github/workflows/cicd-main.yml index 6f090bd34213..df631443e7f7 100644 --- a/.github/workflows/cicd-main.yml +++ b/.github/workflows/cicd-main.yml @@ -319,29 +319,32 @@ jobs: - uses: "NVIDIA/NeMo/.github/actions/cancel-workflow@main" if: "failure()" - L2_Community_LLM_Checkpoints_tests_Baichuan2: - needs: [cicd-test-container-setup] - runs-on: self-hosted-azure - container: - image: nemoci.azurecr.io/nemo_container_${{ github.run_id }} - options: - # --user 0:128 - --device=/dev/nvidia0 - --gpus all - --shm-size=8g - --env TRANSFORMERS_OFFLINE=0 - --env HYDRA_FULL_ERROR=1 - --volume /mnt/datadrive/TestData:/home/TestData - steps: - - name: Checkout repository - uses: actions/checkout@v4 - - run: | - python scripts/checkpoint_converters/convert_baichuan2_hf_to_nemo.py \ - --input_name_or_path=/home/TestData/nlp/megatron_gpt/Baichuan2-7B-Base \ - --output_path=/home/TestData/nlp/megatron_gpt/Baichuan2-7B-Base/ci.nemo - rm -f /home/TestData/nlp/megatron_gpt/Baichuan2-7B-Base/ci.nemo - - uses: "NVIDIA/NeMo/.github/actions/cancel-workflow@main" - if: "failure()" + # this test is using a 7B model which is too large for GitHub CI + # replace the model in this test with a toy model or move the test + # to the nightly CI + # L2_Community_LLM_Checkpoints_tests_Baichuan2: + # needs: [cicd-test-container-setup] + # runs-on: self-hosted-azure + # container: + # image: nemoci.azurecr.io/nemo_container_${{ github.run_id }} + # options: + # # --user 0:128 + # --device=/dev/nvidia0 + # --gpus all + # --shm-size=8g + # --env TRANSFORMERS_OFFLINE=0 + # --env HYDRA_FULL_ERROR=1 + # --volume /mnt/datadrive/TestData:/home/TestData + # steps: + # - name: Checkout repository + # uses: actions/checkout@v4 + # - run: | + # python scripts/checkpoint_converters/convert_baichuan2_hf_to_nemo.py \ + # --input_name_or_path=/home/TestData/nlp/megatron_gpt/Baichuan2-7B-Base \ + # --output_path=/home/TestData/nlp/megatron_gpt/Baichuan2-7B-Base/ci.nemo + # rm -f /home/TestData/nlp/megatron_gpt/Baichuan2-7B-Base/ci.nemo + # - uses: "NVIDIA/NeMo/.github/actions/cancel-workflow@main" + # if: "failure()" L2_PTQ_Llama2_Export_Only: needs: [cicd-test-container-setup] @@ -6370,7 +6373,7 @@ jobs: - L2_Community_LLM_Checkpoints_tests_Llama - L2_Community_LLM_Checkpoints_tests_StarCoder - L2_Community_LLM_Checkpoints_tests_Falcon - - L2_Community_LLM_Checkpoints_tests_Baichuan2 + #- L2_Community_LLM_Checkpoints_tests_Baichuan2 - ASR_dev_run_Speech_to_Text - ASR_dev_run_Speech_to_Text_WPE_-_CitriNet - ASR_dev_run_Speech_Pre-training_-_CitriNet From f15e8975fd12e23ca1fd887222e7a636c52a8167 Mon Sep 17 00:00:00 2001 From: Abhishree Thittenamane <47577437+athitten@users.noreply.github.com> Date: Thu, 2 May 2024 09:48:10 -0700 Subject: [PATCH 10/73] Add safe extraction of nemo tar files (#8976) * Add safe extraction of nemo tar files Signed-off-by: Abhishree * Fix bugs Signed-off-by: Abhishree * Replace print with logging Signed-off-by: Abhishree Thittenamane <47577437+athitten@users.noreply.github.com> --------- Signed-off-by: Abhishree Signed-off-by: Abhishree Thittenamane <47577437+athitten@users.noreply.github.com> Co-authored-by: Eric Harper Co-authored-by: Pablo Garay --- .../core/connectors/save_restore_connector.py | 27 +++++++++++++++++-- 1 file changed, 25 insertions(+), 2 deletions(-) diff --git a/nemo/core/connectors/save_restore_connector.py b/nemo/core/connectors/save_restore_connector.py index 2d01e9d5bad8..70d91066b7f0 100644 --- a/nemo/core/connectors/save_restore_connector.py +++ b/nemo/core/connectors/save_restore_connector.py @@ -553,6 +553,29 @@ def _make_nemo_file_from_folder(filename, source_dir): with tarfile.open(filename, "w:") as tar: tar.add(source_dir, arcname=".") + @staticmethod + def _is_safe_path(member, extract_to): + # Check for path traversal characters or absolute paths + member_path = os.path.normpath(member.name) + # Ensure the path does not start with a slash or contain ".." after normalization + if os.path.isabs(member_path) or ".." in member_path.split(os.sep): + return False + # Construct the full path where the member would be extracted + full_path = os.path.join(extract_to, member_path) + # Ensure the member would be extracted within the intended directory + return os.path.commonprefix([full_path, extract_to]) == extract_to + + @staticmethod + def _safe_extract(tar, out_folder: str, members=None): + extract_to = os.path.realpath(out_folder) + if members is None: + members = tar.getmembers() + for member in members: + if SaveRestoreConnector._is_safe_path(member, extract_to): + tar.extract(member, extract_to) + else: + logging.warning(f"Skipping potentially unsafe member: {member.name}") + @staticmethod def _unpack_nemo_file(path2file: str, out_folder: str, extract_config_only: bool = False) -> str: if not os.path.exists(path2file): @@ -569,10 +592,10 @@ def _unpack_nemo_file(path2file: str, out_folder: str, extract_config_only: bool tar_header = "r:gz" tar = tarfile.open(path2file, tar_header) if not extract_config_only: - tar.extractall(path=out_folder) + SaveRestoreConnector._safe_extract(tar, out_folder) else: members = [x for x in tar.getmembers() if ".yaml" in x.name] - tar.extractall(path=out_folder, members=members) + SaveRestoreConnector._safe_extract(tar, out_folder, members) tar.close() return out_folder From 9100cfd6462e1dbd5119b5affa845b1c061f265b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Piotr=20=C5=BBelasko?= Date: Thu, 2 May 2024 14:13:46 -0400 Subject: [PATCH 11/73] PyTorch CUDA allocator optimization for dynamic batch shape dataloading in ASR (#9061) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Option to auto-set expandable_segments in PyTorch CUDA allocator Signed-off-by: Piotr Żelasko * warning Signed-off-by: Piotr Żelasko * set opts after parsing config Signed-off-by: Piotr Żelasko --------- Signed-off-by: Piotr Żelasko --- .../common/data/lhotse/dataloader.py | 27 ++++++++++++++++++- 1 file changed, 26 insertions(+), 1 deletion(-) diff --git a/nemo/collections/common/data/lhotse/dataloader.py b/nemo/collections/common/data/lhotse/dataloader.py index eabc3da5d11b..191ac54589e5 100644 --- a/nemo/collections/common/data/lhotse/dataloader.py +++ b/nemo/collections/common/data/lhotse/dataloader.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +import os import warnings from dataclasses import dataclass from functools import partial @@ -74,6 +74,7 @@ class LhotseDataLoadingConfig: drop_last: bool = False shard_seed: int | str = "trng" max_open_streams: int | None = None + cuda_expandable_segments: bool = True # 2.1 Multimodal sampling override options use_multimodal_sampling: bool = False @@ -150,6 +151,8 @@ def get_lhotse_dataloader_from_config( config = make_structured_with_schema_warnings(config) + maybe_set_cuda_expandable_segments(enabled=config.cuda_expandable_segments) + # First, resolve the random seed in case a string value was provided. seed = resolve_seed(config.seed) fix_random_seed(seed) @@ -451,6 +454,28 @@ def _flatten_alt_text(cut) -> list: return ans +def maybe_set_cuda_expandable_segments(enabled: bool): + """ + Configures PyTorch memory allocator to expand existing allocated segments + instead of re-allocating them when tensor shape grows. + This can help speed up the training when sequence length and/or batch size change often, + and makes GPU more robust towards OOM. + + See here for more details: + https://pytorch.org/docs/stable/notes/cuda.html#optimizing-memory-usage-with-pytorch-cuda-alloc-conf + """ + if enabled and torch.cuda.is_available(): + if ( + (value := os.environ.get("PYTORCH_CUDA_ALLOC_CONF")) is not None + and len(value) > 0 + and "expandable_segments:True" not in value + ): + warnings.warn( + "You have set PYTORCH_CUDA_ALLOC_CONF without expandable_segments:True option. We're setting that option anyway. To disable it, set cuda_expandable_segments=False in NeMo dataloader configuration." + ) + torch.cuda.memory._set_allocator_settings("expandable_segments:True") + + def _select_channel(cut, channel_selector: int | str) -> list: if isinstance(channel_selector, int): channel_idx = channel_selector From f769ad504890f4798fe2a679d337d0ddf2c05fe3 Mon Sep 17 00:00:00 2001 From: Ryan Langman Date: Thu, 2 May 2024 12:04:32 -0700 Subject: [PATCH 12/73] [TTS] Add tutorial for training audio codecs (#8723) * [TTS] Add tutorial for training audio codecs Signed-off-by: Ryan * [TTS] Update tutorial Signed-off-by: Ryan * [TTS] Add diagrams Signed-off-by: Ryan * [TTS] Add introduction and references Signed-off-by: Ryan * [TTS] Replace diagram with github release link Signed-off-by: Ryan --------- Signed-off-by: Ryan --- examples/tts/audio_codec.py | 1 + .../conf/audio_codec/audio_codec_16000.yaml | 8 +- .../conf/audio_codec/audio_codec_24000.yaml | 12 +- .../tts/conf/audio_codec/encodec_24000.yaml | 8 +- .../tts/conf/audio_codec/mel_codec_22050.yaml | 194 +++++ .../tts/conf/audio_codec/mel_codec_44100.yaml | 10 +- tutorials/tts/Audio_Codec_Training.ipynb | 800 ++++++++++++++++++ 7 files changed, 1011 insertions(+), 22 deletions(-) create mode 100644 examples/tts/conf/audio_codec/mel_codec_22050.yaml create mode 100644 tutorials/tts/Audio_Codec_Training.ipynb diff --git a/examples/tts/audio_codec.py b/examples/tts/audio_codec.py index 800edfb7fb0f..5fc4b6fd0afd 100644 --- a/examples/tts/audio_codec.py +++ b/examples/tts/audio_codec.py @@ -27,6 +27,7 @@ def main(cfg): trainer = pl.Trainer(**cfg.trainer) exp_manager(trainer, cfg.get("exp_manager", None)) model = AudioCodecModel(cfg=cfg.model, trainer=trainer) + model.maybe_init_from_pretrained_checkpoint(cfg=cfg) trainer.fit(model) diff --git a/examples/tts/conf/audio_codec/audio_codec_16000.yaml b/examples/tts/conf/audio_codec/audio_codec_16000.yaml index 7182414a31db..93b44b579655 100644 --- a/examples/tts/conf/audio_codec/audio_codec_16000.yaml +++ b/examples/tts/conf/audio_codec/audio_codec_16000.yaml @@ -92,13 +92,13 @@ model: log_epochs: [1, 2, 3, 4, 5, 6] epoch_frequency: 1 log_tensorboard: false - log_wandb: true + log_wandb: false generators: - _target_: nemo.collections.tts.parts.utils.callbacks.AudioCodecArtifactGenerator log_audio: true - log_encoding: true - log_dequantized: true + log_encoding: false + log_dequantized: false dataset: _target_: nemo.collections.tts.data.vocoder_dataset.VocoderDataset @@ -129,8 +129,6 @@ model: _target_: nemo.collections.tts.modules.encodec_modules.MultiResolutionDiscriminatorSTFT resolutions: [[128, 32, 128], [256, 64, 256], [512, 128, 512], [1024, 256, 1024], [2048, 512, 2048]] - # The original EnCodec uses hinged loss, but squared-GAN loss is more stable - # and reduces the need to tune the loss weights or use a gradient balancer. generator_loss: _target_: nemo.collections.tts.losses.audio_codec_loss.GeneratorSquaredLoss diff --git a/examples/tts/conf/audio_codec/audio_codec_24000.yaml b/examples/tts/conf/audio_codec/audio_codec_24000.yaml index e5e386722fb1..cf48db807d25 100644 --- a/examples/tts/conf/audio_codec/audio_codec_24000.yaml +++ b/examples/tts/conf/audio_codec/audio_codec_24000.yaml @@ -2,7 +2,7 @@ # If you want to train model on other dataset, you can change config values according to your dataset. # Most dataset-specific arguments are in the head of the config file, see below. -name: EnCodec +name: AudioCodec max_epochs: ??? # Adjust batch size based on GPU memory @@ -90,13 +90,13 @@ model: log_epochs: [10, 50, 100, 150, 200] epoch_frequency: 100 log_tensorboard: false - log_wandb: true + log_wandb: false generators: - _target_: nemo.collections.tts.parts.utils.callbacks.AudioCodecArtifactGenerator log_audio: true - log_encoding: true - log_dequantized: true + log_encoding: false + log_dequantized: false dataset: _target_: nemo.collections.tts.data.vocoder_dataset.VocoderDataset @@ -127,8 +127,6 @@ model: _target_: nemo.collections.tts.modules.encodec_modules.MultiResolutionDiscriminatorSTFT resolutions: [[128, 32, 128], [256, 64, 256], [512, 128, 512], [1024, 256, 1024], [2048, 512, 2048]] - # The original EnCodec uses hinged loss, but squared-GAN loss is more stable - # and reduces the need to tune the loss weights or use a gradient balancer. generator_loss: _target_: nemo.collections.tts.losses.audio_codec_loss.GeneratorSquaredLoss @@ -162,7 +160,7 @@ exp_manager: exp_dir: null name: ${name} create_tensorboard_logger: false - create_wandb_logger: true + create_wandb_logger: false wandb_logger_kwargs: name: null project: null diff --git a/examples/tts/conf/audio_codec/encodec_24000.yaml b/examples/tts/conf/audio_codec/encodec_24000.yaml index 4898d449d520..be66fd4b4979 100644 --- a/examples/tts/conf/audio_codec/encodec_24000.yaml +++ b/examples/tts/conf/audio_codec/encodec_24000.yaml @@ -90,13 +90,13 @@ model: log_epochs: [10, 50, 100, 150, 200] epoch_frequency: 100 log_tensorboard: false - log_wandb: true + log_wandb: false generators: - _target_: nemo.collections.tts.parts.utils.callbacks.AudioCodecArtifactGenerator log_audio: true - log_encoding: true - log_dequantized: true + log_encoding: false + log_dequantized: false dataset: _target_: nemo.collections.tts.data.vocoder_dataset.VocoderDataset @@ -162,7 +162,7 @@ exp_manager: exp_dir: null name: ${name} create_tensorboard_logger: false - create_wandb_logger: true + create_wandb_logger: false wandb_logger_kwargs: name: null project: null diff --git a/examples/tts/conf/audio_codec/mel_codec_22050.yaml b/examples/tts/conf/audio_codec/mel_codec_22050.yaml new file mode 100644 index 000000000000..df77e7747a51 --- /dev/null +++ b/examples/tts/conf/audio_codec/mel_codec_22050.yaml @@ -0,0 +1,194 @@ +# This config contains the default values for training 22.05kHz audio codec model which encodes mel spectrogram +# instead of raw audio. +# If you want to train model on other dataset, you can change config values according to your dataset. +# Most dataset-specific arguments are in the head of the config file, see below. + +name: MelCodec + +max_epochs: ??? +# Adjust batch size based on GPU memory +batch_size: 16 +# When doing weighted sampling with multiple manifests, this defines how many training steps are in an epoch. +# If null, then weighted sampling is disabled. +weighted_sampling_steps_per_epoch: null + +# Dataset metadata for each manifest +# https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/tts/data/vocoder_dataset.py#L39-L41 +train_ds_meta: ??? +val_ds_meta: ??? + +log_ds_meta: ??? +log_dir: ??? + +# Modify these values based on your sample rate +sample_rate: 22050 +win_length: 1024 +hop_length: 256 +train_n_samples: 8192 # ~0.37 seconds +# The product of the up_sample_rates should match the hop_length. +# For example 8 * 8 * 2 * 2 = 256. +up_sample_rates: [8, 8, 2, 2] + + +model: + + max_epochs: ${max_epochs} + steps_per_epoch: ${weighted_sampling_steps_per_epoch} + + sample_rate: ${sample_rate} + samples_per_frame: ${hop_length} + + mel_loss_l1_scale: 1.0 + mel_loss_l2_scale: 0.0 + stft_loss_scale: 20.0 + time_domain_loss_scale: 0.0 + commit_loss_scale: 0.0 + + # Probability of updating the discriminator during each training step + # For example, update the discriminator 1/2 times (1 update for every 2 batches) + disc_updates_per_period: 1 + disc_update_period: 2 + + # All resolutions for mel reconstruction loss, ordered [num_fft, hop_length, window_length] + loss_resolutions: [ + [32, 8, 32], [64, 16, 64], [128, 32, 128], [256, 64, 256], [512, 128, 512], [1024, 256, 1024], [2048, 512, 2048] + ] + mel_loss_dims: [5, 10, 20, 40, 80, 160, 320] + mel_loss_log_guard: 1.0 + stft_loss_log_guard: 1.0 + feature_loss_type: absolute + + train_ds: + dataset: + _target_: nemo.collections.tts.data.vocoder_dataset.VocoderDataset + dataset_meta: ${train_ds_meta} + weighted_sampling_steps_per_epoch: ${weighted_sampling_steps_per_epoch} + sample_rate: ${sample_rate} + n_samples: ${train_n_samples} + min_duration: 0.4 + max_duration: null + + dataloader_params: + batch_size: ${batch_size} + drop_last: true + num_workers: 4 + + validation_ds: + dataset: + _target_: nemo.collections.tts.data.vocoder_dataset.VocoderDataset + sample_rate: ${sample_rate} + n_samples: null + min_duration: null + max_duration: null + trunc_duration: 10.0 # Only use the first 10 seconds of audio for computing validation loss + dataset_meta: ${val_ds_meta} + + dataloader_params: + batch_size: 4 + num_workers: 2 + + # Configures how audio samples are generated and saved during training. + # Remove this section to disable logging. + log_config: + log_dir: ${log_dir} + log_epochs: [10, 50, 100, 150, 200] + epoch_frequency: 100 + log_tensorboard: false + log_wandb: false + + generators: + - _target_: nemo.collections.tts.parts.utils.callbacks.AudioCodecArtifactGenerator + log_audio: true + log_encoding: false + log_dequantized: false + + dataset: + _target_: nemo.collections.tts.data.vocoder_dataset.VocoderDataset + sample_rate: ${sample_rate} + n_samples: null + min_duration: null + max_duration: null + trunc_duration: 10.0 # Only log the first 10 seconds of generated audio. + dataset_meta: ${log_ds_meta} + + dataloader_params: + batch_size: 4 + num_workers: 2 + + audio_encoder: + _target_: nemo.collections.tts.modules.audio_codec_modules.MultiBandMelEncoder + mel_bands: [[0, 10], [10, 20], [20, 30], [30, 40], [40, 50], [50, 60], [60, 70], [70, 80]] + out_channels: 4 # The dimension of each codebook + hidden_channels: 128 + filters: 256 + mel_processor: + _target_: nemo.collections.tts.modules.audio_codec_modules.MelSpectrogramProcessor + mel_dim: 80 + sample_rate: ${sample_rate} + win_length: ${win_length} + hop_length: ${hop_length} + + audio_decoder: + _target_: nemo.collections.tts.modules.audio_codec_modules.HiFiGANDecoder + up_sample_rates: ${up_sample_rates} + input_dim: 32 # Should be equal to len(audio_encoder.mel_bands) * audio_encoder.out_channels + base_channels: 1024 # This is double the base channels of HiFi-GAN V1, making it approximately 4x larger. + + vector_quantizer: + _target_: nemo.collections.tts.modules.audio_codec_modules.GroupFiniteScalarQuantizer + num_groups: 8 # Should equal len(audio_encoder.mel_bands) + num_levels_per_group: [8, 5, 5, 5] # 8 * 5 * 5 * 5 = 1000 entries per codebook + + discriminator: + _target_: nemo.collections.tts.modules.audio_codec_modules.Discriminator + discriminators: + - _target_: nemo.collections.tts.modules.encodec_modules.MultiResolutionDiscriminatorSTFT + resolutions: [[128, 32, 128], [256, 64, 256], [512, 128, 512], [1024, 256, 1024], [2048, 512, 2048]] + - _target_: nemo.collections.tts.modules.audio_codec_modules.MultiPeriodDiscriminator + + generator_loss: + _target_: nemo.collections.tts.losses.audio_codec_loss.GeneratorSquaredLoss + + discriminator_loss: + _target_: nemo.collections.tts.losses.audio_codec_loss.DiscriminatorSquaredLoss + + optim: + _target_: torch.optim.Adam + lr: 2e-4 + betas: [0.8, 0.99] + + sched: + name: ExponentialLR + gamma: 0.998 + +trainer: + num_nodes: 1 + devices: 1 + accelerator: gpu + strategy: ddp_find_unused_parameters_true + precision: 16 + max_epochs: ${max_epochs} + accumulate_grad_batches: 1 + enable_checkpointing: False # Provided by exp_manager + logger: false # Provided by exp_manager + log_every_n_steps: 100 + check_val_every_n_epoch: 5 + benchmark: false + +exp_manager: + exp_dir: null + name: ${name} + create_tensorboard_logger: false + create_wandb_logger: false + wandb_logger_kwargs: + name: null + project: null + create_checkpoint_callback: true + checkpoint_callback_params: + monitor: val_loss + mode: min + save_top_k: 5 + save_best_model: true + always_save_nemo: true + resume_if_exists: false + resume_ignore_no_checkpoint: false diff --git a/examples/tts/conf/audio_codec/mel_codec_44100.yaml b/examples/tts/conf/audio_codec/mel_codec_44100.yaml index 15d12f009ae0..3ae528df6a64 100644 --- a/examples/tts/conf/audio_codec/mel_codec_44100.yaml +++ b/examples/tts/conf/audio_codec/mel_codec_44100.yaml @@ -94,13 +94,13 @@ model: log_epochs: [10, 50, 100, 150, 200] epoch_frequency: 100 log_tensorboard: false - log_wandb: true + log_wandb: false generators: - _target_: nemo.collections.tts.parts.utils.callbacks.AudioCodecArtifactGenerator log_audio: true - log_encoding: true - log_dequantized: true + log_encoding: false + log_dequantized: false dataset: _target_: nemo.collections.tts.data.vocoder_dataset.VocoderDataset @@ -146,8 +146,6 @@ model: resolutions: [[128, 32, 128], [256, 64, 256], [512, 128, 512], [1024, 256, 1024], [2048, 512, 2048]] - _target_: nemo.collections.tts.modules.audio_codec_modules.MultiPeriodDiscriminator - # The original EnCodec uses hinged loss, but squared-GAN loss is more stable - # and reduces the need to tune the loss weights or use a gradient balancer. generator_loss: _target_: nemo.collections.tts.losses.audio_codec_loss.GeneratorSquaredLoss @@ -181,7 +179,7 @@ exp_manager: exp_dir: null name: ${name} create_tensorboard_logger: false - create_wandb_logger: true + create_wandb_logger: false wandb_logger_kwargs: name: null project: null diff --git a/tutorials/tts/Audio_Codec_Training.ipynb b/tutorials/tts/Audio_Codec_Training.ipynb new file mode 100644 index 000000000000..5f42fd73aa2c --- /dev/null +++ b/tutorials/tts/Audio_Codec_Training.ipynb @@ -0,0 +1,800 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "7X-TwhdTGmlc" + }, + "source": [ + "# License" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "fCQUeZRPGnoe" + }, + "source": [ + "> Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n", + ">\n", + "> Licensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with the License. You may obtain a copy of the License at\n", + ">\n", + "> http://www.apache.org/licenses/LICENSE-2.0\n", + ">\n", + "> Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "rtBDkKqVGZJ8" + }, + "source": [ + "# Introduction" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "pZ2QSsXuGbMe" + }, + "source": [ + "In this tutorial we show how to use NeMo to train and fine-tune **neural audio codecs**.\n", + "\n", + "Neural audio codecs are deep learning models that compress audio into a low bitrate representation. The compact embedding space created by these models can be useful for various speech tasks, such as TTS and ASR.\n", + "\n", + "
\n", + "\n", + "
\n", + "\n", + "Audio codec models typically have an *encoder-quantizer-decoder* structure. The **encoder** takes an input audio signal and encodes it into a sequence of embeddings. The **quantizer** discretizes the embeddings to create a lookup table known as a **codebook**. The embeddings saved in the codebook are referred to as **audio codes**. The **decoder** takes the audio codes as input and attempts to reconstruct the original audio signal.\n", + "\n", + "To store compressed audio we only need to save the codebook index for each embedding in an audio sequence. This is how audio codec models achieve low bitrates. The codebook indices for an audio are referred to **audio tokens**. It is becoming common for speech generation models to synthesize speech by predicting audio tokens.\n", + "\n", + "In NeMo we have implementations of the [SEANet encoder and decoder](https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/tts/modules/encodec_modules.py#L146) used by [EnCodec](https://github.com/facebookresearch/encodec). As well as a [ResNet encoder](https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/tts/modules/audio_codec_modules.py#L1035) and [HiFi-GAN decoder](https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/tts/modules/audio_codec_modules.py#L875). For quantizers we support [Residual Vector Quantizer](https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/tts/modules/encodec_modules.py#L694) (**RVQ**) and [Finite Scalar Quantizer](https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/tts/modules/audio_codec_modules.py#L409) (**FSQ**).\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "3OZassNG5xff" + }, + "source": [ + "# Install" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "WZvQvPkIhRi3" + }, + "outputs": [], + "source": [ + "BRANCH = 'main'\n", + "# Install NeMo library. If you are running locally (rather than on Google Colab), comment out the below line\n", + "# and instead follow the instructions at https://github.com/NVIDIA/NeMo#Installation\n", + "!python -m pip install git+https://github.com/NVIDIA/NeMo.git@$BRANCH#egg=nemo_toolkit[all]" + ] + }, + { + "cell_type": "code", + "source": [ + "from pathlib import Path" + ], + "metadata": { + "id": "v8NGOM0EzK8W" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "tvsgWO_WhV3M" + }, + "outputs": [], + "source": [ + "# Directory where tutorialscripts will run and outputs will be saved.\n", + "ROOT_DIR = Path().absolute() / \"codec_tutorial\"\n", + "\n", + "# Nemo code paths\n", + "NEMO_DIR = ROOT_DIR / \"nemo\"\n", + "NEMO_SCRIPT_DIR = NEMO_DIR / \"scripts\" / \"dataset_processing\" / \"tts\"\n", + "NEMO_EXAMPLES_DIR = NEMO_DIR / \"examples\" / \"tts\"\n", + "NEMO_CONFIG_DIR = NEMO_EXAMPLES_DIR / \"conf\"\n", + "\n", + "nemo_download_dir = str(NEMO_DIR)\n", + "# Download local version of NeMo scripts. If you are running locally and want to use your own local NeMo code,\n", + "# comment out the below line and set NEMO_ROOT_DIR to your local path.\n", + "!git clone -b $BRANCH https://github.com/NVIDIA/NeMo.git $nemo_download_dir" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "KAbH7N427FdT" + }, + "source": [ + "# Configuration" + ] + }, + { + "cell_type": "markdown", + "source": [ + "Predefined model configurations are available in https://github.com/NVIDIA/NeMo/tree/main/examples/tts/conf/audio_codec.\n", + "\n", + "Configurations available include:\n", + "\n", + "* **audio_codec_*.yaml**: Audio codec configurations optimized for various sampling rates.\n", + "* **mel_codec_*.yaml**: A mel-spectrogram based codec designed to maximize the performance of speech synthesis models.\n", + "* **encodec_*.yaml**: A reproduction of the original [EnCodec](https://arxiv.org/abs/2210.13438) model setup.\n", + "\n", + "This tutorial can be run with any of our predefined configs. As a default we have selected `audio_codec_16000.yaml`, which works for 16kHz audio." + ], + "metadata": { + "id": "ODgdGgsAAUku" + } + }, + { + "cell_type": "code", + "source": [ + "from omegaconf import OmegaConf" + ], + "metadata": { + "id": "SPtjS2LkzE9Q" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "CONFIG_FILENAME = \"audio_codec_16000.yaml\"\n", + "CONFIG_DIR = NEMO_CONFIG_DIR / \"audio_codec\"\n", + "\n", + "config_filepath = CONFIG_DIR / CONFIG_FILENAME\n", + "\n", + "if not config_filepath.exists():\n", + " raise ValueError(f\"Config file does not exist {config_filepath}\")" + ], + "metadata": { + "id": "iCPJFKg63Dsv" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "# Read model name and sample rate from model configuration\n", + "omega_conf = OmegaConf.load(config_filepath)\n", + "MODEL_NAME = omega_conf.name\n", + "SAMPLE_RATE = omega_conf.sample_rate\n", + "print(f\"Training {MODEL_NAME} with sample rate {SAMPLE_RATE}\")" + ], + "metadata": { + "id": "QE0HYh7FjAR3" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "We provide pretrained model checkpoints for fine-tuning. The list of available models can be found [here](https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/tts/models/audio_codec.py#L645)." + ], + "metadata": { + "id": "W7F--_0maLh5" + } + }, + { + "cell_type": "code", + "source": [ + "import wget\n", + "from nemo.collections.tts.models.audio_codec import AudioCodecModel\n", + "\n", + "# Optionally specify a pretrained model to fine-tune from. To train from scratch, set this to 'None'.\n", + "pretrained_model_name = \"audio_codec_16khz_small\"\n", + "\n", + "if pretrained_model_name is None:\n", + " MODEL_CHECKPOINT_PATH = None\n", + "else:\n", + " model_list = AudioCodecModel.list_available_models()\n", + "\n", + " pretrained_model_url = None\n", + " for model in model_list:\n", + " if model.pretrained_model_name == pretrained_model_name:\n", + " pretrained_model_url = model.location\n", + " break\n", + "\n", + " if pretrained_model_url is None:\n", + " raise ValueError(f\"Could not find pretrained model {pretrained_model_name}. Models available {model_list}\")\n", + "\n", + " # Optionally load pretrained checkpoint\n", + " MODEL_CHECKPOINT_PATH = ROOT_DIR / \"models\" / f\"{pretrained_model_name}.nemo\"\n", + "\n", + " if not MODEL_CHECKPOINT_PATH.exists():\n", + " print(f\"Downloading {pretrained_model_url} to {MODEL_CHECKPOINT_PATH}\")\n", + " MODEL_CHECKPOINT_PATH.parent.mkdir(exist_ok=True)\n", + " wget.download(pretrained_model_url, out=str(MODEL_CHECKPOINT_PATH))" + ], + "metadata": { + "id": "XqAYWR65aKTx" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "fM4QPsLTnzK7" + }, + "source": [ + "# Dataset Preparation" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "tkZC6Dl7KRl6" + }, + "source": [ + "For our tutorial, we use a subset of [VCTK](https://datashare.ed.ac.uk/handle/10283/2950) dataset with 5 speakers (p225-p229)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "sYzvAYr2vo1K" + }, + "outputs": [], + "source": [ + "import tarfile\n", + "\n", + "from nemo.collections.asr.parts.utils.manifest_utils import read_manifest, write_manifest" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "aoxN1QsUzX-k" + }, + "outputs": [], + "source": [ + "# Create dataset directory\n", + "DATA_DIR = ROOT_DIR / \"data\"\n", + "\n", + "DATA_DIR.mkdir(parents=True, exist_ok=True)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "mArlQd5Hk36b" + }, + "outputs": [], + "source": [ + "# Download the dataset\n", + "dataset_url = \"https://vctk-subset.s3.amazonaws.com/vctk_subset_multispeaker.tar.gz\"\n", + "dataset_tar_filepath = DATA_DIR / \"vctk.tar.gz\"\n", + "\n", + "if not dataset_tar_filepath.exists():\n", + " wget.download(dataset_url, out=str(dataset_tar_filepath))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "p987cjtOy9C7" + }, + "outputs": [], + "source": [ + "# Extract the dataset\n", + "with tarfile.open(dataset_tar_filepath) as tar_f:\n", + " tar_f.extractall(DATA_DIR)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "Ko6dxYJW0i3G" + }, + "outputs": [], + "source": [ + "DATASET_DIR = DATA_DIR / \"vctk_subset_multispeaker\"" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "We5FHYQt5BeO" + }, + "outputs": [], + "source": [ + "# Visualize the raw dataset\n", + "train_raw_filepath = DATASET_DIR / \"train.json\"\n", + "!head $train_raw_filepath" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "i3jsk2HCMSU5" + }, + "source": [ + "## Manifest Processing" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "N8WuAGJsMHRn" + }, + "source": [ + "The downloaded manifest is formatted for TTS training, which contains metadata such as text and speaker.\n", + "\n", + "For codec training we need `audio_filepath`. The `audio_filepath` field can either be an *absolute path*, or a *relative path* with the root directory provided as an argument to each script. Here we use relative paths.\n", + "\n", + "If you include `duration` the training script will automatically calculate the total size of each dataset used, and can be useful for filtering based on utterance length." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "zoCRrKQ20VZP" + }, + "outputs": [], + "source": [ + "def update_manifest(data_type):\n", + " input_filepath = DATASET_DIR / f\"{data_type}.json\"\n", + " output_filepath = DATASET_DIR / f\"{data_type}_raw.json\"\n", + "\n", + " entries = read_manifest(input_filepath)\n", + " new_entries = []\n", + " for entry in entries:\n", + " # Provide relative path instead of absolute path\n", + " audio_filepath = entry[\"audio_filepath\"].replace(\"audio/\", \"\")\n", + " duration = round(entry[\"duration\"], 2)\n", + " new_entry = {\n", + " \"audio_filepath\": audio_filepath,\n", + " \"duration\": duration\n", + " }\n", + " new_entries.append(new_entry)\n", + "\n", + " write_manifest(output_path=output_filepath, target_manifest=new_entries, ensure_ascii=False)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "PaCc3GCG1UbH" + }, + "outputs": [], + "source": [ + "update_manifest(\"dev\")\n", + "update_manifest(\"train\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "bVLIB3Ip1Aqn" + }, + "outputs": [], + "source": [ + "# Visualize updated 'audio_filepath' field.\n", + "train_filepath = DATASET_DIR / \"train_raw.json\"\n", + "!head $train_filepath" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "alrRDWio41qi" + }, + "source": [ + "## Audio Preprocessing" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "4WfEaMwpUsFt" + }, + "source": [ + "Next we process the audio data using [preprocess_audio.py](https://github.com/NVIDIA/NeMo/blob/main/scripts/dataset_processing/tts/preprocess_audio.py).\n", + "\n", + "During this step we can apply the following transformations:\n", + "\n", + "1. Resample the audio from 48khz to the target sample rate for codec training.\n", + "2. Remove long silence from the beginning and end of each audio file. This can be done using an *energy* based approach which will work on clean audio, or using *voice activity detection (VAD)* which is slower but also works on audio with background or static noise (eg. from a microphone). Here we suggest VAD because some audio in VCTK has background noise." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "WEvIefjnd7AG" + }, + "outputs": [], + "source": [ + "import IPython.display as ipd" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "-qEuCH8S4vFP" + }, + "outputs": [], + "source": [ + "# Python wrapper to invoke the given bash script with the given input args\n", + "def run_script(script, args):\n", + " args = ' \\\\'.join(args)\n", + " cmd = f\"python {script} \\\\{args}\"\n", + "\n", + " print(cmd.replace(\" \\\\\", \"\\n\"))\n", + " print()\n", + " !$cmd" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "0kQ1UDnGfdX6" + }, + "outputs": [], + "source": [ + "audio_preprocessing_script = NEMO_SCRIPT_DIR / \"preprocess_audio.py\"\n", + "\n", + "# Directory with raw audio data\n", + "input_audio_dir = DATASET_DIR / \"audio\"\n", + "# Directory to write preprocessed audio to\n", + "output_audio_dir = DATASET_DIR / \"audio_preprocessed\"\n", + "# Whether to overwrite existing audio, if it exists in the output directory\n", + "overwrite_audio = True\n", + "# Whether to overwrite output manifest, if it exists\n", + "overwrite_manifest = True\n", + "# Number of threads to parallelize audio processing across\n", + "num_workers = 4\n", + "# Format of output audio files. Use \"flac\" to compress to a smaller file size.\n", + "output_format = \"flac\"\n", + "# Method for silence trimming. Can use \"energy.yaml\" or \"vad.yaml\".\n", + "trim_config_path = NEMO_CONFIG_DIR / \"trim\" / \"vad.yaml\"\n", + "\n", + "def preprocess_audio(data_type):\n", + " input_filepath = DATASET_DIR / f\"{data_type}_raw.json\"\n", + " output_filepath = DATASET_DIR / f\"{data_type}_manifest.json\"\n", + "\n", + " args = [\n", + " f\"--input_manifest={input_filepath}\",\n", + " f\"--output_manifest={output_filepath}\",\n", + " f\"--input_audio_dir={input_audio_dir}\",\n", + " f\"--output_audio_dir={output_audio_dir}\",\n", + " f\"--num_workers={num_workers}\",\n", + " f\"--output_sample_rate={SAMPLE_RATE}\",\n", + " f\"--output_format={output_format}\",\n", + " f\"--trim_config_path={trim_config_path}\"\n", + " ]\n", + " if overwrite_manifest:\n", + " args.append(\"--overwrite_manifest\")\n", + " if overwrite_audio:\n", + " args.append(\"--overwrite_audio\")\n", + "\n", + " run_script(audio_preprocessing_script, args)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "ai0zbXSOriuY" + }, + "outputs": [], + "source": [ + "preprocess_audio(\"dev\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "NUKnidQYfgDo" + }, + "outputs": [], + "source": [ + "preprocess_audio(\"train\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "x2yhJtsj2lDR" + }, + "source": [ + "Before we proceed, it is important to verify that the audio processing works as expected. Let's listen to an audio file before and after processing.\n", + "\n", + "Note that the processed audio is shorter because we trimmed the leading and trailing silence." + ] + }, + { + "cell_type": "code", + "source": [ + "!ls $processed_audio_filepath" + ], + "metadata": { + "id": "AfdHUHAWuF-G" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "_fM3GwJxkjOA" + }, + "outputs": [], + "source": [ + "audio_file = \"p228_009.wav\"\n", + "audio_filepath = input_audio_dir / audio_file\n", + "processed_audio_filepath = output_audio_dir / audio_file.replace(\".wav\", \".flac\")\n", + "\n", + "print(\"Original audio.\")\n", + "ipd.display(ipd.Audio(audio_filepath, rate=SAMPLE_RATE))\n", + "\n", + "print(\"Processed audio.\")\n", + "ipd.display(ipd.Audio(processed_audio_filepath, rate=SAMPLE_RATE))" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "oRO842MUyODC" + }, + "source": [ + "# Audio Codec Training" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "E4wUKYOfH8ax" + }, + "source": [ + "Here we show how to train an audio codec model from scratch. Instructions and checkpoints for fine-tuning will be provided later.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "pqfl9jAYMJob" + }, + "outputs": [], + "source": [ + "import os\n", + "import torch\n", + "from omegaconf import OmegaConf" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "jK2rr-Kr6Qg8" + }, + "outputs": [], + "source": [ + "dataset_name = \"vctk\"\n", + "audio_dir = DATASET_DIR / \"audio_preprocessed\"\n", + "train_manifest_filepath = DATASET_DIR / \"train_manifest.json\"\n", + "dev_manifest_filepath = DATASET_DIR / \"dev_manifest.json\"" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "Vr4D-NB-yQx8" + }, + "outputs": [], + "source": [ + "audio_codec_training_script = NEMO_EXAMPLES_DIR / \"audio_codec.py\"\n", + "\n", + "# The total number of training steps will be (epochs * steps_per_epoch)\n", + "epochs = 10\n", + "steps_per_epoch = 10\n", + "\n", + "# Name of the experiment that will determine where it is saved locally and in TensorBoard and WandB\n", + "run_id = \"test_run\"\n", + "exp_dir = ROOT_DIR / \"exps\"\n", + "codec_exp_output_dir = exp_dir / MODEL_NAME / run_id\n", + "# Directory where predicted audio will be stored periodically throughout training\n", + "codec_log_dir = codec_exp_output_dir / \"logs\"\n", + "# Optionally log visualization of learned codes.\n", + "log_dequantized = True\n", + "# Optionally log predicted audio and other artifacts to WandB\n", + "log_to_wandb = False\n", + "# Optionally log predicted audio and other artifacts to Tensorboard\n", + "log_to_tensorboard = False\n", + "\n", + "if torch.cuda.is_available():\n", + " accelerator=\"gpu\"\n", + " batch_size = 4\n", + "else:\n", + " accelerator=\"cpu\"\n", + " batch_size = 2\n", + "\n", + "args = [\n", + " f\"--config-path={CONFIG_DIR}\",\n", + " f\"--config-name={CONFIG_FILENAME}\",\n", + " f\"max_epochs={epochs}\",\n", + " f\"weighted_sampling_steps_per_epoch={steps_per_epoch}\",\n", + " f\"batch_size={batch_size}\",\n", + " f\"log_dir={codec_log_dir}\",\n", + " f\"exp_manager.exp_dir={exp_dir}\",\n", + " f\"+exp_manager.version={run_id}\",\n", + " f\"model.log_config.log_wandb={log_to_wandb}\",\n", + " f\"model.log_config.log_tensorboard={log_to_tensorboard}\",\n", + " f\"model.log_config.generators.0.log_dequantized={log_dequantized}\",\n", + " f\"trainer.accelerator={accelerator}\",\n", + " f\"+train_ds_meta.{dataset_name}.manifest_path={train_manifest_filepath}\",\n", + " f\"+train_ds_meta.{dataset_name}.audio_dir={audio_dir}\",\n", + " f\"+val_ds_meta.{dataset_name}.manifest_path={dev_manifest_filepath}\",\n", + " f\"+val_ds_meta.{dataset_name}.audio_dir={audio_dir}\",\n", + " f\"+log_ds_meta.{dataset_name}.manifest_path={dev_manifest_filepath}\",\n", + " f\"+log_ds_meta.{dataset_name}.audio_dir={audio_dir}\"\n", + "]\n", + "\n", + "if MODEL_CHECKPOINT_PATH is not None:\n", + " args.append(f\"+init_from_nemo_model={MODEL_CHECKPOINT_PATH}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "Bn8lQG0PxWGi" + }, + "outputs": [], + "source": [ + "# If an error occurs, log the entire stacktrace.\n", + "os.environ[\"HYDRA_FULL_ERROR\"] = \"1\"" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "yUxFCNrE3Ywi" + }, + "outputs": [], + "source": [ + "# Do the model training. For some configurations this step might hang when using CPU.\n", + "run_script(audio_codec_training_script, args)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "BBPIpS-lL6z9" + }, + "source": [ + "During training, the model will automatically save predictions for all audio files specified in the `log_ds_meta` manifest." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "rSFOm1Sg46Lh" + }, + "outputs": [], + "source": [ + "codec_log_epoch_dir = codec_log_dir / \"epoch_10\" / dataset_name\n", + "!ls $codec_log_epoch_dir" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "oCJs7oCLMIjD" + }, + "source": [ + "This makes it easy to listen to the audio to determine how well the model is performing. We can decide to stop training when either:\n", + "\n", + "* The predicted audio sounds almost identical to the original audio.\n", + "* The predicted audio stops improving in between epochs.\n", + "\n", + "**Note that when training from scratch, the dataset in this tutorial is too small to get good audio quality.**" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "G6k4ymzfJ5Y6" + }, + "outputs": [], + "source": [ + "audio_filepath_ground_truth = output_audio_dir / \"p228_009.flac\"\n", + "audio_filepath_reconstructed = codec_log_epoch_dir / \"p228_009_audio_out.wav\"\n", + "\n", + "print(\"Ground truth audio.\")\n", + "ipd.display(ipd.Audio(audio_filepath_ground_truth, rate=SAMPLE_RATE))\n", + "\n", + "print(\"Reconstructed audio.\")\n", + "ipd.display(ipd.Audio(audio_filepath_reconstructed, rate=SAMPLE_RATE))\n", + "\n", + "dequantized_filepath = codec_log_epoch_dir / \"p228_009_dequantized.png\"\n", + "ipd.Image(dequantized_filepath)" + ] + }, + { + "cell_type": "markdown", + "source": [ + "# Related Information" + ], + "metadata": { + "id": "rynZYwg2VP5d" + } + }, + { + "cell_type": "markdown", + "source": [ + "To learn more about audio codec models in NeMo, look at our [documentation](https://docs.nvidia.com/nemo-framework/user-guide/latest/nemotoolkit/tts/models.html#codecs).\n", + "\n", + "For more information on how to download and run pre-trained audio codec models, visit [NGC](https://catalog.ngc.nvidia.com/models?filters=&orderBy=scoreDESC&query=codec)." + ], + "metadata": { + "id": "_LtyHHuLkNDv" + } + }, + { + "cell_type": "markdown", + "source": [ + "# References" + ], + "metadata": { + "id": "LeqV3VvJVOb-" + } + }, + { + "cell_type": "markdown", + "source": [ + "1. [EnCodec](https://arxiv.org/abs/2210.13438)\n", + "2. [Finite Scalar Quantization (FSQ)](https://arxiv.org/abs/2309.15505)\n", + "3. [HiFi-GAN](https://arxiv.org/abs/2010.05646)\n", + "4. [SEANet](https://arxiv.org/abs/2009.02095)" + ], + "metadata": { + "id": "Rvu4w2x_3RSY" + } + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + }, + "language_info": { + "name": "python" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} From 8005fb2f7ef4936e090f93f5f80b0e76bfa18e78 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Piotr=20=C5=BBelasko?= Date: Thu, 2 May 2024 15:31:45 -0400 Subject: [PATCH 13/73] Improved `shard_id` parsing in `LazyNemoTarredIterator`, enables AIS dataloading (#9077) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * More permissive shard_id parsing, enables AIS dataloading Signed-off-by: Piotr Żelasko * Fix to shard id discovery Signed-off-by: Piotr Żelasko * More informative assertion errors Signed-off-by: Piotr Żelasko --------- Signed-off-by: Piotr Żelasko --- .../common/data/lhotse/nemo_adapters.py | 59 +++++++++++++------ 1 file changed, 40 insertions(+), 19 deletions(-) diff --git a/nemo/collections/common/data/lhotse/nemo_adapters.py b/nemo/collections/common/data/lhotse/nemo_adapters.py index 02b3e1f4edda..b8769b041b4f 100644 --- a/nemo/collections/common/data/lhotse/nemo_adapters.py +++ b/nemo/collections/common/data/lhotse/nemo_adapters.py @@ -14,7 +14,6 @@ import random import re -import secrets import tarfile from io import BytesIO from pathlib import Path @@ -147,9 +146,20 @@ class LazyNeMoTarredIterator: Args ``manifest_path`` and ``tar_paths`` can be either a path/string to a single file, or a string in NeMo format that indicates multiple paths (e.g. "[[data/bucket0/tarred_audio_paths.json],[data/bucket1/...]]"). + We discover shard ids from sharded tar and json files by parsing the input specifier/path and + searching for the following pattern: ``(manifest|audio)[^/]*_(\d+)[^/]*\.(json|tar)``. + It allows filenames such as ``manifest_0.json``, ``manifest_0_normalized.json``, ``manifest_normalized_0.json``, + ``manifest_0.jsonl.gz``, etc. (anologusly the same applies to tar files). + + We also support generalized input specifiers that imitate webdataset's pipes (also very similar to Kaldi's pipes). + These are arbitrary shell commands to be lazily executed which yield manifest or tar audio contents. + For example, ``tar_paths`` can be set to ``pipe:ais get ais://my-bucket/audio_{0..127}.tar -`` + to indicate that we want to read tarred audio data from shards on an AIStore bucket. + This can be used for other cloud storage APIs such as S3, GCS, etc. + The same mechanism applies to ``manifest_path``. The ``shard_seed`` argument is used to seed the RNG shuffling the shards. - By default it's ``trng`` which samples a seed number from OS-provided TRNG (see Python ``secrets`` module). + By default, it's ``trng`` which samples a seed number from OS-provided TRNG (see Python ``secrets`` module). Seed is resolved lazily so that every dataloading worker may sample a different one. Override with an integer value for deterministic behaviour and consult Lhotse documentation for details: https://lhotse.readthedocs.io/en/latest/datasets.html#handling-random-seeds @@ -172,30 +182,36 @@ def __init__( text_field: str = "text", lang_field: str = "lang", ) -> None: - def strip_pipe(p): - if isinstance(p, str): - if p.startswith("pipe:"): - p = p[5:] - return Path(p) - return p - self.shard_id_to_manifest: dict[int, Iterable[dict]] self.paths = expand_sharded_filepaths(manifest_path) if len(self.paths) == 1: self.source = LazyJsonlIterator(self.paths[0]) self.shard_id_to_manifest = groupby("shard_id", self.source) else: - pattern = re.compile(r".+_(\d+)\.jsonl?(?:.gz)?") + json_pattern = re.compile(r"manifest[^/]*_(\d+)[^/]*\.json") shard_ids = [] for p in self.paths: - m = pattern.match(p) - assert m is not None, f"Cannot determine shard_id from manifest path: {p}" + m = json_pattern.search(p) + assert m is not None, ( + f"Cannot determine shard_id from manifest input specified: " + f"we searched with regex '{json_pattern.pattern}' in input '{p}'" + ) shard_ids.append(int(m.group(1))) self.shard_id_to_manifest = {sid: LazyJsonlIterator(p) for sid, p in zip(shard_ids, self.paths)} self.source = LazyIteratorChain(*self.shard_id_to_manifest.values()) - tar_paths = expand_sharded_filepaths(tar_paths) - self.shard_id_to_tar_path: dict[int, str] = {int(strip_pipe(p).stem.split("_")[1]): p for p in tar_paths} + self.tar_paths = expand_sharded_filepaths(tar_paths) + tar_pattern = re.compile(r"audio[^/]*_(\d+)[^/]*\.tar") + shard_ids = [] + for p in self.tar_paths: + m = tar_pattern.search(p) + assert m is not None, ( + f"Cannot determine shard_id from tar input specifier: " + f"we searched with regex '{tar_pattern.pattern}' in input '{p}'" + ) + shard_ids.append(int(m.group(1))) + self.shard_id_to_tar_path = dict(zip(shard_ids, self.tar_paths)) + self.shuffle_shards = shuffle_shards self.shard_seed = shard_seed self.text_field = text_field @@ -225,8 +241,11 @@ def _validate(self) -> None: shard_ids_tars = set(self.shard_id_to_tar_path) shard_ids_manifest = set(self.shard_id_to_manifest) assert shard_ids_tars == shard_ids_manifest, ( - f"Mismatch between shard IDs discovered from tar files ({len(shard_ids_tars)=}) and " - f"JSON manifest ({len(shard_ids_manifest)=}): {shard_ids_tars - shard_ids_manifest=}" + f"Mismatch between shard IDs. Details:\n" + f"* JSON manifest(s) {self.paths}\n" + f"* Tar files: {self.tar_paths}\n" + f"* JSON manifest(s) indicate(s) IDs: {sorted(shard_ids_manifest)}\n" + f"* Tar path(s) indicate(s) IDs: {sorted(shard_ids_tars)}\n" ) @property @@ -245,9 +264,11 @@ def __iter__(self) -> Generator[Cut, None, None]: tar_path = self.shard_id_to_tar_path[sid] with tarfile.open(fileobj=open_best(tar_path, mode="rb"), mode="r|*") as tar: for data, tar_info in zip(shard_manifest, tar): - assert ( - data["audio_filepath"] == tar_info.name - ), f"Mismatched JSON manifest and tar file. {data['audio_filepath']=} != {tar_info.name=}" + manifest_path = self.paths[sid] if len(self.paths) > 1 else self.paths[0] + assert data["audio_filepath"] == tar_info.name, ( + f"Mismatched entry between JSON manifest ('{manifest_path}') and tar file ('{tar_path}'). " + f"Conflicting audio file names are JSON='{data['audio_filepath']}' and TAR='{tar_info.name}'" + ) raw_audio = tar.extractfile(tar_info).read() # Note: Lhotse has a Recording.from_bytes() utility that we won't use here because # the profiling indicated significant overhead in torchaudio ffmpeg integration From e16d06999bd088eb2fa1b1787628fca9929613bf Mon Sep 17 00:00:00 2001 From: Marc Romeyn Date: Thu, 2 May 2024 23:45:53 +0200 Subject: [PATCH 14/73] [NeMo-UX] Add mistral-7b model (#9066) * Adding MegatronParallel * Move over _strategy_liMegatronCheckpointIO * Adding GPTModel & MockDataModule * Adding mixed-precision to NeMo * Fix import * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Adding MegatronParallel * Move over _strategy_liMegatronCheckpointIO * Adding GPTModel & MockDataModule * Add nemo.io to MegatronStrategy * Move to cloudpickle * Adding Mistral7B model * Fix small bug inside state-transform * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * revert unintended changes Signed-off-by: Chen Cui * clean up code and reinstate mix precision tests Signed-off-by: Chen Cui * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * clean up Signed-off-by: Chen Cui * use cpu for unit test Signed-off-by: Chen Cui * clean up Signed-off-by: Chen Cui * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix test Signed-off-by: Chen Cui * mistral requires hf login so use a toy model for now Signed-off-by: Chen Cui * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * revert accidental change Signed-off-by: Chen Cui --------- Signed-off-by: Chen Cui Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Chen Cui --- nemo/io/__init__.py | 17 +- nemo/io/api.py | 171 +++++++++- nemo/io/connector.py | 179 +++++++++++ nemo/io/mixin.py | 184 ++++++++++- nemo/io/state.py | 403 ++++++++++++++++++++++++ nemo/llm/__init__.py | 12 +- nemo/llm/gpt/model/__init__.py | 11 +- nemo/llm/gpt/model/base.py | 7 +- nemo/llm/gpt/model/mistral_7b.py | 263 ++++++++++++++++ requirements/requirements_lightning.txt | 1 + tests/io/test_state.py | 233 ++++++++++++++ 11 files changed, 1472 insertions(+), 9 deletions(-) create mode 100644 nemo/io/connector.py create mode 100644 nemo/io/state.py create mode 100644 nemo/llm/gpt/model/mistral_7b.py create mode 100644 tests/io/test_state.py diff --git a/nemo/io/__init__.py b/nemo/io/__init__.py index 5b1d48768848..1b541ff7ba34 100644 --- a/nemo/io/__init__.py +++ b/nemo/io/__init__.py @@ -1,14 +1,25 @@ -from nemo.io.api import load, load_ckpt +from nemo.io.api import export_ckpt, import_ckpt, load, load_ckpt, model_exporter, model_importer from nemo.io.capture import reinit -from nemo.io.mixin import IOMixin +from nemo.io.connector import Connector, ModelConnector +from nemo.io.mixin import ConnectorMixin, IOMixin from nemo.io.pl import TrainerCheckpoint, is_distributed_ckpt - +from nemo.io.state import TransformCTX, apply_transforms, state_transform __all__ = [ + "apply_transforms", + "Connector", + "ConnectorMixin", "IOMixin", + "import_ckpt", "is_distributed_ckpt", + "export_ckpt", "load", "load_ckpt", + "ModelConnector", + "model_importer", + "model_exporter", 'reinit', + "state_transform", "TrainerCheckpoint", + "TransformCTX", ] diff --git a/nemo/io/api.py b/nemo/io/api.py index f7de36cb9545..c8fe3c04a811 100644 --- a/nemo/io/api.py +++ b/nemo/io/api.py @@ -1,9 +1,11 @@ import pickle from pathlib import Path -from typing import Any, Type, TypeVar +from typing import Any, Callable, Optional, Type, TypeVar import fiddle as fdl +import pytorch_lightning as pl +from nemo.io.mixin import ConnectorMixin, ConnT, ModelConnector from nemo.io.pl import TrainerCheckpoint CkptType = TypeVar("CkptType") @@ -60,3 +62,170 @@ def load_ckpt(path: Path) -> TrainerCheckpoint: checkpoint: TrainerCheckpoint = load_ckpt("/path/to/checkpoint") """ return load(path, output_type=TrainerCheckpoint) + + +def model_importer( + target: Type[ConnectorMixin], ext: str, default_path: Optional[str] = None +) -> Callable[[Type[ConnT]], Type[ConnT]]: + """ + Registers an importer for a model with a specified file extension and an optional default path. + + Args: + target (Type[ConnectorMixin]): The model class to which the importer will be attached. + ext (str): The file extension associated with the model files to be imported. + default_path (Optional[str]): The default path where the model files are located, if any. + + Returns + ------- + Callable[[Type[ConnT]], Type[ConnT]]: A decorator function that registers the importer + to the model class. + + Example: + @model_importer(MyModel, "hf", default_path="path/to/default") + class MyModelHfImporter(io.ModelConnector): + ... + """ + return target.register_importer(ext, default_path=default_path) + + +def model_exporter( + target: Type[ConnectorMixin], ext: str, default_path: Optional[str] = None +) -> Callable[[Type[ConnT]], Type[ConnT]]: + """ + Registers an exporter for a model with a specified file extension and an optional default path. + + Args: + target (Type[ConnectorMixin]): The model class to which the exporter will be attached. + ext (str): The file extension associated with the model files to be exported. + default_path (Optional[str]): The default path where the model files will be saved, if any. + + Returns + ------- + Callable[[Type[ConnT]], Type[ConnT]]: A decorator function that registers the exporter + to the model class. + + Example: + @model_exporter(MyModel, "hf", default_path="path/to/default") + class MyModelHFExporter(io.ModelConnector): + ... + """ + return target.register_exporter(ext, default_path=default_path) + + +def import_ckpt( + model: pl.LightningModule, source: str, output_path: Optional[Path] = None, overwrite: bool = False +) -> Path: + """ + Imports a checkpoint into a model using the model's associated importer, typically for + the purpose of fine-tuning a community model trained in an external framework, such as + Hugging Face. This function leverages the ConnectorMixin interface to integrate external + checkpoint data seamlessly into the specified model instance. + + The importer component of the model reads the checkpoint data from the specified source + and transforms it into the right format. This is particularly useful for adapting + models that have been pre-trained in different environments or frameworks to be fine-tuned + or further developed within the current system. The function allows for specifying an output + path for the imported checkpoint; if not provided, the importer's default path will be used. + The 'overwrite' parameter enables the replacement of existing data at the output path, which + is useful when updating models with new data and discarding old checkpoint files. + + For instance, using `import_ckpt(Mistral7BModel(), "hf")` initiates the import process + by searching for a registered model importer tagged with "hf". In NeMo, `HFMistral7BImporter` + is registered under this tag via: + `@io.model_importer(Mistral7BModel, "hf", default_path="mistralai/Mistral-7B-v0.1")`. + This links `Mistral7BModel` to `HFMistral7BImporter`, designed for HuggingFace checkpoints. + The importer then processes and integrates these checkpoints into `Mistral7BModel` for further + fine-tuning. + + Args: + model (pl.LightningModule): The model into which the checkpoint will be imported. + This model must implement the ConnectorMixin, which includes the necessary + importer method for checkpoint integration. + source (str): The source from which the checkpoint will be imported. This can be + a file path, URL, or any other string identifier that the model's importer + can recognize. + output_path (Optional[Path]): The path where the imported checkpoint will be stored. + If not specified, the importer's default path is used. + overwrite (bool): If set to True, existing files at the output path will be overwritten. + This is useful for model updates where retaining old checkpoint files is not required. + + Returns + ------- + Path: The path where the checkpoint has been saved after import. This path is determined + by the importer, based on the provided output_path and its internal logic. + + Raises + ------ + ValueError: If the model does not implement ConnectorMixin, indicating a lack of + necessary importer functionality. + + Example: + model = Mistral7BModel() + imported_path = import_ckpt(model, "hf") + """ + if not isinstance(model, ConnectorMixin): + raise ValueError("Model must be an instance of ConnectorMixin") + + importer: ModelConnector = model.importer(source) + return importer(overwrite=overwrite, output_path=output_path) + + +def load_connector_from_trainer_ckpt(path: Path, target: str) -> ModelConnector: + model: pl.LightningModule = load_ckpt(path).model + + if not isinstance(model, ConnectorMixin): + raise ValueError("Model must be an instance of ConnectorMixin") + + return model.exporter(target, path) + + +def export_ckpt( + path: Path, + target: str, + output_path: Optional[Path] = None, + overwrite: bool = False, + load_connector: Callable[[Path, str], ModelConnector] = load_connector_from_trainer_ckpt, +) -> Path: + """ + Exports a checkpoint from a model using the model's associated exporter, typically for + the purpose of sharing a model that has been fine-tuned or customized within NeMo. + This function leverages the ConnectorMixin interface to seamlessly integrate + the model's state into an external checkpoint format. + + The exporter component of the model reads the model's state from the specified path and + exports it into the format specified by the 'target' identifier. This is particularly + useful for adapting models that have been developed or fine-tuned within the current system + to be compatible with other environments or frameworks. The function allows for specifying + an output path for the exported checkpoint; if not provided, the exporter's default path + will be used. The 'overwrite' parameter enables the replacement of existing data at the + output path, which is useful when updating models with new data and discarding old checkpoint + files. + + Args: + path (Path): The path to the model's checkpoint file from which data will be exported. + target (str): The identifier for the exporter that defines the format of the export. + output_path (Optional[Path]): The path where the exported checkpoint will be saved. + If not specified, the exporter's default path is used. + overwrite (bool): If set to True, existing files at the output path will be overwritten. + This is useful for model updates where retaining old checkpoint files is not required. + load_connector (Callable[[Path, str], ModelConnector]): A function to load the appropriate + exporter based on the model and target format. Defaults to `load_connector_from_trainer_ckpt`. + + Returns + ------- + Path: The path where the checkpoint has been saved after export. This path is determined + by the exporter, based on the provided output_path and its internal logic. + + Raises + ------ + ValueError: If the model does not implement ConnectorMixin, indicating a lack of + necessary exporter functionality. + + Example: + nemo_ckpt_path = Path("/path/to/model.ckpt") + export_path = export_ckpt(nemo_ckpt_path, "hf") + """ + exporter: ModelConnector = load_connector(path, target) + _output_path = output_path or Path(path) / target + + return exporter(overwrite=overwrite, output_path=_output_path) diff --git a/nemo/io/connector.py b/nemo/io/connector.py new file mode 100644 index 000000000000..bf5f88f95992 --- /dev/null +++ b/nemo/io/connector.py @@ -0,0 +1,179 @@ +import os +import shutil +from pathlib import Path, PosixPath, WindowsPath +from typing import Generic, Optional, Tuple, TypeVar + +import pytorch_lightning as pl + +# Dynamically inherit from the correct Path subclass based on the operating system. +if os.name == 'nt': + BasePath = WindowsPath +else: + BasePath = PosixPath + + +SourceT = TypeVar("SourceT") +TargetT = TypeVar("TargetT") + + +class Connector(BasePath, Generic[SourceT, TargetT]): + """ + A generic connector class that provides a framework for transforming a source type (SourceT) + to a target type (TargetT) while handling file paths based on the operating system. + + Attributes + ---------- + default_path (Optional[Path]): A default path used when no path is explicitly provided. + + Methods + ------- + init() -> TargetT: + Should be implemented to initialize the target type from the source type. + + apply(output_path: Path) -> Path: + Should be implemented to apply the transformation and save the result at the output path. + + __new__(cls, *args, **kwargs) -> 'Connector': + Creates a new instance of the connector, using default_path if no path is provided. + + __call__(output_path: Optional[Path] = None, overwrite: bool = False) -> Path: + Processes the transformation and handles file operations like overwriting. + + local_path(base_path: Optional[Path] = None) -> Path: + Computes the local path for storage based on a base path or a default cache home. + + is_in_cache(base_path: Optional[Path] = None) -> bool: + Checks if the transformed data is already cached at the specified base path. + """ + + default_path = None + + def init(self) -> TargetT: + raise NotImplementedError() + + def apply(self, output_path: Path) -> Path: + raise NotImplementedError() + + def __new__(cls, *args, **kwargs): + if cls.default_path is not None and not args and 'path' not in kwargs: + # If default_path is set and no arguments are provided, use default_path as the argument + return super().__new__(cls, cls.default_path) + + return super().__new__(cls, *args, **kwargs) + + def __call__(self, output_path: Optional[Path] = None, overwrite: bool = False) -> Path: + _output_path = output_path or self.local_path() + + if overwrite and _output_path.exists(): + shutil.rmtree(_output_path) + + if not _output_path.exists(): + to_return = self.apply(_output_path) + _output_path = to_return or _output_path + + return _output_path + + def local_path(self, base_path: Optional[Path] = None) -> Path: + if base_path: + _base = base_path + else: + from nemo.lightning.base import NEMO_CACHE_HOME + + _base = Path(NEMO_CACHE_HOME) + + return _base / str(self).replace("://", "/") + + def is_in_cache(self, base_path: Optional[Path] = None) -> bool: + return self.local_path(base_path=base_path).exists() + + +class ModelConnector(Connector, Generic[SourceT, TargetT]): + """ + A specialized connector that extends the generic Connector to handle model-specific operations + such as setup, save, and load using the Lightning framework. + + Methods + ------- + nemo_setup(model: pl.LightningModule, trainer: Optional[pl.Trainer] = None) -> pl.Trainer: + Sets up the model and trainer using a specified strategy, preparing it for training or inference. + + nemo_save(output_path: Path, trainer: pl.Trainer): + Saves the model's state to the specified path using the trainer's current strategy. + + nemo_load(path: Path, trainer: Optional[pl.Trainer] = None, cpu: bool = True) -> Tuple[Any, pl.Trainer]: + Loads a model from the specified path, optionally using a CPU-focused strategy, and returns the model and trainer. + """ + + def nemo_setup(self, model: pl.LightningModule, trainer: Optional[pl.Trainer] = None) -> pl.Trainer: + """ + Sets up the model and trainer using a specified strategy, preparing it for training or inference. + + Args: + model (pl.LightningModule): The model to be set up. + trainer (Optional[pl.Trainer]): The trainer to be used, if not provided a new one will be created. + + Returns + ------- + pl.Trainer: The trainer configured with the model and strategy. + """ + from nemo.lightning import MegatronStrategy, Trainer + + _trainer = trainer or Trainer(devices=1, accelerator="cpu", strategy=MegatronStrategy()) + + _trainer.strategy.connect(model) + _trainer.strategy.setup_environment() + + if not model.state_dict(): + _trainer.strategy.lazy_init = True + with _trainer.init_module(): + model.configure_model() + + return _trainer + + def nemo_save(self, output_path: Path, trainer: pl.Trainer) -> None: + """ + Saves the model's state to the specified path using the trainer's current strategy. + + Args: + output_path (Path): The path where the model checkpoint will be saved. + trainer (pl.Trainer): The trainer with the strategy to save the model. + """ + trainer.strategy.setup(trainer) + trainer.save_checkpoint(output_path) + + def nemo_load( + self, path: Path, trainer: Optional[pl.Trainer] = None, cpu: bool = True + ) -> Tuple[pl.LightningModule, pl.Trainer]: + """ + Loads a model from the specified path. + + Args: + path (Path): The path from which the model will be loaded. + trainer (Optional[pl.Trainer]): The trainer to be used, if not provided a new one will be created. + cpu (bool): If True, the model will be loaded with a CPU-focused strategy. + + Returns + ------- + Tuple[pl.LightningModule, pl.Trainer]: The loaded model and the trainer configured with the model. + """ + from nemo.io.api import load_ckpt + from nemo.lightning import MegatronStrategy, Trainer, _strategy_lib + + model = load_ckpt(path).model + _trainer = trainer or Trainer(devices=1, accelerator="cpu" if cpu else "gpu", strategy=MegatronStrategy()) + + _trainer.strategy.connect(model) + _trainer.strategy.setup_environment() + # TODO: Fix cpu initialization + if not model.state_dict(): + if cpu: + # TODO: Make this more generic + with _strategy_lib.megatron_cpu_init_context(model.config): + model.configure_model() + else: + model.configure_model() + + _trainer.strategy.setup(_trainer) + _trainer.strategy.load_checkpoint(path) + + return model, _trainer diff --git a/nemo/io/mixin.py b/nemo/io/mixin.py index d09c456f7957..bba6677b452b 100644 --- a/nemo/io/mixin.py +++ b/nemo/io/mixin.py @@ -2,13 +2,16 @@ import inspect from dataclasses import is_dataclass from pathlib import Path -from typing import Any, Dict +from typing import Any, Callable, Dict, Optional, Type, TypeVar, Union import fiddle as fdl from cloudpickle import dump from typing_extensions import Self from nemo.io.capture import IOProtocol +from nemo.io.connector import ModelConnector + +ConnT = TypeVar('ConnT', bound=ModelConnector) class IOMixin: @@ -137,3 +140,182 @@ def io_dump(self, output: Path): config_path = Path(output) / "io.pkl" with open(config_path, "wb") as f: dump(self.__io__, f) + + +class ConnectorMixin: + """ + A mixin class that provides methods to register and retrieve model connectors for importing + and exporting models. This class supports dynamic registration of connectors based on file + extensions, which facilitates the customization and extension of model serialization and + deserialization processes. + + Attributes + ---------- + _IMPORTERS (Dict[str, Type[ModelConnector]]): A dictionary mapping file extensions to + model connector classes that handle the import process. + _EXPORTERS (Dict[str, Type[ModelConnector]]): A dictionary mapping file extensions to + model connector classes that handle the export process. + """ + + _IMPORTERS: Dict[str, Type[ModelConnector]] = {} + _EXPORTERS: Dict[str, Type[ModelConnector]] = {} + + @classmethod + def import_from(cls, path: str) -> Self: + """ + Creates an instance of a model by using the appropriate importer based on the file + extension of the provided path. + + Args: + path (str): The path to the model file to be imported. + + Example: + from nemo import llm + model = llm.Mistral7BModel.import_from("hf") + + Returns + ------- + Self: An instance of the model initialized from the imported data. + """ + output = cls._get_connector(path).init() + output.ckpt_path = output.import_ckpt_path(path) + + return output + + @classmethod + def register_importer(cls, ext: str, default_path: Optional[str] = None) -> Callable[[Type[ConnT]], Type[ConnT]]: + """ + A class method decorator to register a model connector as an importer for a specific file + extension. + + Args: + ext (str): The file extension to associate with the model connector. + default_path (Optional[str]): The default path to use if no path is specified during import. + + Returns + ------- + Callable[[Type[ConnT]], Type[ConnT]]: The decorator that registers the model connector. + """ + + def decorator(connector: Type[ConnT]) -> Type[ConnT]: + cls._IMPORTERS[ext] = connector + if default_path: + connector.default_path = default_path + return connector + + return decorator + + @classmethod + def register_exporter(cls, ext: str, default_path: Optional[str] = None) -> Callable[[Type[ConnT]], Type[ConnT]]: + """ + A class method decorator to register a model connector as an exporter for a specific file + extension. + + Args: + ext (str): The file extension to associate with the model connector. + default_path (Optional[str]): The default path to use if no path is specified during export. + + Returns + ------- + Callable[[Type[ConnT]], Type[ConnT]]: The decorator that registers the model connector. + """ + + def decorator(connector: Type[ConnT]) -> Type[ConnT]: + cls._EXPORTERS[ext] = connector + if default_path: + connector.default_path = default_path + return connector + + return decorator + + @classmethod + def importer(cls, path: str) -> ModelConnector: + """ + Retrieves the appropriate model connector for importing based on the extension of the + provided path. + + Args: + path (str): The path to the model file to be imported. + + Returns + ------- + ModelConnector: The model connector instance capable of handling the import. + """ + return cls._get_connector(path, importer=True) + + @classmethod + def exporter(cls, ext: str, path: Union[str, Path]) -> ModelConnector: + """ + Retrieves the appropriate model connector for exporting based on the extension. + + Args: + ext (str): The file extension associated with the model connector. + path (Union[str, Path]): The path where the model will be exported. + + Returns + ------- + ModelConnector: The model connector instance capable of handling the export. + """ + return cls._get_connector(ext, path, importer=False) + + def import_ckpt(self, path: str, overwrite: bool = False, base_path: Optional[Path] = None) -> Path: + """ + Imports a checkpoint from a specified path, potentially overwriting existing files. + + Args: + path (str): The path to the checkpoint file to be imported. + overwrite (bool): Flag to determine if existing files should be overwritten (default is False). + base_path (Optional[Path]): The base path where the checkpoint file is located; used to resolve + relative paths. + + Returns + ------- + Path: The path to the imported checkpoint. + + Raises + ------ + FileNotFoundError: If the checkpoint file does not exist at the specified path. + """ + connector = self._get_connector(path) + ckpt_path: Path = connector.local_path(base_path=base_path) + ckpt_path = connector(ckpt_path, overwrite=overwrite) + + return ckpt_path + + @classmethod + def _get_connector(cls, ext, path=None, importer=True) -> ModelConnector: + """ + Retrieves the appropriate model connector based on the file extension and path, + distinguishing between importers and exporters. + + Args: + ext (str): The file extension or a URI that may include a protocol specifier. + path (Optional[Union[str, Path]]): The path where the model file is located or will be saved. + importer (bool): Flag to determine if the connector is for importing (True) or exporting (False). + + Returns + ------- + ModelConnector: The model connector instance capable of handling the import or export. + + Raises + ------ + ValueError: If no connector is found for the specified extension or if no default path is provided + when required. + """ + _path = None + if "://" in ext: + ext, _path = ext.split("://") + else: + _path = path + + connector = cls._IMPORTERS.get(ext) if importer else cls._EXPORTERS.get(ext) + if not connector: + raise ValueError(f"No connector found for extension '{ext}'") + + if not _path: + if not connector.default_path: + raise ValueError(f"No default path specified for extension '{ext}'. ", "Please provide a path") + + return connector() + + return connector(_path) diff --git a/nemo/io/state.py b/nemo/io/state.py new file mode 100644 index 000000000000..d978cd0ade8e --- /dev/null +++ b/nemo/io/state.py @@ -0,0 +1,403 @@ +import inspect +import re +from dataclasses import dataclass +from typing import Any, Callable, Dict, Generic, List, Optional, Tuple, TypeVar, Union, overload + +import numpy as np +from torch import nn + +SourceModuleT = TypeVar("SourceModuleT", bound=nn.Module) +TargetModuleT = TypeVar("TargetModuleT", bound=nn.Module) +F = TypeVar("F", bound=Callable[..., Any]) + + +@dataclass +class TransformCTX: + source: nn.Module + source_state: dict + target: nn.Module + target_state: dict + + +def apply_transforms( + source: nn.Module, + target: TargetModuleT, + mapping: Dict[str, str], + transforms: Optional[List[Callable[[TransformCTX], TransformCTX]]] = None, +) -> TargetModuleT: + """ + Applies a series of transformations to adapt the state dictionary of a source module to + match the structure of a target module's state dictionary. + + This function renames keys according to a provided mapping and modifies values using a list + of transformation functions. Each transformation function typically is decorated + with `io.state_transform`. + + Args: + source (nn.Module): The source module from which parameters and buffers are taken. + target (TargetModuleT): The target module to which parameters and buffers are adapted. + mapping (Dict[str, str]): Key-value pairs where each key from the source state dictionary + is mapped to a corresponding key in the target state dictionary. + transforms (Optional[List[Callable[[TransformCTX], TransformCTX]]]): A list of functions + that modify the `TransformCTX` object. If None, no transformations beyond key renaming + are applied. Defaults to None. + + Returns + ------- + TargetModuleT: The modified target module with its state dictionary adjusted according to + the specified mappings and transformations. + + Raises + ------ + ValueError: If there's a mismatch in shape between corresponding source and target parameters + or buffers. + RuntimeError: If the target state dictionary contains keys that are not present in the source + state dictionary after all transformations. + + Examples + -------- + >>> source_module = nn.Linear(10, 5) + >>> target_module = nn.Linear(10, 5) + >>> mapping = {'weight': 'weights', 'bias': 'biases'} + @io.state_transform( + source_key="weight", + target_key="weights" + ) + def scale_weights(ctx): + ctx.target_state['weights'] = ctx.source_state['weight'] * 2 + return ctx + >>> transformed_target = apply_transforms( + ... source_module, target_module, mapping, [scale_weights] + ... ) + >>> print(transformed_target.state_dict()['weights']) + + See Also + -------- + - `TransformCTX`: For more details on the context object used in transformations. + - `StateDictTransform`: For creating complex transformations. + + Note: + This function is particularly useful when adapting models from different frameworks or + when consolidating models with different architectural changes. + """ + from megatron.core.transformer.module import MegatronModule + + # TODO: How can we improve this? + _source = source + if hasattr(source, "module") and isinstance(source.module, MegatronModule): + _source = source.module + _target = target + if hasattr(target, "module") and isinstance(target.module, MegatronModule): + _target = target.module + + target_state = _target.state_dict() + ctx = TransformCTX(source=_source, source_state=_source.state_dict(), target=_target, target_state=target_state,) + + for key, val in mapping.items(): + ctx = StateDictTransform(key, val)(ctx) + + if transforms: + for transform in transforms: + ctx = transform(ctx) + + _params: Dict[str, nn.Parameter] = {} + for name, param in _target.named_parameters(): + if name in target_state: + target_param = target_state[name] + if param.data.shape != target_param.shape: + raise ValueError(f"Shape mismatch for parameter {name}: {param.shape} vs {target_param.shape}") + + _params[name] = nn.Parameter(target_param, requires_grad=param.requires_grad) + target_state.pop(name) + else: + print(f"Unexpected key: {name} not in checkpoint but in model.") + + for key, val in _params.items(): + _module, _key = _target, key + if "." in key: + for part in key.split(".")[:-1]: + _module = getattr(_module, part) + _key = key.split(".")[-1] + + _module.register_parameter(_key, val) + + _buffers = {} + for name, buffer in _target.named_buffers(): + if name in target_state: + if buffer.shape != target_state[name].shape: + raise ValueError(f"Shape mismatch for buffer {name}: {buffer.shape} vs {target_state[name].shape}") + + _buffers[name] = nn.Parameter(target_state[name], requires_grad=False) + target_state.pop(name) + + for key, val in _buffers.items(): + _module, _key = _target, key + if "." in key: + for part in key.split(".")[:-1]: + _module = getattr(_module, part) + _key = key.split(".")[-1] + + _module.register_buffer(_key, val) + + keys = [name for name in list(target_state.keys()) if not name.endswith("_extra_state")] + if len(keys) != 0: + raise RuntimeError(f"Additional keys: {target_state.keys()} in checkpoint but not in model.") + + # TODO: Is this correct? + # for key in target.state_dict(): + # if key.endswith("_extra_state"): + # del target.state_dict()[key] + + """finally: + cls._set_model_restore_state(is_being_restored=False)""" + + if hasattr(target, "module") and isinstance(target.module, MegatronModule): + target.module = _target + + return target + + return _target + + +def _default_transform(inp): + return inp.float() + + +class StateDictTransform(Generic[F]): + """ + A transformation class for state dictionaries, allowing for flexible key matching and + transformation of values between source and target state dictionaries. + + Attributes + ---------- + source_key: A string, tuple of strings, or a dictionary specifying the keys in the source + state dictionary to match. Wildcards (*) are supported. + target_key: A string or tuple of strings specifying the keys in the target state dictionary + to match. Wildcards (*) are supported. + transform: A callable that performs the transformation on matched keys' values. + + Examples + -------- + >>> def example_transform(ctx, *args): + ... return sum(args) + >>> transform = StateDictTransform( + ... source_key="model.layers.*.self_attn.*_proj.weight", + ... target_key="decoder.layers.*.self_attention.linear_qkv.weight", + ... transform=example_transform + ... ) + """ + + def __init__( + self, + source_key: Union[str, Tuple[str, ...], Dict[str, str]], + target_key: Union[str, Tuple[str, ...]], + transform: F = _default_transform, + ): + self.source_key = source_key + self.target_key = target_key + self.transform = transform + + def __call__(self, ctx: TransformCTX) -> TransformCTX: + source_key = self.source_key + target_key = self.target_key + source_dict, target_dict = ctx.source_state, ctx.target_state + + fn_params = dict(inspect.signature(self.transform).parameters) + fn_params.pop("ctx", None) + + if isinstance(source_key, (dict, tuple)): + if isinstance(source_key, tuple): + source_key_dict = {param: source_key[i] for i, param in enumerate(fn_params)} + else: + source_key_dict = source_key + source_matches_dict = {k: _match_keys(list(source_dict.keys()), v) for k, v in source_key_dict.items()} + target_matches = _match_keys(list(target_dict.keys()), target_key) + + for target_index, target_match in np.ndenumerate(target_matches): + kwargs = {} + for param in fn_params: + if param in source_matches_dict: + source_match = source_matches_dict[param][target_index[:-1]] + kwargs[param] = source_dict[source_match[target_index]] + + target_dict[target_match] = self.call_transform(ctx, **kwargs) + else: + source_keys = list(source_dict.keys()) + target_keys = list(target_dict.keys()) + + source_matches = _match_keys(source_keys, source_key) + if source_matches.size == 1 and source_matches == np.array(None): + raise ValueError(f"No matches found for source key: {source_key}") + + if isinstance(target_key, str): + target_matches = _match_keys(target_keys, target_key) + if target_matches.size < 1: + raise ValueError(f"No matches found for target key: {target_key}") + else: + if isinstance(target_key, dict): + raise ValueError("Target key must be a string or a tuple of strings.") + + _matches = np.vstack([_match_keys(target_keys, key) for key in target_key]) + target_matches = np.transpose(_matches) + + # Determine if we are dealing with multiple source matches or multiple target matches + multiple_sources = source_matches.ndim >= target_matches.ndim + accepts_var_args = any( + param.kind == param.VAR_POSITIONAL for param in inspect.signature(self.transform).parameters.values() + ) + + if multiple_sources: + for target_index, target_match in np.ndenumerate(target_matches): + source_match = source_matches[target_index] + + if accepts_var_args: + source_values = [source_dict[k] for k in source_match] + target_dict[target_match] = self.call_transform(ctx, *source_values) + else: + _source_match_list = [source_match] if isinstance(source_match, str) else list(source_match) + if len(fn_params) != len(_source_match_list): + raise ValueError( + f"Mismatch between source and target keys: {source_match} vs {target_match}" + ) + + kwargs = {param: source_dict[k] for param, k in zip(fn_params, _source_match_list)} + target_dict[target_match] = self.call_transform(ctx, **kwargs) + else: + if source_matches.ndim == 0: + source_matches_list = [source_matches.item()] + source_matches = np.array(source_matches_list, dtype=object) + else: + source_matches_list = list(source_matches) + + if source_matches.shape[0] != target_matches.shape[0]: + if target_matches.shape[0] == 1 and source_matches.shape[0] == target_matches.shape[1]: + source_matches_list = [source_matches_list] + else: + raise ValueError( + "Mismatch between source and target keys: {source_matches} vs {target_matches}" + ) + + for source_index, source_match in enumerate(source_matches_list): + target_match = target_matches[source_index] + source_values = ( + [source_dict[source_match]] + if np.isscalar(source_match) + else [source_dict[k] for k in source_match] + ) + if accepts_var_args: + outputs = self.call_transform(ctx, *source_values) + else: + kwargs = {param: val for param, val in zip(fn_params, source_values)} + outputs = self.call_transform(ctx, **kwargs) + + if isinstance(target_match, str): + target_dict[target_match] = outputs + else: + for i, t in enumerate(outputs): + target_dict[target_match[i]] = t + + return ctx + + def call_transform(self, ctx: TransformCTX, *args, **kwargs): + func_params = inspect.signature(self.transform).parameters + expected_num_args = len([p for p in func_params if p not in ['self', 'ctx']]) + provided_num_args = len(args) + len(kwargs) + accepts_var_args = any(param.kind == param.VAR_POSITIONAL for param in func_params.values()) + + if not accepts_var_args and provided_num_args != expected_num_args: + raise ValueError( + f"Expected {expected_num_args} arguments for the transformation function, but got {provided_num_args}." + ) + + if 'ctx' in func_params: + return self.transform(ctx, *args, **kwargs) + + return self.transform(*args, **kwargs) + + +def _match_keys(keys: List[str], pattern: str) -> np.ndarray: + regex_pattern = re.compile("^" + pattern.replace("*", "(.*)") + "$") + wildcard_matches = [[] for _ in range(pattern.count("*"))] + + for key in keys: + match = regex_pattern.match(key) + if match: + for i, group in enumerate(match.groups()): + if group not in wildcard_matches[i]: + wildcard_matches[i].append(group) + + # Sort the wildcard matches to maintain consistent ordering + for i in range(len(wildcard_matches)): + wildcard_matches[i].sort(key=lambda x: int(x) if x.isdigit() else x) + + # Determine the shape of the output array based on the unique matches for each wildcard + shape = [len(matches) for matches in wildcard_matches] + + # Initialize an empty array with the determined shape + output_array = np.empty(shape, dtype=object) + + # Populate the array with the keys, now that we have the correct shape and ordering + for key in keys: + match = regex_pattern.match(key) + if match: + # Convert match groups to indices based on their position in wildcard_matches + indices = [wildcard_matches[i].index(group) for i, group in enumerate(match.groups())] + output_array[tuple(indices)] = key # Place the key in the array based on the indices + + return output_array + + +@overload +def state_transform( + source_key: Union[str, Tuple[str, ...], Dict[str, str]], target_key: Union[str, Tuple[str, ...]], +) -> Callable[[F], StateDictTransform[F]]: + ... + + +@overload +def state_transform( + source_key: Union[str, Tuple[str, ...], Dict[str, str]], target_key: Union[str, Tuple[str, ...]], fn: F +) -> StateDictTransform[F]: + ... + + +def state_transform( + source_key: Union[str, Tuple[str, ...], Dict[str, str]], + target_key: Union[str, Tuple[str, ...]], + fn: Optional[F] = None, +): + """ + A decorator for creating StateDictTransform instances with specified source and target keys, + and a transformation function. This allows for concise definition of state dictionary + transformations. + + Args: + source_key: A string, tuple of strings, or a dictionary specifying the keys in the source + state dictionary to match. Wildcards (*) are supported. + target_key: A string or tuple of strings specifying the keys in the target state dictionary + to match. Wildcards (*) are supported. + fn: An optional callable that performs the transformation on matched keys' values. If not + provided, the decorator can be used to wrap a function definition. + + Returns + ------- + A StateDictTransform instance if `fn` is provided, otherwise returns a decorator that + takes a function and returns a StateDictTransform instance. + + Examples + -------- + >>> @state_transform( + ... source_key="model.layers.*.self_attn.*_proj.weight", + ... target_key="decoder.layers.*.self_attention.linear_qkv.weight" + ... ) + ... def sum_transform(ctx, *args): + ... return sum(args) + """ + + def wrapper(fn) -> StateDictTransform: + return StateDictTransform(source_key, target_key, fn) + + if fn is None: + return wrapper + + return wrapper(fn) diff --git a/nemo/llm/__init__.py b/nemo/llm/__init__.py index 2dd39b3f170e..a05c96f60944 100644 --- a/nemo/llm/__init__.py +++ b/nemo/llm/__init__.py @@ -1,5 +1,13 @@ from nemo.llm.gpt.data import MockDataModule -from nemo.llm.gpt.model import GPTConfig, GPTModel, MaskedTokenLossReduction, gpt_data_step, gpt_forward_step +from nemo.llm.gpt.model import ( + GPTConfig, + GPTModel, + MaskedTokenLossReduction, + Mistral7BConfig, + Mistral7BModel, + gpt_data_step, + gpt_forward_step, +) __all__ = [ "MockDataModule", @@ -8,4 +16,6 @@ "gpt_data_step", "gpt_forward_step", "MaskedTokenLossReduction", + "Mistral7BConfig", + "Mistral7BModel", ] diff --git a/nemo/llm/gpt/model/__init__.py b/nemo/llm/gpt/model/__init__.py index 9481e75542ed..05c3e9928fab 100644 --- a/nemo/llm/gpt/model/__init__.py +++ b/nemo/llm/gpt/model/__init__.py @@ -1,3 +1,12 @@ from nemo.llm.gpt.model.base import GPTConfig, GPTModel, MaskedTokenLossReduction, gpt_data_step, gpt_forward_step +from nemo.llm.gpt.model.mistral_7b import Mistral7BConfig, Mistral7BModel -__all__ = ["GPTConfig", "GPTModel", "MaskedTokenLossReduction", "gpt_data_step", "gpt_forward_step"] +__all__ = [ + "GPTConfig", + "GPTModel", + "Mistral7BConfig", + "Mistral7BModel", + "MaskedTokenLossReduction", + "gpt_data_step", + "gpt_forward_step", +] diff --git a/nemo/llm/gpt/model/base.py b/nemo/llm/gpt/model/base.py index 93186a7e7e08..554870712a36 100644 --- a/nemo/llm/gpt/model/base.py +++ b/nemo/llm/gpt/model/base.py @@ -23,7 +23,9 @@ class GPTConfig(TransformerConfig, io.IOMixin): fp16_lm_cross_entropy: bool = False parallel_output: bool = True share_embeddings_and_output_weights: bool = False + make_vocab_size_divisible_by: int = 128 position_embedding_type: Literal["learned_absolute", "rope"] = "learned_absolute" + rotary_base: int = 10000 rotary_percent: float = 1.0 seq_len_interpolation_factor: Optional[float] = None seq_length: int = 1024 @@ -48,20 +50,21 @@ def configure_model(self, tokenizer) -> "MCoreGPTModel": return MCoreGPTModel( self, transformer_layer_spec=get_gpt_layer_with_transformer_engine_spec(), - vocab_size=get_vocab_size(self, tokenizer.vocab_size), + vocab_size=get_vocab_size(self, tokenizer.vocab_size, self.make_vocab_size_divisible_by), max_sequence_length=self.seq_length, fp16_lm_cross_entropy=self.fp16_lm_cross_entropy, parallel_output=self.parallel_output, share_embeddings_and_output_weights=self.share_embeddings_and_output_weights, position_embedding_type=self.position_embedding_type, rotary_percent=self.rotary_percent, + rotary_base=self.rotary_base, seq_len_interpolation_factor=self.seq_len_interpolation_factor, pre_process=parallel_state.is_pipeline_first_stage(), post_process=parallel_state.is_pipeline_last_stage(), ) -class GPTModel(L.LightningModule, io.IOMixin): +class GPTModel(L.LightningModule, io.IOMixin, io.ConnectorMixin): def __init__( self, config: GPTConfig, diff --git a/nemo/llm/gpt/model/mistral_7b.py b/nemo/llm/gpt/model/mistral_7b.py new file mode 100644 index 000000000000..83d3b3412a39 --- /dev/null +++ b/nemo/llm/gpt/model/mistral_7b.py @@ -0,0 +1,263 @@ +from dataclasses import dataclass, field +from pathlib import Path +from typing import TYPE_CHECKING, Callable, List, Optional + +import torch +import torch.nn.functional as F + +from nemo import io +from nemo.llm.gpt.model.base import GPTConfig, GPTModel + +if TYPE_CHECKING: + from transformers import MistralConfig, MistralForCausalLM + + from nemo.collections.common.tokenizers.huggingface.auto_tokenizer import AutoTokenizer + + +@dataclass +class Mistral7BConfig(GPTConfig): + normalization: str = "RMSNorm" + activation_func: Callable = F.silu + position_embedding_type: str = "rope" + add_bias_linear: bool = False + gated_linear_unit: bool = True + apply_query_key_layer_scaling: bool = True + + num_layers: int = 32 + hidden_size: int = 4096 + num_attention_heads: int = 32 + num_query_groups: int = 8 + ffn_hidden_size: int = 14336 + seq_length: int = 32768 + + init_method_std: float = 0.02 + layernorm_epsilon: float = 1e-5 + window_size: List[int] = field(default_factory=lambda: [4096, 0]) + + +class Mistral7BModel(GPTModel): + def __init__(self, config: Optional[Mistral7BConfig] = None, tokenizer=None): + _tokenizer = tokenizer or HFMistral7BImporter().tokenizer + + super().__init__(config or Mistral7BConfig(), _tokenizer) + + +@io.model_importer(Mistral7BModel, "hf", default_path="mistralai/Mistral-7B-v0.1") +class HFMistral7BImporter(io.ModelConnector["MistralForCausalLM", Mistral7BModel]): + def init(self) -> Mistral7BModel: + return Mistral7BModel(self.config, tokenizer=self.tokenizer) + + def apply(self, output_path: Path) -> Path: + from transformers import MistralForCausalLM + + source = MistralForCausalLM.from_pretrained(str(self)) + target = self.init() + trainer = self.nemo_setup(target) + self.convert_state(source, target) + self.nemo_save(output_path, trainer) + + return output_path + + def convert_state(self, source, target): + mapping = { + "model.embed_tokens.weight": "embedding.word_embeddings.weight", + "model.layers.*.self_attn.o_proj.weight": "decoder.layers.*.self_attention.linear_proj.weight", + "model.layers.*.mlp.down_proj.weight": "decoder.layers.*.mlp.linear_fc2.weight", + "model.layers.*.input_layernorm.weight": "decoder.layers.*.self_attention.linear_qkv.layer_norm_weight", + "model.layers.*.post_attention_layernorm.weight": "decoder.layers.*.mlp.linear_fc1.layer_norm_weight", + "model.norm.weight": "decoder.final_layernorm.weight", + "lm_head.weight": "output_layer.weight", + } + + return io.apply_transforms(source, target, mapping=mapping, transforms=[_import_qkv, _import_linear_fc1]) + + @property + def tokenizer(self) -> "AutoTokenizer": + from nemo.collections.common.tokenizers.huggingface.auto_tokenizer import AutoTokenizer + + return AutoTokenizer(str(self)) + + @property + def config(self) -> Mistral7BConfig: + from transformers import MistralConfig + + source = MistralConfig.from_pretrained(str(self)) + + def make_vocab_size_divisible_by(mistral_vocab_size): + base = 128 + while mistral_vocab_size % base != 0: + base //= 2 + return base + + output = Mistral7BConfig( + seq_length=source.max_position_embeddings, + num_layers=source.num_hidden_layers, + hidden_size=source.hidden_size, + ffn_hidden_size=source.intermediate_size, + num_attention_heads=source.num_attention_heads, + init_method_std=source.initializer_range, + layernorm_epsilon=source.rms_norm_eps, + num_query_groups=source.num_key_value_heads, + rotary_base=source.rope_theta, + gated_linear_unit=True, + make_vocab_size_divisible_by=make_vocab_size_divisible_by(source.vocab_size), + window_size=[source.sliding_window, 0], + ) + + return output + + +@io.model_exporter(Mistral7BModel, "hf") +class HFMistral7BExporter(io.ModelConnector[Mistral7BModel, "MistralForCausalLM"]): + def init(self) -> "MistralForCausalLM": + from transformers import AutoModelForCausalLM + + return AutoModelForCausalLM.from_config(self.config) + + def apply(self, output_path: Path) -> Path: + # TODO: Make it work with lazy init + # with torch.device("meta"): + # target = self.init() + target = self.init() + source, _ = self.nemo_load(str(self)) + target = self.convert_state(source, target) + + # TODO: Make sure we don't need to do this + target = target.cpu() + target.save_pretrained(output_path) + self.tokenizer.save_pretrained(output_path) + + return output_path + + def convert_state(self, source, target): + mapping = { + "embedding.word_embeddings.weight": "model.embed_tokens.weight", + "decoder.layers.*.self_attention.linear_proj.weight": "model.layers.*.self_attn.o_proj.weight", + "decoder.layers.*.mlp.linear_fc2.weight": "model.layers.*.mlp.down_proj.weight", + "decoder.layers.*.self_attention.linear_qkv.layer_norm_weight": "model.layers.*.input_layernorm.weight", + "decoder.layers.*.mlp.linear_fc1.layer_norm_weight": "model.layers.*.post_attention_layernorm.weight", + "decoder.final_layernorm.weight": "model.norm.weight", + "output_layer.weight": "lm_head.weight", + } + + return io.apply_transforms(source, target, mapping=mapping, transforms=[_export_qkv, _export_linear_fc1]) + + @property + def tokenizer(self): + return io.load_ckpt(str(self)).model.tokenizer.tokenizer + + @property + def config(self) -> "MistralConfig": + source: Mistral7BConfig = io.load_ckpt(str(self)).model.config + + from transformers import MistralConfig + + return MistralConfig( + sliding_window=source.window_size[0], + num_hidden_layers=source.num_layers, + hidden_size=source.hidden_size, + intermediate_size=source.ffn_hidden_size, + num_attention_heads=source.num_attention_heads, + max_position_embeddings=source.seq_length, + initializer_range=source.init_method_std, + rms_norm_eps=source.layernorm_epsilon, + num_key_value_heads=source.num_query_groups, + rope_theta=source.rotary_base, + vocab_size=self.tokenizer.vocab_size, + ) + + +@io.state_transform( + source_key=( + "model.layers.*.self_attn.q_proj.weight", + "model.layers.*.self_attn.k_proj.weight", + "model.layers.*.self_attn.v_proj.weight", + ), + target_key="decoder.layers.*.self_attention.linear_qkv.weight", +) +def _import_qkv(ctx: io.TransformCTX, q, k, v): + megatron_config = ctx.target.config + + head_num = megatron_config.num_attention_heads + num_query_groups = megatron_config.num_query_groups + heads_per_group = head_num // num_query_groups + hidden_size = megatron_config.hidden_size + head_num = megatron_config.num_attention_heads + head_size = hidden_size // head_num + + old_tensor_shape = q.size() + new_q_tensor_shape = (head_num, head_size) + old_tensor_shape[1:] + new_kv_tensor_shape = (num_query_groups, head_size) + old_tensor_shape[1:] + + q = q.view(*new_q_tensor_shape) + k = k.view(*new_kv_tensor_shape) + v = v.view(*new_kv_tensor_shape) + + qkv_weights_l = [] + for i in range(num_query_groups): + qkv_weights_l.append(q[i * heads_per_group : (i + 1) * heads_per_group, :, :]) + qkv_weights_l.append(k[i : i + 1, :, :]) + qkv_weights_l.append(v[i : i + 1, :, :]) + qkv_weights = torch.cat(qkv_weights_l) + assert qkv_weights.ndim == 3, qkv_weights.shape + assert qkv_weights.shape[0] == (heads_per_group + 2) * num_query_groups, qkv_weights.shape + assert qkv_weights.shape[1] == head_size, qkv_weights.shape + assert qkv_weights.shape[2] == old_tensor_shape[1], qkv_weights.shape + + qkv_weights = qkv_weights.reshape([head_size * (head_num + 2 * num_query_groups), hidden_size]) + + return qkv_weights + + +@io.state_transform( + source_key="decoder.layers.*.self_attention.linear_qkv.weight", + target_key=( + "model.layers.*.self_attn.q_proj.weight", + "model.layers.*.self_attn.k_proj.weight", + "model.layers.*.self_attn.v_proj.weight", + ), +) +def _export_qkv(ctx: io.TransformCTX, linear_qkv): + megatron_config = ctx.source.config + + head_num = megatron_config.num_attention_heads + num_query_groups = megatron_config.num_query_groups + heads_per_group = head_num // num_query_groups + hidden_size = megatron_config.hidden_size + head_num = megatron_config.num_attention_heads + head_size = hidden_size // head_num + qkv_total_dim = head_num + 2 * num_query_groups + + linear_qkv = linear_qkv.reshape([qkv_total_dim, head_size, hidden_size]) + q_slice = torch.cat( + [ + torch.arange((heads_per_group + 2) * i, (heads_per_group + 2) * i + heads_per_group) + for i in range(num_query_groups) + ] + ) + k_slice = torch.arange(heads_per_group, qkv_total_dim, (heads_per_group + 2)) + v_slice = torch.arange(heads_per_group + 1, qkv_total_dim, (heads_per_group + 2)) + + q_proj = linear_qkv[q_slice].reshape(-1, hidden_size).cpu() + k_proj = linear_qkv[k_slice].reshape(-1, hidden_size).cpu() + v_proj = linear_qkv[v_slice].reshape(-1, hidden_size).cpu() + + return q_proj, k_proj, v_proj + + +@io.state_transform( + source_key=("model.layers.*.mlp.gate_proj.weight", "model.layers.*.mlp.up_proj.weight"), + target_key="decoder.layers.*.mlp.linear_fc1.weight", +) +def _import_linear_fc1(down, gate): + return torch.cat((down, gate), axis=0).float() + + +@io.state_transform( + source_key="decoder.layers.*.mlp.linear_fc1.weight", + target_key=("model.layers.*.mlp.gate_proj.weight", "model.layers.*.mlp.up_proj.weight"), +) +def _export_linear_fc1(linear_fc1): + gate_proj, up_proj = torch.chunk(linear_fc1, 2, dim=0) + + return gate_proj, up_proj diff --git a/requirements/requirements_lightning.txt b/requirements/requirements_lightning.txt index 6acfddad9189..5ad2519cfd1a 100644 --- a/requirements/requirements_lightning.txt +++ b/requirements/requirements_lightning.txt @@ -1,3 +1,4 @@ +cloudpickle fiddle hydra-core>1.3,<=1.3.2 omegaconf<=2.3 diff --git a/tests/io/test_state.py b/tests/io/test_state.py new file mode 100644 index 000000000000..bb5dc4a9af3d --- /dev/null +++ b/tests/io/test_state.py @@ -0,0 +1,233 @@ +import pytest +from torch import nn + +from nemo.io.state import StateDictTransform, TransformCTX, state_transform + + +class TestStateDictTransform: + """ + Tests for the StateDictTransform functionality. + """ + + @pytest.fixture + def mock_ctx(self): + """ + Provides a mock transformation context with predefined source and target states. + + Returns + ------- + TransformCTX: A context object with source and target states. + """ + source_state = { + 'model.layers.0.self_attn.q_proj.weight': 1, + 'model.layers.0.self_attn.k_proj.weight': 2, + 'model.layers.0.self_attn.v_proj.weight': 3, + 'model.layers.1.self_attn.q_proj.weight': 1, + 'model.layers.1.self_attn.k_proj.weight': 2, + 'model.layers.1.self_attn.v_proj.weight': 3, + } + target_state = { + "decoder.layers.0.self_attention.linear_qkv.weight": 10, + "decoder.layers.1.self_attention.linear_qkv.weight": 10, + } + ctx = TransformCTX( + source=nn.Module(), source_state=source_state, target=nn.Module(), target_state=target_state + ) + return ctx + + @pytest.fixture + def mock_multi_target_ctx(self): + """ + Provides a mock transformation context with a source state that matches the expected source_key + and a target state prepared with initial values for the expected target_keys. + """ + source_state = {'model.layers.1.self_attn.q_proj.weight': 1} + # Populate target_state with initial placeholder values for keys expected to be matched and updated + target_state = { + 'decoder.layers.1.self_attention.linear_q.weight': 0, + 'decoder.layers.1.self_attention.linear_k.weight': 0, + } + ctx = TransformCTX( + source=nn.Module(), source_state=source_state, target=nn.Module(), target_state=target_state + ) + return ctx + + def test_transform_with_multiple_source_keys(self, mock_ctx): + """ + Test transformation when multiple source keys are specified. + """ + transform = StateDictTransform( + source_key=( + "model.layers.*.self_attn.q_proj.weight", + "model.layers.*.self_attn.k_proj.weight", + "model.layers.*.self_attn.v_proj.weight", + ), + target_key="decoder.layers.*.self_attention.linear_qkv.weight", + transform=lambda ctx, k, q, v: q + k + v, + ) + transform(mock_ctx) + assert mock_ctx.target_state["decoder.layers.0.self_attention.linear_qkv.weight"] == 6 + assert mock_ctx.target_state["decoder.layers.1.self_attention.linear_qkv.weight"] == 6 + + def test_transform_with_wildcard_in_source_keys(self, mock_ctx): + """ + Test transformation using a wildcard pattern in source keys. + """ + transform = StateDictTransform( + source_key="model.layers.*.self_attn.*_proj.weight", + target_key="decoder.layers.*.self_attention.linear_qkv.weight", + transform=lambda ctx, k, q, v: q + k + v, + ) + transform(mock_ctx) + assert mock_ctx.target_state["decoder.layers.0.self_attention.linear_qkv.weight"] == 6 + assert mock_ctx.target_state["decoder.layers.1.self_attention.linear_qkv.weight"] == 6 + + def test_transform_with_mapped_source_keys(self, mock_ctx): + """ + Test transformation with a dictionary mapping for source keys. + """ + transform = StateDictTransform( + source_key={ + "k": "model.layers.*.self_attn.k_proj.weight", + "q": "model.layers.*.self_attn.q_proj.weight", + "v": "model.layers.*.self_attn.v_proj.weight", + }, + target_key="decoder.layers.*.self_attention.linear_qkv.weight", + transform=lambda ctx, k, q, v: q + k + v, + ) + transform(mock_ctx) + assert mock_ctx.target_state["decoder.layers.0.self_attention.linear_qkv.weight"] == 6 + assert mock_ctx.target_state["decoder.layers.1.self_attention.linear_qkv.weight"] == 6 + + def test_transform_with_variable_arguments(self, mock_ctx): + """ + Test transformation with a wildcard pattern and variable arguments. + """ + transform = StateDictTransform( + source_key="model.layers.*.self_attn.*_proj.weight", + target_key="decoder.layers.*.self_attention.linear_qkv.weight", + transform=lambda ctx, *args: sum(args), + ) + transform(mock_ctx) + assert mock_ctx.target_state["decoder.layers.0.self_attention.linear_qkv.weight"] == 6 + assert mock_ctx.target_state["decoder.layers.1.self_attention.linear_qkv.weight"] == 6 + + def test_transform_with_no_matching_source_keys(self, mock_ctx): + """ + Test transformation when no source keys match the pattern. + """ + transform = StateDictTransform( + source_key="non.existent.pattern", + target_key="decoder.layers.*.self_attention.linear_qkv.weight", + transform=lambda ctx, *args: sum(args), + ) + with pytest.raises(ValueError): + transform(mock_ctx) + + def test_transform_with_invalid_transform_function(self, mock_ctx): + """ + Test transformation with a transform function that does not match expected signature. + """ + transform = StateDictTransform( + source_key="model.layers.*.self_attn.q_proj.weight", + target_key="decoder.layers.*.self_attention.linear_qkv.weight", + transform=lambda ctx: 0, # Invalid signature + ) + with pytest.raises(ValueError): + transform(mock_ctx) + + def test_transform_with_tuple_target_key_and_multiple_outputs(self, mock_multi_target_ctx): + """ + Test transformation where the target_key is a tuple and the transform function + returns multiple values that are then unrolled to these target keys. + """ + # Define a transformation that splits the input into two parts + def split_transform(ctx, x): + return x - 1, x + 1 + + # Apply the transformation + transform = StateDictTransform( + source_key="model.layers.1.self_attn.q_proj.weight", + target_key=( + "decoder.layers.1.self_attention.linear_q.weight", + "decoder.layers.1.self_attention.linear_k.weight", + ), + transform=split_transform, + ) + transform(mock_multi_target_ctx) + + # Check that the target state has been updated correctly + assert mock_multi_target_ctx.target_state["decoder.layers.1.self_attention.linear_q.weight"] == 0 + assert mock_multi_target_ctx.target_state["decoder.layers.1.self_attention.linear_k.weight"] == 2 + + +class TestStateTransformDecorator: + """ + Tests for the @state_transform decorator functionality. + """ + + @pytest.fixture + def mock_ctx(self): + """ + Provides a mock transformation context with predefined source and target states. + """ + source_state = { + 'model.layers.1.self_attn.q_proj.weight': 1, + 'model.layers.1.self_attn.k_proj.weight': 2, + 'model.layers.1.self_attn.v_proj.weight': 3, + } + # Pre-populate target_state with initial values or placeholders + target_state = { + "decoder.layers.1.self_attention.linear_q.weight": 0, + "decoder.layers.1.self_attention.linear_k.weight": 0, + "decoder.layers.1.self_attention.linear_v.weight": 0, + } + ctx = TransformCTX( + source=nn.Module(), source_state=source_state, target=nn.Module(), target_state=target_state + ) + return ctx + + def test_single_transform(self, mock_ctx): + """ + Test the @state_transform decorator with a single source and target key. + """ + # Apply the transformation + single_transform(mock_ctx) + # Verify the target state is updated correctly + assert mock_ctx.target_state["decoder.layers.1.self_attention.linear_q.weight"] == 11 + + def test_multiple_outputs_transform(self, mock_ctx): + """ + Test the @state_transform decorator with a single source key and multiple target keys. + """ + # Apply the transformation + multiple_outputs_transform(mock_ctx) + # Verify the target state is updated correctly for each key + assert mock_ctx.target_state["decoder.layers.1.self_attention.linear_q.weight"] == 2 + assert mock_ctx.target_state["decoder.layers.1.self_attention.linear_k.weight"] == 1 + assert mock_ctx.target_state["decoder.layers.1.self_attention.linear_v.weight"] == 3 + + +@state_transform( + source_key="model.layers.*.self_attn.q_proj.weight", target_key="decoder.layers.1.self_attention.linear_q.weight" +) +def single_transform(ctx, x): + """ + A single transformation function that adds 10 to the input value. + """ + return x + 10 + + +@state_transform( + source_key="model.layers.1.self_attn.*_proj.weight", + target_key=( + "decoder.layers.1.self_attention.linear_q.weight", + "decoder.layers.1.self_attention.linear_k.weight", + "decoder.layers.1.self_attention.linear_v.weight", + ), +) +def multiple_outputs_transform(ctx, *args): + """ + A transformation function that returns multiple values for multiple target keys. + """ + return args From 894e5022651f6b31523964333d07937344d258f0 Mon Sep 17 00:00:00 2001 From: Vladimir Bataev Date: Fri, 3 May 2024 15:10:15 +0400 Subject: [PATCH 15/73] RNN-T and TDT inference: use CUDA graphs by default (#8972) * Use Cuda graphs by default for RNN-T and TDT Signed-off-by: Vladimir Bataev --------- Signed-off-by: Vladimir Bataev --- nemo/collections/asr/models/asr_model.py | 51 ++++- nemo/collections/asr/modules/rnnt.py | 4 +- .../cuda_graph_rnnt_greedy_decoding.py | 13 +- .../asr/parts/submodules/rnnt_decoding.py | 4 +- .../parts/submodules/rnnt_greedy_decoding.py | 98 +++++++-- .../submodules/rnnt_loop_labels_computer.py | 180 +++++++++++++--- .../submodules/tdt_loop_labels_computer.py | 199 ++++++++++++++---- .../common/parts/optional_cuda_graphs.py | 89 ++++++++ nemo/core/utils/cuda_python_utils.py | 2 +- .../asr/decoding/rnnt_alignments_check.py | 12 +- .../test_cuda_graph_rnnt_greedy_decoding.py | 138 +++++++++--- .../asr/test_asr_rnnt_encdec_model.py | 18 +- .../common/test_optional_cuda_graphs.py | 71 +++++++ 13 files changed, 746 insertions(+), 133 deletions(-) create mode 100644 nemo/collections/common/parts/optional_cuda_graphs.py create mode 100644 tests/collections/common/test_optional_cuda_graphs.py diff --git a/nemo/collections/asr/models/asr_model.py b/nemo/collections/asr/models/asr_model.py index e14424cec5c1..0539f961a1ca 100644 --- a/nemo/collections/asr/models/asr_model.py +++ b/nemo/collections/asr/models/asr_model.py @@ -12,11 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from abc import ABC, abstractmethod -from typing import List +from abc import ABC +from typing import List, Optional import torch +from nemo.collections.common.parts.optional_cuda_graphs import WithOptionalCudaGraphs from nemo.core.classes import ModelPT from nemo.core.classes.common import PretrainedModelInfo from nemo.core.classes.exportable import Exportable @@ -171,6 +172,52 @@ def on_after_backward(self): logging.warning(f'detected inf or nan values in gradients! Setting gradients to zero.') self.zero_grad() + def on_train_epoch_start(self) -> None: + """ + Decoder with CUDA graphs does not release memory, thus we disable it for training epoch. + EncDecRNNTModel.decoding.decoding is the inference class with CUDA graphs + """ + WithOptionalCudaGraphs.disable_cuda_graphs_recursive(self, attribute_path="decoding.decoding") + + def on_train_epoch_end(self) -> None: + """ + After training, we can enable the decoder with CUDA graphs. + EncDecRNNTModel.decoding.decoding is the inference class with CUDA graphs + """ + WithOptionalCudaGraphs.enable_cuda_graphs_recursive(self, attribute_path="decoding.decoding") + + def on_validation_epoch_start(self) -> None: + """ + For validation, we enable CUDA graphs to speedup validation. + EncDecRNNTModel.decoding.decoding is the inference class with CUDA graphs. + """ + WithOptionalCudaGraphs.enable_cuda_graphs_recursive(self, attribute_path="decoding.decoding") + + def on_validation_epoch_end(self) -> Optional[dict[str, dict[str, torch.Tensor]]]: + """ + After validation, we disable CUDA graphs, since `validation` can be called in training loop, and + training will continue after validation + EncDecRNNTModel.decoding.decoding is the inference class with CUDA graphs. + """ + WithOptionalCudaGraphs.disable_cuda_graphs_recursive(self, attribute_path="decoding.decoding") + return super().on_validation_epoch_end() + + def on_test_epoch_start(self) -> None: + """ + For testing, we enable CUDA graphs to speedup validation. + We do not need to disable CUDA graphs after testing, since `test` cannot be called in training loop. + EncDecRNNTModel.decoding.decoding is the inference class with CUDA graphs. + """ + WithOptionalCudaGraphs.enable_cuda_graphs_recursive(self, attribute_path="decoding.decoding") + + def on_predict_epoch_start(self) -> None: + """ + For predicting, we enable CUDA graphs to speedup validation. + We do not need to disable CUDA graphs after predicting, since `predict` cannot be called in training loop. + EncDecRNNTModel.decoding.decoding is the inference class with CUDA graphs + """ + WithOptionalCudaGraphs.enable_cuda_graphs_recursive(self, attribute_path="decoding.decoding") + class ExportableEncDecModel(Exportable): """ diff --git a/nemo/collections/asr/modules/rnnt.py b/nemo/collections/asr/modules/rnnt.py index 055066c00660..2355cfb7005b 100644 --- a/nemo/collections/asr/modules/rnnt.py +++ b/nemo/collections/asr/modules/rnnt.py @@ -312,7 +312,9 @@ def initialize_state(self, y: torch.Tensor) -> List[torch.Tensor]: batch = y.size(0) # state contains context_size - 1 elements for each utterance in batch, # consistent with the state returned from StatelessNet.forward - state = [torch.ones([batch, self.context_size - 1], dtype=torch.long, device=y.device) * self.blank_idx] + state = [ + torch.full([batch, self.context_size - 1], fill_value=self.blank_idx, dtype=torch.long, device=y.device) + ] return state def batch_initialize_states(self, batch_states: List[torch.Tensor], decoder_states: List[List[torch.Tensor]]): diff --git a/nemo/collections/asr/parts/submodules/cuda_graph_rnnt_greedy_decoding.py b/nemo/collections/asr/parts/submodules/cuda_graph_rnnt_greedy_decoding.py index 388737443fd4..93cef4d4138e 100644 --- a/nemo/collections/asr/parts/submodules/cuda_graph_rnnt_greedy_decoding.py +++ b/nemo/collections/asr/parts/submodules/cuda_graph_rnnt_greedy_decoding.py @@ -292,14 +292,21 @@ def __call__( partial_hypotheses: Optional[List[rnnt_utils.Hypothesis]] = None, ): if partial_hypotheses is not None: - raise NotImplementedError("`partial_hypotheses` support is not available with cuda graphs (but could be)") + raise NotImplementedError( + "`partial_hypotheses` support is not available " + "with Frame-Looping algorithm with Cuda graphs (not implemented yet)" + ) if self.caller.preserve_alignments: - raise NotImplementedError("`preserve_alignments` support is not available with cuda graphs (but could be)") + raise NotImplementedError( + "`preserve_alignments` support is not available" + "with Frame-Looping algorithm with Cuda graphs (not implemented yet)" + ) if self.caller.preserve_frame_confidence: raise NotImplementedError( - "`preserve_frame_confidence` support is not available with cuda graphs (but could be)" + "`preserve_frame_confidence` support is not available" + "with Frame-Looping algorithm with Cuda graphs (not implemented yet)" ) batch_size = x.shape[0] diff --git a/nemo/collections/asr/parts/submodules/rnnt_decoding.py b/nemo/collections/asr/parts/submodules/rnnt_decoding.py index 71079f4b6382..5fa225864f8c 100644 --- a/nemo/collections/asr/parts/submodules/rnnt_decoding.py +++ b/nemo/collections/asr/parts/submodules/rnnt_decoding.py @@ -331,7 +331,7 @@ def __init__(self, decoding_cfg, decoder, joint, blank_id: int): preserve_frame_confidence=self.preserve_frame_confidence, confidence_method_cfg=self.confidence_method_cfg, loop_labels=self.cfg.greedy.get('loop_labels', True), - use_cuda_graph_decoder=self.cfg.greedy.get('use_cuda_graph_decoder', False), + use_cuda_graph_decoder=self.cfg.greedy.get('use_cuda_graph_decoder', True), ) else: self.decoding = rnnt_greedy_decoding.GreedyBatchedTDTInfer( @@ -347,7 +347,7 @@ def __init__(self, decoding_cfg, decoder, joint, blank_id: int): preserve_frame_confidence=self.preserve_frame_confidence, include_duration_confidence=self.tdt_include_duration_confidence, confidence_method_cfg=self.confidence_method_cfg, - use_cuda_graph_decoder=self.cfg.greedy.get('use_cuda_graph_decoder', False), + use_cuda_graph_decoder=self.cfg.greedy.get('use_cuda_graph_decoder', True), ) else: diff --git a/nemo/collections/asr/parts/submodules/rnnt_greedy_decoding.py b/nemo/collections/asr/parts/submodules/rnnt_greedy_decoding.py index e5de99cf0776..b2fa9b85b5fd 100644 --- a/nemo/collections/asr/parts/submodules/rnnt_greedy_decoding.py +++ b/nemo/collections/asr/parts/submodules/rnnt_greedy_decoding.py @@ -38,6 +38,7 @@ from nemo.collections.asr.parts.submodules.tdt_loop_labels_computer import GreedyBatchedTDTLoopLabelsComputer from nemo.collections.asr.parts.utils import rnnt_utils from nemo.collections.asr.parts.utils.asr_confidence_utils import ConfidenceMethodConfig, ConfidenceMethodMixin +from nemo.collections.common.parts.optional_cuda_graphs import WithOptionalCudaGraphs from nemo.collections.common.parts.rnn import label_collate from nemo.core.classes import Typing, typecheck from nemo.core.neural_types import AcousticEncodedRepresentation, HypothesisType, LengthsType, NeuralType @@ -508,7 +509,7 @@ def _greedy_decode( return hypothesis -class GreedyBatchedRNNTInfer(_GreedyRNNTInfer): +class GreedyBatchedRNNTInfer(_GreedyRNNTInfer, WithOptionalCudaGraphs): """A batch level greedy transducer decoder. Batch level greedy decoding, performed auto-regressively. @@ -589,7 +590,7 @@ def __init__( preserve_frame_confidence: bool = False, confidence_method_cfg: Optional[DictConfig] = None, loop_labels: bool = True, - use_cuda_graph_decoder: bool = False, + use_cuda_graph_decoder: bool = True, ): super().__init__( decoder_model=decoder_model, @@ -602,13 +603,14 @@ def __init__( ) self.use_cuda_graph_decoder = use_cuda_graph_decoder + self.loop_labels = loop_labels # Depending on availability of `blank_as_pad` support # switch between more efficient batch decoding technique self._decoding_computer = None if self.decoder.blank_as_pad: - if loop_labels: - # default (faster) algo: loop over labels + if self.loop_labels: + # Label-Looping algorithm (default, faster) self._greedy_decode = self._greedy_decode_blank_as_pad_loop_labels self._decoding_computer = GreedyBatchedRNNTLoopLabelsComputer( decoder=self.decoder, @@ -618,20 +620,74 @@ def __init__( preserve_alignments=preserve_alignments, preserve_frame_confidence=preserve_frame_confidence, confidence_method_cfg=confidence_method_cfg, - allow_cuda_graphs=use_cuda_graph_decoder, + allow_cuda_graphs=self.use_cuda_graph_decoder, ) - elif use_cuda_graph_decoder: - from nemo.collections.asr.parts.submodules.cuda_graph_rnnt_greedy_decoding import ( - RNNTGreedyDecodeCudaGraph, - ) - - self._greedy_decode = RNNTGreedyDecodeCudaGraph(max_symbols_per_step, self) else: - # previous algo: loop over frames - self._greedy_decode = self._greedy_decode_blank_as_pad_loop_frames + # Frame-Looping algorithm + if not self.use_cuda_graph_decoder: + self._greedy_decode = self._greedy_decode_blank_as_pad_loop_frames + else: + if self.preserve_alignments: + logging.warning("`preserve_alignments` is not implemented for Frame-Looping + CUDA graphs") + self.use_cuda_graph_decoder = False + if self.preserve_frame_confidence: + logging.warning( + "`preserve_frame_confidence` is not implemented for Frame-Looping + CUDA graphs" + ) + self.use_cuda_graph_decoder = False + if not torch.cuda.is_available(): + self.use_cuda_graph_decoder = False + + if self.use_cuda_graph_decoder: + try: + from nemo.collections.asr.parts.submodules.cuda_graph_rnnt_greedy_decoding import ( + RNNTGreedyDecodeCudaGraph, + ) + + self._greedy_decode = RNNTGreedyDecodeCudaGraph(max_symbols_per_step, self) + except (ImportError, ModuleNotFoundError, ValueError) as e: + self.use_cuda_graph_decoder = False + logging.warning(f"Cannot use decoder with CUDA graphs, reason: {e.msg}") + self._greedy_decode = self._greedy_decode_blank_as_pad_loop_frames + else: + self._greedy_decode = self._greedy_decode_blank_as_pad_loop_frames else: self._greedy_decode = self._greedy_decode_masked + def disable_cuda_graphs(self): + """Disable CUDA graphs (e.g., for decoding in training)""" + if not self.use_cuda_graph_decoder: + # CUDA graphs not allowed, nothing to do + return + + if not self.decoder.blank_as_pad: + # blank as pad uses decoding without CUDA graphs + return + + if self.loop_labels: + # Label-Looping implementation + self._decoding_computer.disable_cuda_graphs() + else: + self._greedy_decode = self._greedy_decode_blank_as_pad_loop_frames + + def maybe_enable_cuda_graphs(self): + """Enable CUDA graphs (if allowed)""" + if not self.use_cuda_graph_decoder: + # CUDA graphs not allowed, nothing to do + return + + if not self.decoder.blank_as_pad: + # blank as pad uses decoding without CUDA graphs + return + + if self.loop_labels: + # Label-Looping implementation + self._decoding_computer.maybe_enable_cuda_graphs() + else: + from nemo.collections.asr.parts.submodules.cuda_graph_rnnt_greedy_decoding import RNNTGreedyDecodeCudaGraph + + self._greedy_decode = RNNTGreedyDecodeCudaGraph(self.max_symbols, self) + @typecheck() def forward( self, @@ -2302,7 +2358,7 @@ class GreedyBatchedRNNTInferConfig: tdt_include_duration_confidence: bool = False confidence_method_cfg: Optional[ConfidenceMethodConfig] = field(default_factory=lambda: ConfidenceMethodConfig()) loop_labels: bool = True - use_cuda_graph_decoder: bool = False + use_cuda_graph_decoder: bool = True def __post_init__(self): # OmegaConf.structured ensures that post_init check is always executed @@ -2580,7 +2636,7 @@ def _greedy_decode( return hypothesis -class GreedyBatchedTDTInfer(_GreedyRNNTInfer): +class GreedyBatchedTDTInfer(_GreedyRNNTInfer, WithOptionalCudaGraphs): """A batch level greedy TDT decoder. Batch level greedy decoding, performed auto-regressively. Args: @@ -2652,7 +2708,7 @@ def __init__( preserve_frame_confidence: bool = False, include_duration_confidence: bool = False, confidence_method_cfg: Optional[DictConfig] = None, - use_cuda_graph_decoder: bool = False, + use_cuda_graph_decoder: bool = True, ): super().__init__( decoder_model=decoder_model, @@ -2759,3 +2815,13 @@ def _greedy_decode_blank_as_pad_loop_labels( for hyp, state in zip(hyps, self.decoder.batch_split_states(last_decoder_state)): hyp.dec_state = state return hyps + + def disable_cuda_graphs(self): + """Disable CUDA graphs (e.g., for decoding in training)""" + if self._decoding_computer is not None: + self._decoding_computer.disable_cuda_graphs() + + def maybe_enable_cuda_graphs(self): + """Enable CUDA graphs (if allowed)""" + if self._decoding_computer is not None: + self._decoding_computer.maybe_enable_cuda_graphs() diff --git a/nemo/collections/asr/parts/submodules/rnnt_loop_labels_computer.py b/nemo/collections/asr/parts/submodules/rnnt_loop_labels_computer.py index 92cb8a36aeb5..b920dba09cfd 100644 --- a/nemo/collections/asr/parts/submodules/rnnt_loop_labels_computer.py +++ b/nemo/collections/asr/parts/submodules/rnnt_loop_labels_computer.py @@ -12,7 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Optional, Tuple +from dataclasses import dataclass, field +from typing import Any, Optional, Tuple, Union import numpy as np import torch @@ -21,6 +22,7 @@ from nemo.collections.asr.parts.utils import rnnt_utils from nemo.collections.asr.parts.utils.asr_confidence_utils import ConfidenceMethodMixin +from nemo.collections.common.parts.optional_cuda_graphs import WithOptionalCudaGraphs from nemo.core.utils.cuda_python_utils import ( check_cuda_python_cuda_graphs_conditional_nodes_supported, cu_call, @@ -28,6 +30,7 @@ with_conditional_node, ) from nemo.utils import logging +from nemo.utils.enum import PrettyStrEnum try: from cuda import cudart @@ -161,7 +164,17 @@ def need_reinit(self, encoder_output_projected: torch.Tensor) -> bool: ) -class GreedyBatchedRNNTLoopLabelsComputer(ConfidenceMethodMixin): +@dataclass +class SeparateGraphsLoopLabels: + """Class to store Cuda graphs for decoding when separate graphs are used""" + + before_outer_loop: torch.cuda.CUDAGraph = field(default_factory=torch.cuda.CUDAGraph) + before_inner_loop: torch.cuda.CUDAGraph = field(default_factory=torch.cuda.CUDAGraph) + inner_loop_code: torch.cuda.CUDAGraph = field(default_factory=torch.cuda.CUDAGraph) + after_inner_loop: torch.cuda.CUDAGraph = field(default_factory=torch.cuda.CUDAGraph) + + +class GreedyBatchedRNNTLoopLabelsComputer(WithOptionalCudaGraphs, ConfidenceMethodMixin): """ Label Looping algorithm implementation: optimized batched greedy decoding. Callable. Iterates over labels, on each step finding the next non-blank label @@ -174,6 +187,16 @@ class GreedyBatchedRNNTLoopLabelsComputer(ConfidenceMethodMixin): INITIAL_MAX_TIME = 375 # initial max time, used to init state for Cuda graphs CUDA_PROGRAM_NAME = b"while_loop_labels_conditional_rnnt.cu" + class CudaGraphsMode(PrettyStrEnum): + FULL_GRAPH = "full_graph" # Cuda graphs with conditional nodes, fastest implementation + NO_WHILE_LOOPS = "no_while_loops" # Decoding with PyTorch while loops + partial Cuda graphs + NO_GRAPHS = "no_graphs" # decoding without graphs, stateful implementation, only for testing purposes + + separate_graphs: Optional[SeparateGraphsLoopLabels] + full_graph: Optional[torch.cuda.CUDAGraph] + cuda_graphs_mode: Optional[CudaGraphsMode] + state: Optional[LoopLabelsState] + def __init__( self, decoder, @@ -203,24 +226,66 @@ def __init__( self.max_symbols = max_symbols_per_step self.preserve_alignments = preserve_alignments self.preserve_frame_confidence = preserve_frame_confidence + self.allow_cuda_graphs = allow_cuda_graphs self._SOS = self._blank_index self._init_confidence_method(confidence_method_cfg=confidence_method_cfg) assert self._SOS == self._blank_index # "blank as pad" algorithm only - self.use_cuda_graphs = allow_cuda_graphs + self.state = None + self.full_graph = None + self.separate_graphs = None - if self.use_cuda_graphs and self.max_symbols is None: - logging.warning("Max symbols is None, which is not allowed with Cuda graphs.") - self.use_cuda_graphs = False + self.cuda_graphs_mode = None + self.maybe_enable_cuda_graphs() - if self.use_cuda_graphs: + def force_cuda_graphs_mode(self, mode: Optional[Union[str, CudaGraphsMode]]): + """ + Method to set graphs mode. Use only for testing purposes. + For debugging the algorithm use "no_graphs" mode, since it is impossible to debug CUDA graphs directly. + """ + self.cuda_graphs_mode = self.CudaGraphsMode(mode) if mode is not None else None + self.state = None + + def maybe_enable_cuda_graphs(self): + """Enable CUDA graphs if conditions met""" + if self.cuda_graphs_mode is not None: + # CUDA graphs are already enabled + return + + if not self.allow_cuda_graphs: + self.cuda_graphs_mode = None + else: + # cuda graphs are allowed + # check basic requirements for cuda graphs + if self.max_symbols is None: + logging.warning("Max symbols per step is None, which is not allowed with Cuda graphs. Setting to `10`") + self.max_symbols = 10 + # basic requirements met, need to check while loops try: check_cuda_python_cuda_graphs_conditional_nodes_supported() - except ImportError as e: - logging.warning(f"No conditional node support. Cuda graphs will be disabled,\n{e.msg}") - self.use_cuda_graphs = False - - self.state: Optional[LoopLabelsState] = None + self.cuda_graphs_mode = self.CudaGraphsMode.FULL_GRAPH + except (ImportError, ModuleNotFoundError) as e: + logging.warning( + "No conditional node support for Cuda.\n" + "Cuda graphs with while loops are disabled, decoding speed will be slower\n" + f"Reason: {e.msg}" + ) + self.cuda_graphs_mode = self.CudaGraphsMode.NO_WHILE_LOOPS + self.reset_cuda_graphs_state() + + def disable_cuda_graphs(self): + """Disable CUDA graphs, can be used to disable graphs temporary, e.g., in training process""" + if self.cuda_graphs_mode is None: + # nothing to disable + return + self.cuda_graphs_mode = None + self.reset_cuda_graphs_state() + + def reset_cuda_graphs_state(self): + """Reset state to release memory (for CUDA graphs implementations)""" + self.state = None + self.full_graph = None + self.separate_graphs = None def loop_labels_torch( self, encoder_output: torch.Tensor, encoder_output_length: torch.Tensor, @@ -237,6 +302,7 @@ def loop_labels_torch( # do not recalculate joint projection, project only once encoder_output_projected = self.joint.project_encoder(encoder_output) + float_dtype = encoder_output_projected.dtype # init output structures: BatchedHyps (for results), BatchedAlignments + last decoder state # init empty batched hypotheses @@ -244,7 +310,7 @@ def loop_labels_torch( batch_size=batch_size, init_length=max_time * self.max_symbols if self.max_symbols is not None else max_time, device=device, - float_dtype=encoder_output_projected.dtype, + float_dtype=float_dtype, ) # sample state, will be replaced further when the decoding for hypothesis is done last_decoder_state = self.decoder.initialize_state(encoder_output_projected) @@ -256,7 +322,7 @@ def loop_labels_torch( logits_dim=self.joint.num_classes_with_blank, init_length=max_time * 2 if use_alignments else 1, # blank for each timestep + text tokens device=device, - float_dtype=encoder_output_projected.dtype, + float_dtype=float_dtype, store_alignments=self.preserve_alignments, store_frame_confidence=self.preserve_frame_confidence, ) @@ -312,7 +378,7 @@ def loop_labels_torch( time_indices=time_indices_current_labels, logits=logits if self.preserve_alignments else None, labels=labels if self.preserve_alignments else None, - confidence=self._get_confidence_tensor(F.log_softmax(logits, dim=-1)) + confidence=self._get_confidence_tensor(F.log_softmax(logits, dim=-1)).to(dtype=float_dtype) if self.preserve_frame_confidence else None, ) @@ -350,7 +416,7 @@ def loop_labels_torch( time_indices=time_indices_current_labels, logits=logits if self.preserve_alignments else None, labels=more_labels if self.preserve_alignments else None, - confidence=self._get_confidence_tensor(F.log_softmax(logits, dim=-1)) + confidence=self._get_confidence_tensor(F.log_softmax(logits, dim=-1)).to(dtype=float_dtype) if self.preserve_frame_confidence else None, ) @@ -413,6 +479,8 @@ def loop_labels_cuda_graphs( encoder_output: output from the encoder encoder_output_length: lengths of the utterances in `encoder_output` """ + assert self.cuda_graphs_mode is not None + # do not recalculate joint projection, project only once encoder_output = self.joint.project_encoder(encoder_output) current_batch_size = encoder_output.shape[0] @@ -430,16 +498,27 @@ def loop_labels_cuda_graphs( self.state.encoder_output_length[: encoder_output_length.shape[0]].copy_(encoder_output_length) # set length to zero for elements outside the current batch self.state.encoder_output_length[current_batch_size:].fill_(0) - self.graph.replay() - - # example manual loop (can be used instead of graph.replay()) - # self._before_outer_loop() - # while self.state.active_mask_any.item(): - # self._before_inner_loop_get_decoder_output() - # self._before_inner_loop_get_joint_output() - # while self.state.advance_mask_any.item(): - # self._inner_loop_code() - # self._after_inner_loop() + if self.cuda_graphs_mode is self.CudaGraphsMode.FULL_GRAPH: + self.full_graph.replay() + elif self.cuda_graphs_mode is self.CudaGraphsMode.NO_WHILE_LOOPS: + self.separate_graphs.before_outer_loop.replay() + while self.state.active_mask_any.item(): + self.separate_graphs.before_inner_loop.replay() + while self.state.advance_mask_any.item(): + self.separate_graphs.inner_loop_code.replay() + self.separate_graphs.after_inner_loop.replay() + elif self.cuda_graphs_mode is self.CudaGraphsMode.NO_GRAPHS: + # this mode is only for testing purposes + # manual loop instead of using graphs + self._before_outer_loop() + while self.state.active_mask_any.item(): + self._before_inner_loop_get_decoder_output() + self._before_inner_loop_get_joint_output() + while self.state.advance_mask_any.item(): + self._inner_loop_code() + self._after_inner_loop() + else: + raise NotImplementedError(f"Unknown graph mode: {self.cuda_graphs_mode}") return ( self.state.batched_hyps, @@ -509,12 +588,49 @@ def _graph_reinitialize( ) # to avoid recalculation of joint projection, store decoder output in state self.state.decoder_output = self.joint.project_prednet(decoder_output) + if self.cuda_graphs_mode is self.CudaGraphsMode.FULL_GRAPH: + self._full_graph_compile() + elif self.cuda_graphs_mode is self.CudaGraphsMode.NO_WHILE_LOOPS: + self._partial_graphs_compile() + elif self.cuda_graphs_mode is self.CudaGraphsMode.NO_GRAPHS: + # no graphs needed + pass + else: + raise NotImplementedError + + def _partial_graphs_compile(self): + """Compile decoding by parts""" + # Always create a new stream, because the per-thread default stream disallows stream capture to a graph. + stream_for_graph = torch.cuda.Stream(self.state.device) + self.separate_graphs = SeparateGraphsLoopLabels() + with torch.cuda.stream(stream_for_graph), torch.inference_mode(), torch.cuda.graph( + self.separate_graphs.before_outer_loop, stream=stream_for_graph + ): + self._before_outer_loop() + + with torch.cuda.stream(stream_for_graph), torch.inference_mode(), torch.cuda.graph( + self.separate_graphs.before_inner_loop, stream=stream_for_graph + ): + self._before_inner_loop_get_decoder_output() + self._before_inner_loop_get_joint_output() + + with torch.cuda.stream(stream_for_graph), torch.inference_mode(), torch.cuda.graph( + self.separate_graphs.inner_loop_code, stream=stream_for_graph + ): + self._inner_loop_code() + + with torch.cuda.stream(stream_for_graph), torch.inference_mode(), torch.cuda.graph( + self.separate_graphs.after_inner_loop, stream=stream_for_graph + ): + self._after_inner_loop() + def _full_graph_compile(self): + """Compile full graph for decoding""" # Always create a new stream, because the per-thread default stream disallows stream capture to a graph. stream_for_graph = torch.cuda.Stream(self.state.device) - self.graph = torch.cuda.CUDAGraph() + self.full_graph = torch.cuda.CUDAGraph() with torch.cuda.stream(stream_for_graph), torch.inference_mode(), torch.cuda.graph( - self.graph, stream=stream_for_graph + self.full_graph, stream=stream_for_graph ): self._before_outer_loop() @@ -612,12 +728,13 @@ def _before_inner_loop_get_joint_output(self): # blank_mask = self.labels == self._blank_index self.state.time_indices_current_labels.copy_(self.state.time_indices, non_blocking=True) if self.state.alignments is not None: + float_dtype = self.state.float_dtype self.state.alignments.add_results_masked_no_checks_( active_mask=self.state.active_mask, time_indices=self.state.time_indices_current_labels, logits=logits if self.preserve_alignments else None, labels=self.state.labels if self.preserve_alignments else None, - confidence=self._get_confidence_tensor(F.log_softmax(logits, dim=-1)) + confidence=self._get_confidence_tensor(F.log_softmax(logits, dim=-1)).to(dtype=float_dtype) if self.preserve_frame_confidence else None, ) @@ -662,12 +779,13 @@ def _inner_loop_code(self): torch.where(self.state.advance_mask, more_scores, self.state.scores, out=self.state.scores) if self.state.alignments is not None: + float_dtype = self.state.float_dtype self.state.alignments.add_results_masked_no_checks_( active_mask=self.state.advance_mask, time_indices=self.state.time_indices_current_labels, logits=logits if self.preserve_alignments else None, labels=more_labels if self.preserve_alignments else None, - confidence=self._get_confidence_tensor(F.log_softmax(logits, dim=-1)) + confidence=self._get_confidence_tensor(F.log_softmax(logits, dim=-1)).to(dtype=float_dtype) if self.preserve_frame_confidence else None, ) @@ -721,7 +839,7 @@ def _after_inner_loop(self): def __call__( self, x: torch.Tensor, out_len: torch.Tensor, ) -> Tuple[rnnt_utils.BatchedHyps, Optional[rnnt_utils.BatchedAlignments], Any]: - if self.use_cuda_graphs and x.device.type == "cuda": + if self.cuda_graphs_mode is not None and x.device.type == "cuda": return self.loop_labels_cuda_graphs(encoder_output=x, encoder_output_length=out_len) return self.loop_labels_torch(encoder_output=x, encoder_output_length=out_len) diff --git a/nemo/collections/asr/parts/submodules/tdt_loop_labels_computer.py b/nemo/collections/asr/parts/submodules/tdt_loop_labels_computer.py index b136446d97fb..4e514966db2b 100644 --- a/nemo/collections/asr/parts/submodules/tdt_loop_labels_computer.py +++ b/nemo/collections/asr/parts/submodules/tdt_loop_labels_computer.py @@ -13,6 +13,7 @@ # limitations under the License. +from dataclasses import dataclass, field from typing import Any, Optional, Tuple, Union import numpy as np @@ -22,6 +23,7 @@ from nemo.collections.asr.parts.utils import rnnt_utils from nemo.collections.asr.parts.utils.asr_confidence_utils import ConfidenceMethodMixin +from nemo.collections.common.parts.optional_cuda_graphs import WithOptionalCudaGraphs from nemo.core.utils.cuda_python_utils import ( check_cuda_python_cuda_graphs_conditional_nodes_supported, cu_call, @@ -29,6 +31,7 @@ with_conditional_node, ) from nemo.utils import logging +from nemo.utils.enum import PrettyStrEnum try: from cuda import cudart @@ -167,7 +170,17 @@ def need_reinit(self, encoder_output_projected: torch.Tensor) -> bool: ) -class GreedyBatchedTDTLoopLabelsComputer(ConfidenceMethodMixin): +@dataclass +class SeparateGraphsLoopLabels: + """Class to store Cuda graphs for decoding when separate graphs are used""" + + before_outer_loop: torch.cuda.CUDAGraph = field(default_factory=torch.cuda.CUDAGraph) + before_inner_loop: torch.cuda.CUDAGraph = field(default_factory=torch.cuda.CUDAGraph) + inner_loop_code: torch.cuda.CUDAGraph = field(default_factory=torch.cuda.CUDAGraph) + after_inner_loop: torch.cuda.CUDAGraph = field(default_factory=torch.cuda.CUDAGraph) + + +class GreedyBatchedTDTLoopLabelsComputer(WithOptionalCudaGraphs, ConfidenceMethodMixin): """ Label Looping algorithm implementation: optimized batched greedy decoding. Callable. Iterates over labels, on each step finding the next non-blank label @@ -180,6 +193,16 @@ class GreedyBatchedTDTLoopLabelsComputer(ConfidenceMethodMixin): INITIAL_MAX_TIME = 375 # initial max time, used to init state for Cuda graphs CUDA_PROGRAM_NAME = b"while_loop_labels_conditional_tdt.cu" + class CudaGraphsMode(PrettyStrEnum): + FULL_GRAPH = "full_graph" # Cuda graphs with conditional nodes, fastest implementation + NO_WHILE_LOOPS = "no_while_loops" # Decoding with PyTorch while loops + partial Cuda graphs + NO_GRAPHS = "no_graphs" # decoding without graphs, stateful implementation, only for testing purposes + + separate_graphs: Optional[SeparateGraphsLoopLabels] + full_graph: Optional[torch.cuda.CUDAGraph] + cuda_graphs_mode: Optional[CudaGraphsMode] + state: Optional[LoopLabelsState] + def __init__( self, decoder, @@ -215,25 +238,67 @@ def __init__( self.max_symbols = max_symbols_per_step self.preserve_alignments = preserve_alignments self.preserve_frame_confidence = preserve_frame_confidence + self.allow_cuda_graphs = allow_cuda_graphs self.include_duration_confidence = include_duration_confidence self._SOS = self._blank_index self._init_confidence_method(confidence_method_cfg=confidence_method_cfg) assert self._SOS == self._blank_index # "blank as pad" algorithm only - self.use_cuda_graphs = allow_cuda_graphs + self.state = None + self.full_graph = None + self.separate_graphs = None - if self.use_cuda_graphs and self.max_symbols is None: - logging.warning("Max symbols is None, which is not allowed with Cuda graphs.") - self.use_cuda_graphs = False + self.cuda_graphs_mode = None + self.maybe_enable_cuda_graphs() - if self.use_cuda_graphs: + def maybe_enable_cuda_graphs(self): + """Enable CUDA graphs if conditions met""" + if self.cuda_graphs_mode is not None: + # CUDA graphs are enabled + return + + if not self.allow_cuda_graphs: + self.cuda_graphs_mode = None + else: + # cuda graphs are allowed + # check basic requirements for cuda graphs + if self.max_symbols is None: + logging.warning("Max symbols per step is None, which is not allowed with Cuda graphs. Setting to `10`") + self.max_symbols = 10 + # basic requirements met, need to check while loops try: check_cuda_python_cuda_graphs_conditional_nodes_supported() - except ImportError as e: - logging.warning(f"No conditional node support. Cuda graphs will be disabled,\n{e.msg}") - self.use_cuda_graphs = False - - self.state: Optional[LoopLabelsState] = None + self.cuda_graphs_mode = self.CudaGraphsMode.FULL_GRAPH + except (ImportError, ModuleNotFoundError) as e: + logging.warning( + "No conditional node support for Cuda.\n" + "Cuda graphs with while loops are disabled, decoding speed will be slower\n" + f"Reason: {e.msg}" + ) + self.cuda_graphs_mode = self.CudaGraphsMode.NO_WHILE_LOOPS + self.reset_cuda_graphs_state() + + def disable_cuda_graphs(self): + """Disable CUDA graphs, can be used to disable graphs temporary, e.g., in training process""" + if self.cuda_graphs_mode is None: + # nothing to disable + return + self.cuda_graphs_mode = None + self.reset_cuda_graphs_state() + + def reset_cuda_graphs_state(self): + """Reset state to release memory (for CUDA graphs implementations)""" + self.state = None + self.full_graph = None + self.separate_graphs = None + + def force_cuda_graphs_mode(self, mode: Optional[Union[str, CudaGraphsMode]]): + """ + Method to set graphs mode. Use only for testing purposes. + For debugging the algorithm use "no_graphs" mode, since it is impossible to debug CUDA graphs directly. + """ + self.cuda_graphs_mode = self.CudaGraphsMode(mode) if mode is not None else None + self.state = None def loop_labels_torch( self, encoder_output: torch.Tensor, encoder_output_length: torch.Tensor, @@ -250,7 +315,7 @@ def loop_labels_torch( # do not recalculate joint projection, project only once encoder_output_projected = self.joint.project_encoder(encoder_output) - dtype = encoder_output_projected.dtype + float_dtype = encoder_output_projected.dtype # init output structures: BatchedHyps (for results), BatchedAlignments + last decoder state # init empty batched hypotheses @@ -258,7 +323,7 @@ def loop_labels_torch( batch_size=batch_size, init_length=max_time * self.max_symbols if self.max_symbols is not None else max_time, device=device, - float_dtype=dtype, + float_dtype=float_dtype, ) # sample state, will be replaced further when the decoding for hypothesis is done last_decoder_state = self.decoder.initialize_state(encoder_output_projected) @@ -270,7 +335,7 @@ def loop_labels_torch( logits_dim=self.joint.num_classes_with_blank, init_length=max_time * 2 if use_alignments else 1, # blank for each timestep + text tokens device=device, - float_dtype=dtype, + float_dtype=float_dtype, store_alignments=self.preserve_alignments, store_frame_confidence=self.preserve_frame_confidence, with_duration_confidence=self.include_duration_confidence, @@ -338,16 +403,18 @@ def loop_labels_torch( confidence=torch.stack( ( self._get_confidence_tensor(F.log_softmax(logits[:, :-num_durations], dim=-1)).to( - dtype=dtype + dtype=float_dtype ), self._get_confidence_tensor(F.log_softmax(logits[:, -num_durations:], dim=-1)).to( - dtype=dtype + dtype=float_dtype ), ), dim=-1, ) if self.include_duration_confidence - else self._get_confidence_tensor(F.log_softmax(logits[:, :-num_durations], dim=-1)).to(dtype=dtype) + else self._get_confidence_tensor(F.log_softmax(logits[:, :-num_durations], dim=-1)).to( + dtype=float_dtype + ) if self.preserve_frame_confidence else None, ) @@ -390,17 +457,17 @@ def loop_labels_torch( confidence=torch.stack( ( self._get_confidence_tensor(F.log_softmax(logits[:, :-num_durations], dim=-1)).to( - dtype=dtype + dtype=float_dtype ), self._get_confidence_tensor(F.log_softmax(logits[:, -num_durations:], dim=-1)).to( - dtype=dtype + dtype=float_dtype ), ), dim=-1, ) if self.include_duration_confidence else self._get_confidence_tensor(F.log_softmax(logits[:, :-num_durations], dim=-1)).to( - dtype=dtype + dtype=float_dtype ) if self.preserve_frame_confidence else None, @@ -467,6 +534,8 @@ def loop_labels_cuda_graphs( encoder_output: output from the encoder encoder_output_length: lengths of the utterances in `encoder_output` """ + assert self.cuda_graphs_mode is not None + # do not recalculate joint projection, project only once encoder_output = self.joint.project_encoder(encoder_output) current_batch_size = encoder_output.shape[0] @@ -484,16 +553,27 @@ def loop_labels_cuda_graphs( self.state.encoder_output_length[: encoder_output_length.shape[0]].copy_(encoder_output_length) # set length to zero for elements outside the current batch self.state.encoder_output_length[current_batch_size:].fill_(0) - self.graph.replay() - - # example manual loop (can be used instead of graph.replay()) - # self._before_outer_loop() - # while self.state.active_mask_any.item(): - # self._before_inner_loop_get_decoder_output() - # self._before_inner_loop_get_joint_output() - # while self.state.advance_mask_any.item(): - # self._inner_loop_code() - # self._after_inner_loop() + if self.cuda_graphs_mode is self.CudaGraphsMode.FULL_GRAPH: + self.full_graph.replay() + elif self.cuda_graphs_mode is self.CudaGraphsMode.NO_WHILE_LOOPS: + self.separate_graphs.before_outer_loop.replay() + while self.state.active_mask_any.item(): + self.separate_graphs.before_inner_loop.replay() + while self.state.advance_mask_any.item(): + self.separate_graphs.inner_loop_code.replay() + self.separate_graphs.after_inner_loop.replay() + elif self.cuda_graphs_mode is self.CudaGraphsMode.NO_GRAPHS: + # this mode is only for testing purposes + # manual loop instead of using graphs + self._before_outer_loop() + while self.state.active_mask_any.item(): + self._before_inner_loop_get_decoder_output() + self._before_inner_loop_get_joint_output() + while self.state.advance_mask_any.item(): + self._inner_loop_code() + self._after_inner_loop() + else: + raise NotImplementedError(f"Unknown graph mode: {self.cuda_graphs_mode}") return ( self.state.batched_hyps, @@ -565,12 +645,49 @@ def _graph_reinitialize( ) # to avoid recalculation of joint projection, store decoder output in state self.state.decoder_output = self.joint.project_prednet(decoder_output) + if self.cuda_graphs_mode is self.CudaGraphsMode.FULL_GRAPH: + self._full_graph_compile() + elif self.cuda_graphs_mode is self.CudaGraphsMode.NO_WHILE_LOOPS: + self._partial_graphs_compile() + elif self.cuda_graphs_mode is self.CudaGraphsMode.NO_GRAPHS: + # no graphs needed + pass + else: + raise NotImplementedError + + def _partial_graphs_compile(self): + """Compile decoding by parts""" + # Always create a new stream, because the per-thread default stream disallows stream capture to a graph. + stream_for_graph = torch.cuda.Stream(self.state.device) + self.separate_graphs = SeparateGraphsLoopLabels() + with torch.cuda.stream(stream_for_graph), torch.inference_mode(), torch.cuda.graph( + self.separate_graphs.before_outer_loop, stream=stream_for_graph + ): + self._before_outer_loop() + + with torch.cuda.stream(stream_for_graph), torch.inference_mode(), torch.cuda.graph( + self.separate_graphs.before_inner_loop, stream=stream_for_graph + ): + self._before_inner_loop_get_decoder_output() + self._before_inner_loop_get_joint_output() + + with torch.cuda.stream(stream_for_graph), torch.inference_mode(), torch.cuda.graph( + self.separate_graphs.inner_loop_code, stream=stream_for_graph + ): + self._inner_loop_code() + + with torch.cuda.stream(stream_for_graph), torch.inference_mode(), torch.cuda.graph( + self.separate_graphs.after_inner_loop, stream=stream_for_graph + ): + self._after_inner_loop() + def _full_graph_compile(self): + """Compile full graph for decoding""" # Always create a new stream, because the per-thread default stream disallows stream capture to a graph. stream_for_graph = torch.cuda.Stream(self.state.device) - self.graph = torch.cuda.CUDAGraph() + self.full_graph = torch.cuda.CUDAGraph() with torch.cuda.stream(stream_for_graph), torch.inference_mode(), torch.cuda.graph( - self.graph, stream=stream_for_graph + self.full_graph, stream=stream_for_graph ): self._before_outer_loop() @@ -651,7 +768,6 @@ def _before_inner_loop_get_joint_output(self): # stage 2: get joint output, iteratively seeking for non-blank labels # blank label in `labels` tensor means "end of hypothesis" (for this index) self.state.active_mask_prev.copy_(self.state.active_mask, non_blocking=True) - dtype = self.state.encoder_output_projected.dtype logits = ( self.joint.joint_after_projection( self.state.encoder_output_projected[self.state.batch_indices, self.state.safe_time_indices].unsqueeze( @@ -675,6 +791,7 @@ def _before_inner_loop_get_joint_output(self): # for blank labels force duration >= 1 durations.masked_fill_(torch.logical_and(durations == 0, self.state.blank_mask), 1) if self.state.alignments is not None: + float_dtype = self.state.float_dtype self.state.alignments.add_results_masked_no_checks_( active_mask=self.state.active_mask, time_indices=self.state.time_indices_current_labels, @@ -684,17 +801,17 @@ def _before_inner_loop_get_joint_output(self): ( self._get_confidence_tensor( F.log_softmax(logits[:, : -self.state.all_durations.shape[0]], dim=-1) - ).to(dtype=dtype), + ).to(dtype=float_dtype), self._get_confidence_tensor( F.log_softmax(logits[:, -self.state.all_durations.shape[0] :], dim=-1) - ).to(dtype=dtype), + ).to(dtype=float_dtype), ), dim=-1, ) if self.include_duration_confidence else self._get_confidence_tensor( F.log_softmax(logits[:, : -self.state.all_durations.shape[0]], dim=-1) - ).to(dtype=dtype) + ).to(dtype=float_dtype) if self.preserve_frame_confidence else None, ) @@ -720,7 +837,6 @@ def _inner_loop_code(self): self.state.time_indices_current_labels, out=self.state.time_indices_current_labels, ) - dtype = self.state.encoder_output_projected.dtype logits = ( self.joint.joint_after_projection( self.state.encoder_output_projected[self.state.batch_indices, self.state.safe_time_indices].unsqueeze( @@ -742,6 +858,7 @@ def _inner_loop_code(self): torch.where(self.state.advance_mask, more_scores, self.state.scores, out=self.state.scores) if self.state.alignments is not None: + float_dtype = self.state.float_dtype self.state.alignments.add_results_masked_no_checks_( active_mask=self.state.advance_mask, time_indices=self.state.time_indices_current_labels, @@ -751,17 +868,17 @@ def _inner_loop_code(self): ( self._get_confidence_tensor( F.log_softmax(logits[:, : -self.state.all_durations.shape[0]], dim=-1) - ).to(dtype=dtype), + ).to(dtype=float_dtype), self._get_confidence_tensor( F.log_softmax(logits[:, -self.state.all_durations.shape[0] :], dim=-1) - ).to(dtype=dtype), + ).to(dtype=float_dtype), ), dim=-1, ) if self.include_duration_confidence else self._get_confidence_tensor( F.log_softmax(logits[:, : -self.state.all_durations.shape[0]], dim=-1) - ).to(dtype=dtype) + ).to(dtype=float_dtype) if self.preserve_frame_confidence else None, ) @@ -822,7 +939,7 @@ def _after_inner_loop(self): def __call__( self, x: torch.Tensor, out_len: torch.Tensor, ) -> Tuple[rnnt_utils.BatchedHyps, Optional[rnnt_utils.BatchedAlignments], Any]: - if self.use_cuda_graphs and x.device.type == "cuda": + if self.cuda_graphs_mode is not None and x.device.type == "cuda": return self.loop_labels_cuda_graphs(encoder_output=x, encoder_output_length=out_len) return self.loop_labels_torch(encoder_output=x, encoder_output_length=out_len) diff --git a/nemo/collections/common/parts/optional_cuda_graphs.py b/nemo/collections/common/parts/optional_cuda_graphs.py new file mode 100644 index 000000000000..2417d9e00370 --- /dev/null +++ b/nemo/collections/common/parts/optional_cuda_graphs.py @@ -0,0 +1,89 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import abc +from typing import Optional + +import torch.nn as nn + +from nemo.utils import logging + + +class WithOptionalCudaGraphs(abc.ABC): + """ + Abstract interface for modules with CUDA graphs. + Allows to enable/disable CUDA graphs on the fly. + """ + + @classmethod + def disable_cuda_graphs_recursive(cls, module: nn.Module, attribute_path: Optional[str] = None): + """ + Disable CUDA graphs Enable CUDA graphs, finding submodule recursively. + + Args: + module: instance of nn.Module + attribute_path: field containing instance of WithOptionalCudaGraphs + E.g., "decoding.decoding" means that ".decoding.decoding" are checked. + If None, "" is checked. + """ + attributes = attribute_path.split(".") if attribute_path else [] + + for name, submodule in module.named_modules(): + object_to_check = submodule + try: + # recursively get attribute by iterating attribute_path + for attribute in attributes: + object_to_check = getattr(object_to_check, attribute) + except AttributeError: + continue # loop over modules, no attribute + + if isinstance(object_to_check, cls): + object_to_check.disable_cuda_graphs() + logging.info(f"Disabled CUDA graphs for module {type(submodule)}" + ".".join([name] + attributes)) + + @classmethod + def enable_cuda_graphs_recursive(cls, module: nn.Module, attribute_path: Optional[str] = None): + """ + Enable CUDA graphs, finding submodule recursively + + Args: + module: instance of nn.Module + attribute_path: field containing instance of WithOptionalCudaGraphs + E.g., "decoding.decoding" means that ".decoding.decoding" are checked. + If None, "" is checked. + """ + attributes = attribute_path.split(".") if attribute_path else [] + + for name, submodule in module.named_modules(): + object_to_check = submodule + try: + # recursively get attribute by iterating attribute_path + for attribute in attributes: + object_to_check = getattr(object_to_check, attribute) + except AttributeError: + continue # loop over modules, no attribute + + if isinstance(object_to_check, cls): + object_to_check.maybe_enable_cuda_graphs() + logging.info(f"Enabled CUDA graphs for module {type(submodule)}" + ".".join([name] + attributes)) + + @abc.abstractmethod + def disable_cuda_graphs(self): + """Disable (maybe temporary) CUDA graphs""" + raise NotImplementedError + + @abc.abstractmethod + def maybe_enable_cuda_graphs(self): + """Enable CUDA graphs if all conditions met""" + raise NotImplementedError diff --git a/nemo/core/utils/cuda_python_utils.py b/nemo/core/utils/cuda_python_utils.py index fb47c22ceee0..eb8897df0797 100644 --- a/nemo/core/utils/cuda_python_utils.py +++ b/nemo/core/utils/cuda_python_utils.py @@ -25,7 +25,7 @@ def check_cuda_python_cuda_graphs_conditional_nodes_supported(): try: from cuda import cuda except ImportError: - raise ModuleNotFoundError("Please do `pip install cuda-python>=12.3`") + raise ModuleNotFoundError("No `cuda-python` module. Please do `pip install cuda-python>=12.3`") from cuda import __version__ as cuda_python_version diff --git a/tests/collections/asr/decoding/rnnt_alignments_check.py b/tests/collections/asr/decoding/rnnt_alignments_check.py index aa4d5f044de1..d44f7f8fd985 100644 --- a/tests/collections/asr/decoding/rnnt_alignments_check.py +++ b/tests/collections/asr/decoding/rnnt_alignments_check.py @@ -28,13 +28,14 @@ PRETRAINED_MODEL_NAME = "stt_en_conformer_transducer_small" -def get_rnnt_alignments(strategy: str, loop_labels: bool = True, location="cuda"): +def get_rnnt_alignments(strategy: str, loop_labels: bool = True, use_cuda_graph_decoder=False, location="cuda"): cfg = OmegaConf.structured(TranscriptionConfig(pretrained_name=PRETRAINED_MODEL_NAME)) cfg.rnnt_decoding.confidence_cfg.preserve_frame_confidence = True cfg.rnnt_decoding.preserve_alignments = True cfg.rnnt_decoding.strategy = strategy if cfg.rnnt_decoding.strategy == "greedy_batch": cfg.rnnt_decoding.greedy.loop_labels = loop_labels + cfg.rnnt_decoding.greedy.use_cuda_graph_decoder = use_cuda_graph_decoder cfg.dataset_manifest = TEST_DATA_PATH filepaths = prepare_audio_data(cfg)[0][:10] # selecting 10 files only @@ -73,10 +74,15 @@ def cleanup_local_folder(): # TODO: add the same tests for multi-blank RNNT decoding @pytest.mark.skipif(not os.path.exists('/home/TestData'), reason='Not a Jenkins machine') @pytest.mark.parametrize("loop_labels", [True, False]) -def test_rnnt_alignments(loop_labels: bool): +@pytest.mark.parametrize("use_cuda_graph_decoder", [True, False]) +def test_rnnt_alignments(loop_labels: bool, use_cuda_graph_decoder: bool): + if not loop_labels and use_cuda_graph_decoder: + pytest.skip("Frame-Looping algorithm with CUDA graphs does not yet support alignments") # using greedy as baseline and comparing all other configurations to it ref_transcriptions = get_rnnt_alignments("greedy") - transcriptions = get_rnnt_alignments("greedy_batch", loop_labels=loop_labels) + transcriptions = get_rnnt_alignments( + "greedy_batch", loop_labels=loop_labels, use_cuda_graph_decoder=use_cuda_graph_decoder + ) # comparing that label sequence in alignments is exactly the same # we can't compare logits as well, because they are expected to be # slightly different in batched and single-sample mode diff --git a/tests/collections/asr/decoding/test_cuda_graph_rnnt_greedy_decoding.py b/tests/collections/asr/decoding/test_cuda_graph_rnnt_greedy_decoding.py index 538ff9d71cf1..31fe822573ce 100644 --- a/tests/collections/asr/decoding/test_cuda_graph_rnnt_greedy_decoding.py +++ b/tests/collections/asr/decoding/test_cuda_graph_rnnt_greedy_decoding.py @@ -11,19 +11,38 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +import copy import glob -import tempfile import jiwer import pytest import torch -from omegaconf import OmegaConf, open_dict +from omegaconf import open_dict from nemo.collections.asr.models import ASRModel from nemo.core.utils.cuda_python_utils import skip_cuda_python_test_if_cuda_graphs_conditional_nodes_not_supported +@pytest.fixture(scope="module") +def stt_en_fastconformer_transducer_xlarge(): + model_name = "stt_en_fastconformer_transducer_xlarge" + return ASRModel.from_pretrained(model_name, map_location="cpu") + + +@pytest.fixture(scope="module") +def stt_en_fastconformer_transducer_xxlarge(): + model_name = "stt_en_fastconformer_transducer_xxlarge" + return ASRModel.from_pretrained(model_name, map_location="cpu") + + +@pytest.fixture(scope="module") +def stt_en_fastconformer_transducer_large(): + model_name = "stt_en_fastconformer_transducer_large" + return ASRModel.from_pretrained(model_name, map_location="cpu") + + +@pytest.mark.with_downloads +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA decoder can run only on CUDA") @pytest.mark.parametrize( ("model_name", "batch_size", "enable_bfloat16"), [ @@ -42,28 +61,87 @@ ], ) @pytest.mark.parametrize("loop_labels", [False, True]) -def test_cuda_graph_rnnt_greedy_decoder(model_name, batch_size, enable_bfloat16, loop_labels: bool): - skip_cuda_python_test_if_cuda_graphs_conditional_nodes_not_supported() +def test_cuda_graph_rnnt_greedy_decoder(model_name, batch_size, enable_bfloat16, loop_labels: bool, request): + if not loop_labels: + skip_cuda_python_test_if_cuda_graphs_conditional_nodes_not_supported() + if enable_bfloat16 and not torch.cuda.is_bf16_supported(): + pytest.skip("bfloat16 is not supported") + + device = torch.device("cuda") + nemo_model = request.getfixturevalue(model_name).to(device) + decoding_config = copy.deepcopy(nemo_model.cfg.decoding) + + with open_dict(decoding_config): + decoding_config["greedy"]["max_symbols"] = 5 + decoding_config["greedy"]["loop_labels"] = loop_labels + decoding_config["greedy"]["use_cuda_graph_decoder"] = False + + nemo_model.change_decoding_strategy(decoding_config) + audio_filepaths = glob.glob("tests/.data/asr/test/an4/wav/*.wav") + + with torch.cuda.amp.autocast(dtype=torch.bfloat16, enabled=enable_bfloat16): + actual_transcripts, _ = nemo_model.transcribe(audio_filepaths, batch_size=batch_size, num_workers=None) + + decoding_config["greedy"]["use_cuda_graph_decoder"] = True + + nemo_model.change_decoding_strategy(decoding_config) + + with torch.cuda.amp.autocast(dtype=torch.bfloat16, enabled=enable_bfloat16): + fast_transcripts, _ = nemo_model.transcribe(audio_filepaths, batch_size=batch_size, num_workers=None) - conf = ASRModel.from_pretrained(model_name, return_config=True) - with open_dict(conf): - conf["decoding"]["greedy"]["max_symbols"] = 5 - conf["decoding"]["greedy"]["loop_labels"] = loop_labels - conf["decoding"]["greedy"]["use_cuda_graph_decoder"] = False + wer = jiwer.wer(actual_transcripts, fast_transcripts) - with tempfile.NamedTemporaryFile() as fp: - OmegaConf.save(config=conf, f=fp.name) - nemo_model = ASRModel.from_pretrained(model_name, override_config_path=fp.name, map_location="cuda") + assert wer <= 1e-3, "Cuda graph greedy decoder should match original decoder implementation." + for actual, fast in zip(actual_transcripts, fast_transcripts): + if actual != fast: + print("erroneous samples:") + print("Original transcript:", actual) + print("New transcript:", fast) + + +@pytest.mark.with_downloads +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA decoder can run only on CUDA") +@pytest.mark.parametrize("force_mode", ["no_graphs", "no_while_loops", "full_graph"]) +@pytest.mark.parametrize("enable_bfloat16", [False, True]) +def test_loop_labels_cuda_graph_rnnt_greedy_decoder_forced_mode( + stt_en_fastconformer_transducer_large, force_mode: str, enable_bfloat16: bool +): + """ + Testing Label-Looping algorithm with CUDA graphs in forced mode. + This test guarantees that we check that the fallback behavior is working. + NB: Since it is impossible to directly debug CUDA graphs, when making changes, + start testing and debugging the code with forced "no_graphs" mode. + """ + if enable_bfloat16 and not torch.cuda.is_bf16_supported(): + pytest.skip("bfloat16 is not supported") + + if force_mode == "full_graph": + skip_cuda_python_test_if_cuda_graphs_conditional_nodes_not_supported() + + batch_size = 16 + device = torch.device("cuda") + nemo_model = stt_en_fastconformer_transducer_large.to(device) + decoding_config = copy.deepcopy(nemo_model.cfg.decoding) + + with open_dict(decoding_config): + decoding_config["greedy"]["max_symbols"] = 5 + decoding_config["greedy"]["loop_labels"] = True + decoding_config["greedy"]["use_cuda_graph_decoder"] = False + # test that alignments and confidence do not introduce failures + decoding_config["greedy"]["preserve_alignments"] = True + decoding_config["greedy"]["preserve_frame_confidence"] = True + + nemo_model.change_decoding_strategy(decoding_config) audio_filepaths = glob.glob("tests/.data/asr/test/an4/wav/*.wav") with torch.cuda.amp.autocast(dtype=torch.bfloat16, enabled=enable_bfloat16): actual_transcripts, _ = nemo_model.transcribe(audio_filepaths, batch_size=batch_size, num_workers=None) - with open_dict(conf): - conf["decoding"]["greedy"]["use_cuda_graph_decoder"] = True - - nemo_model.change_decoding_strategy(conf["decoding"]) + # transcribe with use implementation with cuda graphs + decoding_config["greedy"]["use_cuda_graph_decoder"] = True + nemo_model.change_decoding_strategy(decoding_config) + nemo_model.decoding.decoding._decoding_computer.force_cuda_graphs_mode(mode=force_mode) with torch.cuda.amp.autocast(dtype=torch.bfloat16, enabled=enable_bfloat16): fast_transcripts, _ = nemo_model.transcribe(audio_filepaths, batch_size=batch_size, num_workers=None) @@ -79,27 +157,27 @@ def test_cuda_graph_rnnt_greedy_decoder(model_name, batch_size, enable_bfloat16, print("New transcript:", fast) +@pytest.mark.with_downloads +@pytest.mark.skipif(not torch.cuda.is_available() or torch.cuda.device_count() < 2, reason="Test requires 2 GPUs") @pytest.mark.parametrize("loop_labels", [False, True]) -def test_change_devices(loop_labels: bool): - skip_cuda_python_test_if_cuda_graphs_conditional_nodes_not_supported() - - if torch.cuda.device_count() < 2: - pytest.skip("Test requires more than 2 GPUs") +def test_change_devices(loop_labels: bool, stt_en_fastconformer_transducer_xlarge): + if not loop_labels: + skip_cuda_python_test_if_cuda_graphs_conditional_nodes_not_supported() first_device = torch.device("cuda:0") second_device = torch.device("cuda:1") - model_name = "stt_en_fastconformer_transducer_xlarge" batch_size = 8 - conf = ASRModel.from_pretrained(model_name, return_config=True) - with open_dict(conf): - conf["decoding"]["greedy"]["max_symbols"] = 5 - conf["decoding"]["greedy"]["loop_labels"] = loop_labels - conf["decoding"]["greedy"]["use_cuda_graph_decoder"] = True + nemo_model = stt_en_fastconformer_transducer_xlarge.to(second_device) + decoding_config = copy.deepcopy(nemo_model.cfg.decoding) + + with open_dict(decoding_config): + decoding_config["greedy"]["max_symbols"] = 5 + decoding_config["greedy"]["loop_labels"] = loop_labels + decoding_config["greedy"]["use_cuda_graph_decoder"] = True - nemo_model = ASRModel.from_pretrained(model_name, map_location=second_device) - nemo_model.change_decoding_strategy(conf["decoding"]) + nemo_model.change_decoding_strategy(decoding_config) # Test that the model can run successfully when it is first # initialized on second_device and then transferred to diff --git a/tests/collections/asr/test_asr_rnnt_encdec_model.py b/tests/collections/asr/test_asr_rnnt_encdec_model.py index a6e3714f20f5..c3b214751d04 100644 --- a/tests/collections/asr/test_asr_rnnt_encdec_model.py +++ b/tests/collections/asr/test_asr_rnnt_encdec_model.py @@ -432,9 +432,14 @@ def test_BeamRNNTInferConfig(self): ) @pytest.mark.unit @pytest.mark.parametrize( - "greedy_class", [greedy_decode.GreedyRNNTInfer, greedy_decode.GreedyBatchedRNNTInfer], + ("greedy_class", "loop_labels"), + [ + (greedy_decode.GreedyRNNTInfer, None), + (greedy_decode.GreedyBatchedRNNTInfer, True), + (greedy_decode.GreedyBatchedRNNTInfer, False), + ], ) - def test_greedy_decoding(self, greedy_class): + def test_greedy_decoding(self, greedy_class, loop_labels: Optional[bool]): token_list = [" ", "a", "b", "c"] vocab_size = len(token_list) @@ -454,7 +459,14 @@ def test_greedy_decoding(self, greedy_class): for joint_type in [RNNTJoint, HATJoint]: joint_net = joint_type(jointnet_cfg, vocab_size, vocabulary=token_list) - greedy = greedy_class(decoder, joint_net, blank_index=len(token_list) - 1, max_symbols_per_step=5) + additional_decoding_kwargs = {} if loop_labels is None else {"loop_labels": loop_labels} + greedy = greedy_class( + decoder, + joint_net, + blank_index=len(token_list) - 1, + max_symbols_per_step=5, + **additional_decoding_kwargs, + ) # (B, D, T) enc_out = torch.randn(1, encoder_output_size, 30) diff --git a/tests/collections/common/test_optional_cuda_graphs.py b/tests/collections/common/test_optional_cuda_graphs.py new file mode 100644 index 000000000000..7b1dda775863 --- /dev/null +++ b/tests/collections/common/test_optional_cuda_graphs.py @@ -0,0 +1,71 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from types import SimpleNamespace + +import torch.nn as nn + +from nemo.collections.common.parts.optional_cuda_graphs import WithOptionalCudaGraphs + + +class MockClassWithCudaGraphs(WithOptionalCudaGraphs): + def __init__(self): + super().__init__() + self.cuda_graphs_used = True + + def disable_cuda_graphs(self): + self.cuda_graphs_used = False + + def maybe_enable_cuda_graphs(self): + self.cuda_graphs_used = True + + +class MockModuleWithCudaGraphs(MockClassWithCudaGraphs, nn.Module): + def __init__(self): + super().__init__() + self.linear = nn.Linear(10, 20) + + def forward(self, x): + return self.linear(x) + + +class MockModuleWithCudaGraphsByPath(nn.Module): + def __init__(self): + super().__init__() + self.linear = nn.Linear(10, 20) + self.decoding = SimpleNamespace(decoding=MockClassWithCudaGraphs()) + + def forward(self, x): + return self.linear(x) + + +class TestWithOptionalCudaGraphs: + def test_module_toggle_cuda_graphs(self): + module_with_graphs = MockModuleWithCudaGraphs() + assert module_with_graphs.cuda_graphs_used + WithOptionalCudaGraphs.disable_cuda_graphs_recursive(module_with_graphs) + assert not module_with_graphs.cuda_graphs_used + WithOptionalCudaGraphs.enable_cuda_graphs_recursive(module_with_graphs) + assert module_with_graphs.cuda_graphs_used + + def test_module_toggle_cuda_graphs_by_path(self): + module_with_graphs_by_path = MockModuleWithCudaGraphsByPath() + assert module_with_graphs_by_path.decoding.decoding.cuda_graphs_used + WithOptionalCudaGraphs.disable_cuda_graphs_recursive( + module_with_graphs_by_path, attribute_path="decoding.decoding" + ) + assert not module_with_graphs_by_path.decoding.decoding.cuda_graphs_used + WithOptionalCudaGraphs.enable_cuda_graphs_recursive( + module_with_graphs_by_path, attribute_path="decoding.decoding" + ) + assert module_with_graphs_by_path.decoding.decoding.cuda_graphs_used From 10e15ed1ffdf409c1b130c024524d056ea13ffa7 Mon Sep 17 00:00:00 2001 From: Ali Taghibakhshi <71892896+JRD971000@users.noreply.github.com> Date: Fri, 3 May 2024 08:54:01 -0500 Subject: [PATCH 16/73] Alit/griffin (#9021) * add init griffin * remove unnecessary imports * add sft * add sft model init * add text gen starategy for Griffin no cache * test SFT * minor fix to config * fix logprob output issue * sft WS fixed * replace trainer in conversion script * Revert "Fix PTL2.2 saving multiple `*-last.ckpt` checkpoints in resumed training (#8480)" This reverts commit 11b7a733cbd4b8311eacba581323f88c7cd4bac4. * Revert "FSDP update to PTL 2.2 (#8658)" This reverts commit 355e36c344be55b2bf7b1fd55f5554a831e6fcd3. * init dist opt * add peft * fix generate script * convert to HF format * further cleanups * minor fix * minor fix * more refactoring * remove local path from config * undo unnecessary changes * remove pretraining * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix val param sync * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * minor fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Addresing MR comments * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * code ql fixed * more code ql * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * address comments * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * add jenkins * remove jenkins for momentarily * add reqs for griffin * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * add req test * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * add reqs to nlp * add reqs to nlp * replace torch scan * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * jit fusion for embedding decoder * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * jit fusion for embedding decoder * add fix to rglru * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: Ali Taghibakhshi Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Eric Harper --- .../conf/megatron_griffin_config.yaml | 168 +++++++++ .../megatron_griffin_finetuning_config.yaml | 285 ++++++++++++++++ .../megatron_griffin_generate_config.yaml | 292 ++++++++++++++++ .../megatron_griffin_finetuning.py | 60 ++++ .../megatron_griffin_generate.py | 69 ++++ .../megatron/gpt_sft_dataset.py | 10 +- .../megatron/griffin/__init__.py | 13 + .../megatron/griffin/griffin_block.py | 75 ++++ .../megatron/griffin/griffin_layer_spec.py | 81 +++++ .../megatron/griffin/griffin_model.py | 156 +++++++++ .../megatron/griffin/recurrent_layer.py | 106 ++++++ .../megatron/griffin/recurrent_module.py | 321 ++++++++++++++++++ .../megatron_gpt_sft_model.py | 1 + .../megatron_griffin_model.py | 96 ++++++ .../megatron_griffin_sft_model.py | 55 +++ .../common/text_generation_strategy.py | 75 ++++ requirements/requirements_nlp.txt | 2 + .../convert_griffin_hf_to_nemo.py | 174 ++++++++++ .../convert_griffin_nemo_to_hf.py | 147 ++++++++ 19 files changed, 2185 insertions(+), 1 deletion(-) create mode 100644 examples/nlp/language_modeling/conf/megatron_griffin_config.yaml create mode 100644 examples/nlp/language_modeling/conf/megatron_griffin_finetuning_config.yaml create mode 100644 examples/nlp/language_modeling/conf/megatron_griffin_generate_config.yaml create mode 100644 examples/nlp/language_modeling/megatron_griffin_finetuning.py create mode 100644 examples/nlp/language_modeling/megatron_griffin_generate.py create mode 100755 nemo/collections/nlp/models/language_modeling/megatron/griffin/__init__.py create mode 100755 nemo/collections/nlp/models/language_modeling/megatron/griffin/griffin_block.py create mode 100755 nemo/collections/nlp/models/language_modeling/megatron/griffin/griffin_layer_spec.py create mode 100755 nemo/collections/nlp/models/language_modeling/megatron/griffin/griffin_model.py create mode 100755 nemo/collections/nlp/models/language_modeling/megatron/griffin/recurrent_layer.py create mode 100755 nemo/collections/nlp/models/language_modeling/megatron/griffin/recurrent_module.py create mode 100644 nemo/collections/nlp/models/language_modeling/megatron_griffin_model.py create mode 100644 nemo/collections/nlp/models/language_modeling/megatron_griffin_sft_model.py create mode 100644 scripts/checkpoint_converters/convert_griffin_hf_to_nemo.py create mode 100644 scripts/checkpoint_converters/convert_griffin_nemo_to_hf.py diff --git a/examples/nlp/language_modeling/conf/megatron_griffin_config.yaml b/examples/nlp/language_modeling/conf/megatron_griffin_config.yaml new file mode 100644 index 000000000000..ea23cf630f8b --- /dev/null +++ b/examples/nlp/language_modeling/conf/megatron_griffin_config.yaml @@ -0,0 +1,168 @@ +name: megatron_griffin +restore_from_path: null # used when starting from a .nemo file + +trainer: + devices: 1 + num_nodes: 1 + accelerator: gpu + precision: 16 + logger: False # logger provided by exp_manager + enable_checkpointing: False + use_distributed_sampler: False + max_epochs: -1 # PTL default. In practice we don't usually train for more than 1 epoch. + max_steps: 100000 # consumed_samples = global_step * micro_batch_size * data_parallel_size * accumulate_grad_batches + log_every_n_steps: 10 + val_check_interval: 100 + limit_val_batches: 50 + limit_test_batches: 500 + accumulate_grad_batches: 1 + gradient_clip_val: 1.0 + benchmark: False + +exp_manager: + explicit_log_dir: null + exp_dir: null + name: megatron_griffin + create_wandb_logger: False + wandb_logger_kwargs: + project: null + name: null + resume_if_exists: True + resume_ignore_no_checkpoint: True + create_checkpoint_callback: True + checkpoint_callback_params: + monitor: val_loss + save_top_k: 10 + mode: min + always_save_nemo: False # saves nemo file during validation, not implemented for model parallel + filename: 'megatron_griffin--{val_loss:.2f}-{step}-{consumed_samples}' + model_parallel_size: ${multiply:${model.tensor_model_parallel_size}, ${model.pipeline_model_parallel_size}} + + +model: + restore_from_path: null + # model parallelism + micro_batch_size: 2 + global_batch_size: 2 + tensor_model_parallel_size: 1 + pipeline_model_parallel_size: 1 + virtual_pipeline_model_parallel_size: null + vocab_size: 256000 + # model architecture + encoder_seq_length: 512 + max_position_embeddings: ${.encoder_seq_length} + position_embedding_type: 'rope' # Position embedding type. Options ['learned_absolute', 'rope', 'alibi', 'kerple' , 'xpos', 'sandwich'] xpos and sandwich are experimental. + logits_soft_cap: 30.0 + num_layers: 26 + gated_linear_unit: True + window_size: [1024, 0] + num_query_groups: 1 + attention_dropout: 0.0 + hidden_dropout: 0.0 + hidden_size: 2560 + bias_activation_fusion: True + ffn_hidden_size: 7680 # Transformer FFN hidden size. Usually 4 * hidden_size. + num_attention_heads: 10 + transformer_block_type: pre_ln + init_method_std: 0.02 # Standard deviation of the zero mean normal distribution used for weight initialization.') + kv_channels: null # Projection weights dimension in multi-head attention. Set to hidden_size // num_attention_heads if null + apply_query_key_layer_scaling: True # scale Q * K^T by 1 / layer-number. + normalization: RMSNorm + layernorm_epsilon: 1e-6 + rotary_interleaved: False + layernorm_zero_centered_gamma: True + make_vocab_size_divisible_by: 128 # Pad the vocab size to be divisible by this value for computation efficiency. + pre_process: True # add embedding + post_process: True # add pooler + megatron_legacy: False + + tokenizer: + library: 'huggingface' + type: 'google/recurrentgemma-2b' + model: null + vocab_file: null + merge_file: null + sentencepiece_legacy: False + + # precision + native_amp_init_scale: 4294967296 # 2 ** 32 + native_amp_growth_interval: 1000 + fp32_residual_connection: False # Move residual connections to fp32 + fp16_lm_cross_entropy: False # Move the cross entropy unreduced loss calculation for lm head to fp16 + + # Megatron O2-style half-precision + megatron_amp_O2: False # Enable O2-level automatic mixed precision using main parameters + grad_allreduce_chunk_size_mb: 125 + grad_div_ar_fusion: False + + # miscellaneous + seed: 1234 + use_cpu_initialization: False # Init weights on the CPU (slow for large models) + onnx_safe: False # Use work-arounds for known problems with Torch ONNX exporter. + gradient_as_bucket_view: True # PyTorch DDP argument. Allocate gradients in a contiguous bucket to save memory (less fragmentation and buffer memory) + + ## Activation Checkpointing + # NeMo Megatron supports 'selective' activation checkpointing where only the memory intensive part of attention is checkpointed. + # These memory intensive activations are also less compute intensive which makes activation checkpointing more efficient for LLMs (20B+). + # See Reducing Activation Recomputation in Large Transformer Models: https://arxiv.org/abs/2205.05198 for more details. + # 'full' will checkpoint the entire transformer layer. + activations_checkpoint_granularity: null # 'selective' or 'full' + activations_checkpoint_method: null # 'uniform', 'block' + # 'uniform' divides the total number of transformer layers and checkpoints the input activation + # of each chunk at the specified granularity. When used with 'selective', 'uniform' checkpoints all attention blocks in the model. + # 'block' checkpoints the specified number of layers per pipeline stage at the specified granularity + activations_checkpoint_num_layers: null + # when using 'uniform' this creates groups of transformer layers to checkpoint. Usually set to 1. Increase to save more memory. + # when using 'block' this this will checkpoint the first activations_checkpoint_num_layers per pipeline stage. + num_micro_batches_with_partial_activation_checkpoints: null + # This feature is valid only when used with pipeline-model-parallelism. + # When an integer value is provided, it sets the number of micro-batches where only a partial number of Transformer layers get checkpointed + # and recomputed within a window of micro-batches. The rest of micro-batches in the window checkpoint all Transformer layers. The size of window is + # set by the maximum outstanding micro-batch backpropagations, which varies at different pipeline stages. The number of partial layers to checkpoint + # per micro-batch is set by 'activations_checkpoint_num_layers' with 'activations_checkpoint_method' of 'block'. + # This feature enables using activation checkpoint at a fraction of micro-batches up to the point of full GPU memory usage. + activations_checkpoint_layers_per_pipeline: null + # This feature is valid only when used with pipeline-model-parallelism. + # When an integer value (rounded down when float is given) is provided, it sets the number of Transformer layers to skip checkpointing at later + # pipeline stages. For example, 'activations_checkpoint_layers_per_pipeline' of 3 makes pipeline stage 1 to checkpoint 3 layers less than + # stage 0 and stage 2 to checkpoint 6 layers less stage 0, and so on. This is possible because later pipeline stage + # uses less GPU memory with fewer outstanding micro-batch backpropagations. Used with 'num_micro_batches_with_partial_activation_checkpoints', + # this feature removes most of activation checkpoints at the last pipeline stage, which is the critical execution path. + sequence_parallel: False + + data: + # Path to data must be specified by the user. + # can override from the CLI: "model.data.data_prefix=[.5,/raid/data/pile/my-gpt3_00_text_document,.5,/raid/data/pile/my-gpt3_01_text_document]", + # Or see example below: + # data_prefix: + # - .5 + # - /raid/data/pile/my-gpt3_00_text_document + # - .5 + # - /raid/data/pile/my-gpt3_01_text_document + data_prefix: [1.0, /path/to/data] + index_mapping_dir: null # path to save index mapping .npy files, by default will save in the same location as data_prefix + data_impl: mmap + splits_string: 900,50,50 + seq_length: ${model.encoder_seq_length} + skip_warmup: True + num_workers: 0 + dataloader_type: single # cyclic, LDDL + reset_position_ids: False # Reset position ids after end-of-document token + reset_attention_mask: False # Reset attention mask after end-of-document token + eod_mask_loss: False # Mask loss for the end of document tokens + masked_lm_prob: 0.15 # Probability of replacing a token with mask. + short_seq_prob: 0.1 # Probability of producing a short sequence. + ceil_to_power_2: True + + optim: + name: fused_adam + lr: 2e-4 + weight_decay: 0.01 + betas: + - 0.9 + - 0.98 + sched: + name: CosineAnnealing + warmup_steps: 500 + constant_steps: 50000 + min_lr: 2e-5 diff --git a/examples/nlp/language_modeling/conf/megatron_griffin_finetuning_config.yaml b/examples/nlp/language_modeling/conf/megatron_griffin_finetuning_config.yaml new file mode 100644 index 000000000000..64d1b67bc148 --- /dev/null +++ b/examples/nlp/language_modeling/conf/megatron_griffin_finetuning_config.yaml @@ -0,0 +1,285 @@ +name: megatron_griffin +restore_from_path: ${model.restore_from_path} # used when starting from a .nemo file + +trainer: + devices: 1 + accelerator: gpu + num_nodes: 1 + precision: bf16 + logger: False # logger provided by exp_manager + enable_checkpointing: False + use_distributed_sampler: False + max_epochs: 9999 + max_steps: 10000 # consumed_samples = global_step * micro_batch_size * data_parallel_size * accumulate_grad_batches + log_every_n_steps: 1 # frequency with which training steps are logged + val_check_interval: 200 # If is an int n > 1, will run val every n training steps, if a float 0.0 - 1.0 will run val every epoch fraction, e.g. 0.25 will run val every quarter epoch + gradient_clip_val: 1.0 + limit_val_batches: 1024 + limit_test_batches: 500 + +exp_manager: + explicit_log_dir: null + exp_dir: null + name: ${name} + create_wandb_logger: True + wandb_logger_kwargs: + project: griffin + name: sft-test + resume_if_exists: False + resume_ignore_no_checkpoint: True + create_checkpoint_callback: True + checkpoint_callback_params: + monitor: validation_${model.data.validation_ds.metric.name} + save_top_k: 1 + mode: min + save_nemo_on_train_end: True + filename: '${name}--{${exp_manager.checkpoint_callback_params.monitor}:.3f}-{step}-{consumed_samples}' + model_parallel_size: ${model.tensor_model_parallel_size} + always_save_nemo: False + save_best_model: True + create_early_stopping_callback: True + early_stopping_callback_params: + monitor: "val_loss" + mode: "min" + min_delta: 0.001 + patience: 10 + verbose: True + strict: False # Should be False to avoid a runtime error where EarlyStopping says monitor is unavailable, which sometimes happens with resumed training. + + +model: + restore_from_path: + # model parallelism + micro_batch_size: 2 + global_batch_size: 2 + tensor_model_parallel_size: 1 + pipeline_model_parallel_size: 1 + virtual_pipeline_model_parallel_size: null + vocab_size: 256000 + apply_rope_fusion: True + # model architecture + encoder_seq_length: 512 + max_position_embeddings: ${.encoder_seq_length} + position_embedding_type: 'rope' # Position embedding type. Options ['learned_absolute', 'rope', 'alibi', 'kerple' , 'xpos', 'sandwich'] xpos and sandwich are experimental. + num_layers: 26 + gated_linear_unit: True + window_size: [1024, 0] + num_query_groups: 1 + attention_dropout: 0.0 + hidden_dropout: 0.0 + hidden_size: 2560 + bias_activation_fusion: True + ffn_hidden_size: 7680 # Transformer FFN hidden size. Usually 4 * hidden_size. + num_attention_heads: 10 + transformer_block_type: pre_ln + init_method_std: 0.02 # Standard deviation of the zero mean normal distribution used for weight initialization.') + kv_channels: null # Projection weights dimension in multi-head attention. Set to hidden_size // num_attention_heads if null + apply_query_key_layer_scaling: True # scale Q * K^T by 1 / layer-number. + normalization: RMSNorm + layernorm_epsilon: 1e-6 + rotary_interleaved: False + layernorm_zero_centered_gamma: True + make_vocab_size_divisible_by: 128 # Pad the vocab size to be divisible by this value for computation efficiency. + pre_process: True # add embedding + post_process: True # add pooler + megatron_legacy: False + activation: 'fast-geglu' + + tokenizer: + library: 'huggingface' + type: 'google/recurrentgemma-2b' + model: null + vocab_file: null + merge_file: null + sentencepiece_legacy: False + + # precision + native_amp_init_scale: 4294967296 # 2 ** 32 + native_amp_growth_interval: 1000 + fp32_residual_connection: False # Move residual connections to fp32 + fp16_lm_cross_entropy: False # Move the cross entropy unreduced loss calculation for lm head to fp16 + + # Megatron O2-style half-precision + megatron_amp_O2: False # Enable O2-level automatic mixed precision using main parameters + grad_allreduce_chunk_size_mb: 125 + grad_div_ar_fusion: False + + # miscellaneous + seed: 1234 + use_cpu_initialization: False # Init weights on the CPU (slow for large models) + onnx_safe: False # Use work-arounds for known problems with Torch ONNX exporter. + gradient_as_bucket_view: True # PyTorch DDP argument. Allocate gradients in a contiguous bucket to save memory (less fragmentation and buffer memory) + + ## Activation Checkpointing + # NeMo Megatron supports 'selective' activation checkpointing where only the memory intensive part of attention is checkpointed. + # These memory intensive activations are also less compute intensive which makes activation checkpointing more efficient for LLMs (20B+). + # See Reducing Activation Recomputation in Large Transformer Models: https://arxiv.org/abs/2205.05198 for more details. + # 'full' will checkpoint the entire transformer layer. + activations_checkpoint_granularity: null # 'selective' or 'full' + activations_checkpoint_method: null # 'uniform', 'block' + # 'uniform' divides the total number of transformer layers and checkpoints the input activation + # of each chunk at the specified granularity. When used with 'selective', 'uniform' checkpoints all attention blocks in the model. + # 'block' checkpoints the specified number of layers per pipeline stage at the specified granularity + activations_checkpoint_num_layers: null + # when using 'uniform' this creates groups of transformer layers to checkpoint. Usually set to 1. Increase to save more memory. + # when using 'block' this this will checkpoint the first activations_checkpoint_num_layers per pipeline stage. + num_micro_batches_with_partial_activation_checkpoints: null + # This feature is valid only when used with pipeline-model-parallelism. + # When an integer value is provided, it sets the number of micro-batches where only a partial number of Transformer layers get checkpointed + # and recomputed within a window of micro-batches. The rest of micro-batches in the window checkpoint all Transformer layers. The size of window is + # set by the maximum outstanding micro-batch backpropagations, which varies at different pipeline stages. The number of partial layers to checkpoint + # per micro-batch is set by 'activations_checkpoint_num_layers' with 'activations_checkpoint_method' of 'block'. + # This feature enables using activation checkpoint at a fraction of micro-batches up to the point of full GPU memory usage. + activations_checkpoint_layers_per_pipeline: null + # This feature is valid only when used with pipeline-model-parallelism. + # When an integer value (rounded down when float is given) is provided, it sets the number of Transformer layers to skip checkpointing at later + # pipeline stages. For example, 'activations_checkpoint_layers_per_pipeline' of 3 makes pipeline stage 1 to checkpoint 3 layers less than + # stage 0 and stage 2 to checkpoint 6 layers less stage 0, and so on. This is possible because later pipeline stage + # uses less GPU memory with fewer outstanding micro-batch backpropagations. Used with 'num_micro_batches_with_partial_activation_checkpoints', + # this feature removes most of activation checkpoints at the last pipeline stage, which is the critical execution path. + sequence_parallel: False + + peft: + peft_scheme: "lora" # can be either adapter,ia3, lora, or ptuning + restore_from_path: null + + # Used for adapter peft training + adapter_tuning: + type: 'parallel_adapter' # this should be either 'parallel_adapter' or 'linear_adapter' + adapter_dim: 32 + adapter_dropout: 0.0 + norm_position: 'pre' # This can be set to 'pre', 'post' or null, 'pre' is normally what is used. + column_init_method: 'xavier' # IGNORED if linear_adapter is used, options: xavier, zero or normal + row_init_method: 'zero' # IGNORED if linear_adapter is used, options: xavier, zero or normal + norm_type: 'mixedfusedlayernorm' # IGNORED if layer_adapter is used, options are ['layernorm', 'mixedfusedlayernorm'] + layer_selection: null # selects in which layers to add adapters, e.g. [1,12] will add adapters to layer 1 (lowest) and 12. null will apply adapters to all layers + weight_tying: False + position_embedding_strategy: null # used only when weight_tying is True + + lora_tuning: + target_modules: ['all'] # this can either be 'attention_qkv','attention_dense','mlp_fc1','mlp_fc2', attention (qkv & dense), mlp (fc1 & fc2) + adapter_dim: 32 + alpha: 32 + adapter_dropout: 0.0 + column_init_method: 'xavier' # IGNORED if linear_adapter is used, options: xavier, zero or normal + row_init_method: 'zero' # IGNORED if linear_adapter is used, options: xavier, zero or normal + layer_selection: null # selects in which layers to add lora adapters. e.g. [1,12] will add lora to layer 1 (lowest) and 12. null will apply adapters to all layers + weight_tying: False + position_embedding_strategy: null # used only when weight_tying is True + + # Used for p-tuning peft training + p_tuning: + virtual_tokens: 10 # The number of virtual tokens the prompt encoder should add at the start of the sequence + bottleneck_dim: 1024 # the size of the prompt encoder mlp bottleneck + embedding_dim: 1024 # the size of the prompt encoder embeddings + init_std: 0.023 + + ia3_tuning: + layer_selection: null # selects in which layers to add ia3 adapters. e.g. [1,12] will add lora to layer 1 (lowest) and 12. null will apply adapters to all layers + + selective_tuning: + tunable_base_param_names: ["self_attention", "word_embeddings"] # TODO: regex support @adithyre + + + data: + train_ds: + # Example of how to specify paths to multiple datasets + # file_names: + # - /path/to/squad.jsonl + # - /path/to/mnli.jsonl + # - /path/to/boolq.jsonl + # Example of how each dataset is formatted + # {'input': 'John von Neumann\nVon Neumann made fundamental contributions .... Q: What did the math of artificial viscosity do?', 'output': 'smoothed the shock transition without sacrificing basic physics'} + file_names: null # Path to a list of JSONL files corresponding to the source data. + global_batch_size: ${model.global_batch_size} + micro_batch_size: ${model.micro_batch_size} + shuffle: True + num_workers: 0 + memmap_workers: 2 + pin_memory: True + max_seq_length: 2048 + min_seq_length: 1 + drop_last: True + # Example of how to specify concat_sampling_probabilities + # concat_sampling_probabilities: + # - 0.5 + # - 0.25 + # - 0.25 + concat_sampling_probabilities: [1.0] # When providing a list of datasets, this arg defines the sampling probabilities from each dataset when strategy='random' + label_key: 'output' + add_eos: True + add_sep: False + add_bos: True + truncation_field: "input" # # Can be multiple keys separated with ',' Options: keys in prompt_template + index_mapping_dir: null # Path to a directory to write index mapping files. + prompt_template: "{input} {output}" # fstring to use for assistant prompt. Example: "Q: {input}\nA: {output}" + truncation_method: 'right' # Truncation from which position, Options: ['left', 'right'] + ceil_to_power_2: True + validation_ds: + file_names: null # Path to a list of JSONL files corresponding to the source data. Data format is identical to train_ds. + names: null # Names of the corresponding datasets used to log metrics. + global_batch_size: ${model.global_batch_size} + micro_batch_size: ${model.micro_batch_size} + shuffle: False + num_workers: 0 + memmap_workers: ${model.data.train_ds.memmap_workers} + pin_memory: True + max_seq_length: 2048 + min_seq_length: 1 + drop_last: False + label_key: ${model.data.train_ds.label_key} + add_eos: ${model.data.train_ds.add_eos} + add_sep: ${model.data.train_ds.add_sep} + add_bos: ${model.data.train_ds.add_bos} + write_predictions_to_file: False + output_file_path_prefix: null # Prefix of the file to write predictions to. + truncation_field: ${model.data.train_ds.truncation_field} # Options: keys in prompt_template + index_mapping_dir: null # Path to a directory to write index mapping files. + prompt_template: ${model.data.train_ds.prompt_template} # fstring to use for assistant prompt. Example: "Q: {input}\nA: {output}" + tokens_to_generate: 32 # decide how many tokens we want to generate to evaluate performance with string metrics + truncation_method: 'right' # Truncation from which position, Options: ['left', 'right'] + ceil_to_power_2: True + metric: + name: "loss" # Name of the evaluation metric to use. Options: ['exact_string_match', 'loss'] + average: null # Average the metric over the dataset. Options: ['macro', 'micro']. Works only for 'F1', 'accuracy' etc. Refer to torchmetrics for metrics where this is supported. + num_classes: null + test_ds: + file_names: null # Path to a list of JSONL files corresponding to the source data. Data format is identical to train_ds. + names: null # Names of the corresponding datasets used to log metrics. + global_batch_size: ${model.global_batch_size} + micro_batch_size: ${model.micro_batch_size} + shuffle: False + num_workers: 0 + memmap_workers: ${model.data.train_ds.memmap_workers} + pin_memory: True + max_seq_length: 2048 + min_seq_length: 1 + drop_last: False + label_key: ${model.data.train_ds.label_key} + add_eos: ${model.data.train_ds.add_eos} + add_sep: ${model.data.train_ds.add_sep} + add_bos: ${model.data.train_ds.add_bos} + write_predictions_to_file: False + output_file_path_prefix: null # Prefix of the file to write predictions to. + truncation_field: ${model.data.train_ds.truncation_field} # Options: keys in prompt_template + index_mapping_dir: null # Path to a directory to write index mapping files. + prompt_template: ${model.data.train_ds.prompt_template} + tokens_to_generate: 32 # decide how many tokens we want to generate to evaluate performance with string metrics + truncation_method: 'right' # Truncation from which position, Options: ['left', 'right'] + metric: + name: "loss" # Name of the evaluation metric to use. Options: ['exact_string_match', 'loss'] + average: null # Average the metric over the dataset. Options: ['macro', 'micro']. Works only for 'F1', 'accuracy' etc. Refer to torchmetrics for metrics where this is supported. + num_classes: null + + optim: + name: distributed_fused_adam + lr: 2e-4 + weight_decay: 0.01 + betas: + - 0.9 + - 0.98 + sched: + name: CosineAnnealing + warmup_steps: 500 + constant_steps: 50000 + min_lr: 2e-5 diff --git a/examples/nlp/language_modeling/conf/megatron_griffin_generate_config.yaml b/examples/nlp/language_modeling/conf/megatron_griffin_generate_config.yaml new file mode 100644 index 000000000000..4b3c14c846d1 --- /dev/null +++ b/examples/nlp/language_modeling/conf/megatron_griffin_generate_config.yaml @@ -0,0 +1,292 @@ +name: megatron_griffin +restore_from_path: ${model.restore_from_path} # used when starting from a .nemo file + +trainer: + devices: 1 + accelerator: gpu + num_nodes: 1 + precision: bf16 + logger: False # logger provided by exp_manager + enable_checkpointing: False + use_distributed_sampler: False + max_epochs: 9999 + max_steps: 10000 # consumed_samples = global_step * micro_batch_size * data_parallel_size * accumulate_grad_batches + log_every_n_steps: 1 # frequency with which training steps are logged + val_check_interval: 200 # If is an int n > 1, will run val every n training steps, if a float 0.0 - 1.0 will run val every epoch fraction, e.g. 0.25 will run val every quarter epoch + gradient_clip_val: 1.0 + limit_val_batches: 1024 + limit_test_batches: 500 + +exp_manager: + explicit_log_dir: null + exp_dir: null + name: ${name} + create_wandb_logger: True + wandb_logger_kwargs: + project: griffin + name: sft-test + resume_if_exists: False + resume_ignore_no_checkpoint: True + create_checkpoint_callback: True + checkpoint_callback_params: + monitor: validation_${model.data.validation_ds.metric.name} + save_top_k: 1 + mode: min + save_nemo_on_train_end: True + filename: '${name}--{${exp_manager.checkpoint_callback_params.monitor}:.3f}-{step}-{consumed_samples}' + model_parallel_size: ${model.tensor_model_parallel_size} + always_save_nemo: False + save_best_model: True + create_early_stopping_callback: True + early_stopping_callback_params: + monitor: "val_loss" + mode: "min" + min_delta: 0.001 + patience: 10 + verbose: True + strict: False # Should be False to avoid a runtime error where EarlyStopping says monitor is unavailable, which sometimes happens with resumed training. + + +model: + restore_from_path: null + # model parallelism + micro_batch_size: 2 + global_batch_size: 2 + tensor_model_parallel_size: 1 + pipeline_model_parallel_size: 1 + virtual_pipeline_model_parallel_size: null + vocab_size: 256000 + apply_rope_fusion: True + # model architecture + encoder_seq_length: 512 + logits_soft_cap: 30.0 + max_position_embeddings: ${.encoder_seq_length} + position_embedding_type: 'rope' # Position embedding type. Options ['learned_absolute', 'rope', 'alibi', 'kerple' , 'xpos', 'sandwich'] xpos and sandwich are experimental. + num_layers: 26 + gated_linear_unit: True + window_size: [1024, 0] + num_query_groups: 1 + attention_dropout: 0.0 + hidden_dropout: 0.0 + hidden_size: 2560 + bias_activation_fusion: True + ffn_hidden_size: 7680 # Transformer FFN hidden size. Usually 4 * hidden_size. + num_attention_heads: 10 + transformer_block_type: pre_ln + init_method_std: 0.02 # Standard deviation of the zero mean normal distribution used for weight initialization.') + kv_channels: null # Projection weights dimension in multi-head attention. Set to hidden_size // num_attention_heads if null + apply_query_key_layer_scaling: True # scale Q * K^T by 1 / layer-number. + normalization: RMSNorm + layernorm_epsilon: 1e-6 + rotary_interleaved: False + layernorm_zero_centered_gamma: True + make_vocab_size_divisible_by: 128 # Pad the vocab size to be divisible by this value for computation efficiency. + pre_process: True # add embedding + post_process: True # add pooler + megatron_legacy: False + activation: 'fast-geglu' + + answer_only_loss: True + + + tokenizer: + library: 'huggingface' + type: 'google/recurrentgemma-2b' + model: null + vocab_file: null + merge_file: null + sentencepiece_legacy: False + + # precision + native_amp_init_scale: 4294967296 # 2 ** 32 + native_amp_growth_interval: 1000 + fp32_residual_connection: False # Move residual connections to fp32 + fp16_lm_cross_entropy: False # Move the cross entropy unreduced loss calculation for lm head to fp16 + + # Megatron O2-style half-precision + megatron_amp_O2: False # Enable O2-level automatic mixed precision using main parameters + grad_allreduce_chunk_size_mb: 125 + grad_div_ar_fusion: False + + # miscellaneous + seed: 1234 + use_cpu_initialization: False # Init weights on the CPU (slow for large models) + onnx_safe: False # Use work-arounds for known problems with Torch ONNX exporter. + gradient_as_bucket_view: True # PyTorch DDP argument. Allocate gradients in a contiguous bucket to save memory (less fragmentation and buffer memory) + + ## Activation Checkpointing + # NeMo Megatron supports 'selective' activation checkpointing where only the memory intensive part of attention is checkpointed. + # These memory intensive activations are also less compute intensive which makes activation checkpointing more efficient for LLMs (20B+). + # See Reducing Activation Recomputation in Large Transformer Models: https://arxiv.org/abs/2205.05198 for more details. + # 'full' will checkpoint the entire transformer layer. + activations_checkpoint_granularity: null # 'selective' or 'full' + activations_checkpoint_method: null # 'uniform', 'block' + # 'uniform' divides the total number of transformer layers and checkpoints the input activation + # of each chunk at the specified granularity. When used with 'selective', 'uniform' checkpoints all attention blocks in the model. + # 'block' checkpoints the specified number of layers per pipeline stage at the specified granularity + activations_checkpoint_num_layers: null + # when using 'uniform' this creates groups of transformer layers to checkpoint. Usually set to 1. Increase to save more memory. + # when using 'block' this this will checkpoint the first activations_checkpoint_num_layers per pipeline stage. + num_micro_batches_with_partial_activation_checkpoints: null + # This feature is valid only when used with pipeline-model-parallelism. + # When an integer value is provided, it sets the number of micro-batches where only a partial number of Transformer layers get checkpointed + # and recomputed within a window of micro-batches. The rest of micro-batches in the window checkpoint all Transformer layers. The size of window is + # set by the maximum outstanding micro-batch backpropagations, which varies at different pipeline stages. The number of partial layers to checkpoint + # per micro-batch is set by 'activations_checkpoint_num_layers' with 'activations_checkpoint_method' of 'block'. + # This feature enables using activation checkpoint at a fraction of micro-batches up to the point of full GPU memory usage. + activations_checkpoint_layers_per_pipeline: null + # This feature is valid only when used with pipeline-model-parallelism. + # When an integer value (rounded down when float is given) is provided, it sets the number of Transformer layers to skip checkpointing at later + # pipeline stages. For example, 'activations_checkpoint_layers_per_pipeline' of 3 makes pipeline stage 1 to checkpoint 3 layers less than + # stage 0 and stage 2 to checkpoint 6 layers less stage 0, and so on. This is possible because later pipeline stage + # uses less GPU memory with fewer outstanding micro-batch backpropagations. Used with 'num_micro_batches_with_partial_activation_checkpoints', + # this feature removes most of activation checkpoints at the last pipeline stage, which is the critical execution path. + sequence_parallel: False + + peft: + peft_scheme: "lora" # can be either adapter,ia3, lora, or ptuning + restore_from_path: null + + # Used for adapter peft training + adapter_tuning: + type: 'parallel_adapter' # this should be either 'parallel_adapter' or 'linear_adapter' + adapter_dim: 32 + adapter_dropout: 0.0 + norm_position: 'pre' # This can be set to 'pre', 'post' or null, 'pre' is normally what is used. + column_init_method: 'xavier' # IGNORED if linear_adapter is used, options: xavier, zero or normal + row_init_method: 'zero' # IGNORED if linear_adapter is used, options: xavier, zero or normal + norm_type: 'mixedfusedlayernorm' # IGNORED if layer_adapter is used, options are ['layernorm', 'mixedfusedlayernorm'] + layer_selection: null # selects in which layers to add adapters, e.g. [1,12] will add adapters to layer 1 (lowest) and 12. null will apply adapters to all layers + weight_tying: False + position_embedding_strategy: null # used only when weight_tying is True + + lora_tuning: + target_modules: ['all'] # this can either be 'attention_qkv','attention_dense','mlp_fc1','mlp_fc2', attention (qkv & dense), mlp (fc1 & fc2) + adapter_dim: 32 + alpha: 32 + adapter_dropout: 0.0 + column_init_method: 'xavier' # IGNORED if linear_adapter is used, options: xavier, zero or normal + row_init_method: 'zero' # IGNORED if linear_adapter is used, options: xavier, zero or normal + layer_selection: null # selects in which layers to add lora adapters. e.g. [1,12] will add lora to layer 1 (lowest) and 12. null will apply adapters to all layers + weight_tying: False + position_embedding_strategy: null # used only when weight_tying is True + + # Used for p-tuning peft training + p_tuning: + virtual_tokens: 10 # The number of virtual tokens the prompt encoder should add at the start of the sequence + bottleneck_dim: 1024 # the size of the prompt encoder mlp bottleneck + embedding_dim: 1024 # the size of the prompt encoder embeddings + init_std: 0.023 + + ia3_tuning: + layer_selection: null # selects in which layers to add ia3 adapters. e.g. [1,12] will add lora to layer 1 (lowest) and 12. null will apply adapters to all layers + + selective_tuning: + tunable_base_param_names: ["self_attention", "word_embeddings"] # TODO: regex support @adithyre + + data: + test_ds: + file_names: ??? # Path to a list of JSONL files corresponding to the source data. Data format is identical to train_ds. + names: ??? # Names of the corresponding datasets used to log metrics. + global_batch_size: 1 + micro_batch_size: 1 + shuffle: False + num_workers: 0 + pin_memory: True + max_seq_length: 2048 + min_seq_length: 1 + drop_last: False + context_key: 'input' + label_key: 'output' + add_eos: True + add_sep: False + add_bos: True + write_predictions_to_file: False + output_file_path_prefix: null # Prefix of the file to write predictions to. + truncation_field: "input" # Options: keys in prompt_template + index_mapping_dir: null # Path to a directory to write index mapping files. + prompt_template: "{input} {output}" + tokens_to_generate: 32 # decide how many tokens we want to generate to evaluate performance with string metrics + truncation_method: 'right' # Truncation from which position, Options: ['left', 'right'] + ceil_to_power_2: True + + metric: + name: "loss" # Name of the evaluation metric to use. Options: ['exact_string_match', 'loss'] + average: null # Average the metric over the dataset. Options: ['macro', 'micro']. Works only for 'F1', 'accuracy' etc. Refer to torchmetrics for metrics where this is supported. + num_classes: null + +inference: + greedy: True # Whether or not to use sampling ; use greedy decoding otherwise + top_k: 0 # The number of highest probability vocabulary tokens to keep for top-k-filtering. + top_p: 0.9 # If set to float < 1, only the most probable tokens with probabilities that add up to top_p or higher are kept for generation. + temperature: 1.0 # sampling temperature + all_probs: False # whether return the log prob for all the tokens in vocab + repetition_penalty: 1.2 # The parameter for repetition penalty. 1.0 means no penalty. + min_tokens_to_generate: 0 # The minimum length of the sequence to be generated. + compute_logprob: False # a flag used to compute logprob of all the input text, a very special case of running inference, default False + outfile_path: output.txt + compute_attention_mask: True + +# server-related configs +server: False # whether launch the API server +port: 5555 # the port number for the inference server +web_server: False # whether launch the web inference server +share: True # whether create a public URL +username: test # user name for web client +password: test2 # password for web client +web_port: 9889 # the port number of the web server 1058 +chat: False # use the chat interface +chatbot_config: + value: False # whether to inject the value attributes + attributes: + - name: Quality + min: 0 + max: 4 + key: quality + type: int + default: 4 + - name: Toxicity + min: 0 + max: 4 + key: toxcity + type: int + default: 0 + - name: Humor + min: 0 + max: 4 + key: humor + type: int + default: 0 + - name: Creativity + min: 0 + max: 4 + key: creativity + type: int + default: 0 + - name: Violence + min: 0 + max: 4 + key: violence + type: int + default: 0 + - name: Helpfulness + min: 0 + max: 4 + key: helpfulness + type: int + default: 4 + - name: Not_Appropriate + min: 0 + max: 4 + key: not_appropriate + type: int + default: 0 + - name: Language + choices: ['ar', 'bg', 'bn', 'ca', 'cs', 'da', 'de', 'el', 'en', 'eo', 'es', 'eu', 'fa', 'fi', 'fr', 'gl', 'he', 'hu', 'id', 'it', 'ja', 'ko', 'nb', 'nl', 'pl', 'pt', 'ro', 'ru', 'sk', 'sv', 'th', 'tr', 'uk', 'vi', 'zh'] + key: lang + type: list + default: en + + user: User + assistant: Assistant + system: "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.\n\n" \ No newline at end of file diff --git a/examples/nlp/language_modeling/megatron_griffin_finetuning.py b/examples/nlp/language_modeling/megatron_griffin_finetuning.py new file mode 100644 index 000000000000..c5ae513d5874 --- /dev/null +++ b/examples/nlp/language_modeling/megatron_griffin_finetuning.py @@ -0,0 +1,60 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch.multiprocessing as mp +from omegaconf.omegaconf import OmegaConf + +from nemo.collections.nlp.models.language_modeling.megatron_griffin_sft_model import MegatronGriffinSFTModel +from nemo.collections.nlp.parts.megatron_trainer_builder import MegatronLMPPTrainerBuilder +from nemo.collections.nlp.parts.peft_config import PEFT_CONFIG_MAP +from nemo.core.config import hydra_runner +from nemo.utils import logging +from nemo.utils.exp_manager import exp_manager + +mp.set_start_method("spawn", force=True) + + +@hydra_runner(config_path="conf", config_name="megatron_griffin_finetuning_config") +def main(cfg) -> None: + + logging.info("\n\n************** Experiment configuration ***********") + logging.info(f'\n{OmegaConf.to_yaml(cfg)}') + + precision = cfg.trainer.precision + trainer = MegatronLMPPTrainerBuilder(cfg).create_trainer() + # Restore the precision value after Trainer is built. + cfg.trainer.precision = precision + exp_manager(trainer, cfg.exp_manager) + + model_cfg = MegatronGriffinSFTModel.merge_cfg_with(cfg.model.restore_from_path, cfg) + model = MegatronGriffinSFTModel.restore_from(cfg.model.restore_from_path, model_cfg, trainer=trainer) + + peft_cfg_cls = PEFT_CONFIG_MAP[cfg.model.peft.peft_scheme] + + if cfg.model.peft.restore_from_path is not None: + # initialize peft weights from a checkpoint instead of randomly + # This is not the same as resume training because optimizer states are not restored. + logging.info("PEFT Weights will be loaded from", cfg.model.peft.restore_from_path) + model.load_adapters(cfg.model.peft.restore_from_path, peft_cfg_cls(model_cfg)) + elif peft_cfg_cls is not None: + logging.info("Adding adapter weights to the model for PEFT") + model.add_adapter(peft_cfg_cls(model_cfg)) + else: + logging.info(f"Running full finetuning since no peft scheme is given.\n{model.summarize()}") + + trainer.fit(model) + + +if __name__ == '__main__': + main() diff --git a/examples/nlp/language_modeling/megatron_griffin_generate.py b/examples/nlp/language_modeling/megatron_griffin_generate.py new file mode 100644 index 000000000000..c8e36668fced --- /dev/null +++ b/examples/nlp/language_modeling/megatron_griffin_generate.py @@ -0,0 +1,69 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import os +import torch.multiprocessing as mp +from omegaconf.omegaconf import OmegaConf +from nemo.collections.nlp.models.language_modeling.megatron_griffin_sft_model import MegatronGriffinSFTModel +from nemo.collections.nlp.parts.megatron_trainer_builder import MegatronLMPPTrainerBuilder +from nemo.collections.nlp.parts.peft_config import PEFT_CONFIG_MAP +from nemo.core.config import hydra_runner +from nemo.utils import logging +from nemo.utils.model_utils import inject_model_parallel_rank + + +mp.set_start_method("spawn", force=True) + + +@hydra_runner(config_path="conf", config_name="megatron_griffin_generate_config") +def main(cfg) -> None: + logging.info("\n\n************** Experiment configuration ***********") + logging.info(f"\n{OmegaConf.to_yaml(cfg)}") + trainer = MegatronLMPPTrainerBuilder(cfg).create_trainer() + + if cfg.model.peft.restore_from_path: + model_cfg = MegatronGriffinSFTModel.merge_inference_cfg(cfg.model.peft.restore_from_path, cfg) + else: + model_cfg = MegatronGriffinSFTModel.merge_inference_cfg(cfg.model.restore_from_path, cfg) + + model = MegatronGriffinSFTModel.restore_from(cfg.model.restore_from_path, model_cfg, trainer=trainer) + + if cfg.model.peft.restore_from_path: + model.load_adapters(cfg.model.peft.restore_from_path) + elif cfg.model.peft.restore_from_ckpt.checkpoint_dir and cfg.model.peft.restore_from_ckpt.checkpoint_name: + peft_cfg_cls = PEFT_CONFIG_MAP[cfg.model.peft.peft_scheme] + checkpoint_path = os.path.join( + cfg.model.peft.restore_from_ckpt.checkpoint_dir, cfg.model.peft.restore_from_ckpt.checkpoint_name + ) + # checkpoint_path is a dir in case of distributed checkpointing + if not os.path.isdir(checkpoint_path): + # legacy checkpoint needs model parallel rank injection + checkpoint_path = inject_model_parallel_rank( + os.path.join( + cfg.model.peft.restore_from_ckpt.checkpoint_dir, cfg.model.peft.restore_from_ckpt.checkpoint_name + ) + ) + model.load_adapters(checkpoint_path, peft_cfgs=peft_cfg_cls(model_cfg)) + else: + raise NotImplementedError("distributed checkpointing of PEFT weights is not supported") + + model.freeze() + logging.info(f"Freezing parameters for PEFT eval:\n{model.summarize()}") + + trainer.test(model) + + +if __name__ == "__main__": + main() diff --git a/nemo/collections/nlp/data/language_modeling/megatron/gpt_sft_dataset.py b/nemo/collections/nlp/data/language_modeling/megatron/gpt_sft_dataset.py index 501c766374e1..6354387c18e7 100644 --- a/nemo/collections/nlp/data/language_modeling/megatron/gpt_sft_dataset.py +++ b/nemo/collections/nlp/data/language_modeling/megatron/gpt_sft_dataset.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import math import re from typing import List, Mapping, Optional @@ -60,6 +61,7 @@ def __init__( special_tokens: Optional[Mapping[str, str]] = None, # special tokens, a dictory of {token_type: token} is_test: bool = False, output_original_text: bool = False, + ceil_to_power_2: bool = False, ): """ file_path: Path to a JSONL GPT supervised fine-tuning dataset. Data is formatted as multiple JSON lines with each line formatted as follows. {'input': 'John von Neumann\nVon Neumann made fundamental contributions .... Q: What did the math of artificial viscosity do?', 'output': 'smoothed the shock transition without sacrificing basic physics'} @@ -109,6 +111,8 @@ def __init__( self.truncation_method = truncation_method self.is_test = is_test self.output_original_text = output_original_text + self.ceil_to_power_2 = ceil_to_power_2 + if special_tokens is None: self.special_tokens = { "system_turn_start": "", @@ -406,7 +410,11 @@ def _maybe_cast_to_list(self, x): return x def _ceil_to_nearest(self, n, m): - return (n + m - 1) // m * m + if self.ceil_to_power_2: + # Reccurent Gemma (AKA Griffin) requires seq length to be a power of 2 for parallel scan + return 2 ** math.ceil(math.log2(n)) + else: + return (n + m - 1) // m * m def _collate_item(self, item, max_length, pad_id): item = self._maybe_cast_to_list(item) diff --git a/nemo/collections/nlp/models/language_modeling/megatron/griffin/__init__.py b/nemo/collections/nlp/models/language_modeling/megatron/griffin/__init__.py new file mode 100755 index 000000000000..d9155f923f18 --- /dev/null +++ b/nemo/collections/nlp/models/language_modeling/megatron/griffin/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/nemo/collections/nlp/models/language_modeling/megatron/griffin/griffin_block.py b/nemo/collections/nlp/models/language_modeling/megatron/griffin/griffin_block.py new file mode 100755 index 000000000000..3fc26a51f3c1 --- /dev/null +++ b/nemo/collections/nlp/models/language_modeling/megatron/griffin/griffin_block.py @@ -0,0 +1,75 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from megatron.core.models.common.language_module.language_module import LanguageModule +from megatron.core.transformer.custom_layers.transformer_engine import TENorm +from megatron.core.transformer.spec_utils import build_module +from megatron.core.transformer.transformer_config import TransformerConfig +from torch import nn + +from nemo.collections.nlp.models.language_modeling.megatron.griffin.griffin_layer_spec import ( + griffin_mqa_layer_with_transformer_engine_spec, + griffin_recurrent_layer_with_transformer_engine_spec, +) + + +def get_griffin_layers(num_layers): + dict_spec = { + "Recurrent_Layer": griffin_recurrent_layer_with_transformer_engine_spec, + "Attention_Layer": griffin_mqa_layer_with_transformer_engine_spec, + } + + griffin_layers = [] + for i in range(num_layers): + if i % 3 == 2: + griffin_layers.append(dict_spec["Attention_Layer"]) + else: + griffin_layers.append(dict_spec["Recurrent_Layer"]) + + return griffin_layers + + +def create_block( + config, layer_spec, layer_idx, +): + block = build_module(layer_spec, config,) + block.layer_number = layer_idx + 1 + return block + + +class GriffinStack(LanguageModule): + def __init__( + self, config: TransformerConfig, + ): + + super().__init__(config) + self.config = config + self.griffin_layers = get_griffin_layers(self.config.num_layers) + + self.layers = nn.ModuleList( + [create_block(self.config, layer_spec, layer_idx=i,) for i, layer_spec in enumerate(self.griffin_layers)] + ) + self.final_layernorm = TENorm( + config=self.config, hidden_size=self.config.hidden_size, eps=self.config.layernorm_epsilon, + ) + + def forward(self, hidden_states, attention_mask, rotary_pos_emb): + + for layer in self.layers: + + hidden_states, _ = layer(hidden_states, attention_mask=attention_mask, rotary_pos_emb=rotary_pos_emb) + + hidden_states = self.final_layernorm(hidden_states) + + return hidden_states diff --git a/nemo/collections/nlp/models/language_modeling/megatron/griffin/griffin_layer_spec.py b/nemo/collections/nlp/models/language_modeling/megatron/griffin/griffin_layer_spec.py new file mode 100755 index 000000000000..a504898e9d64 --- /dev/null +++ b/nemo/collections/nlp/models/language_modeling/megatron/griffin/griffin_layer_spec.py @@ -0,0 +1,81 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from megatron.core.fusions.fused_bias_dropout import get_bias_dropout_add +from megatron.core.transformer.attention import SelfAttention, SelfAttentionSubmodules +from megatron.core.transformer.custom_layers.transformer_engine import ( + TEDotProductAttention, + TELayerNormColumnParallelLinear, + TERowParallelLinear, +) +from megatron.core.transformer.enums import AttnMaskType +from megatron.core.transformer.identity_op import IdentityOp +from megatron.core.transformer.mlp import MLP, MLPSubmodules +from megatron.core.transformer.spec_utils import ModuleSpec +from megatron.core.transformer.transformer_layer import TransformerLayer, TransformerLayerSubmodules + +from nemo.collections.nlp.models.language_modeling.megatron.griffin.recurrent_layer import ( + RecurrentBlock, + RecurrentBlockSubmodules, +) +from nemo.collections.nlp.models.language_modeling.megatron.griffin.recurrent_module import ( + RGLRU, + Conv1D, + RecurrentLayer, + RecurrentLayerSubmodules, +) + +griffin_mqa_layer_with_transformer_engine_spec = ModuleSpec( + module=TransformerLayer, + submodules=TransformerLayerSubmodules( + self_attention=ModuleSpec( + module=SelfAttention, + params={"attn_mask_type": AttnMaskType.causal}, + submodules=SelfAttentionSubmodules( + linear_qkv=TELayerNormColumnParallelLinear, + core_attention=TEDotProductAttention, + linear_proj=TERowParallelLinear, + q_layernorm=IdentityOp, + k_layernorm=IdentityOp, + ), + ), + self_attn_bda=get_bias_dropout_add, + mlp=ModuleSpec( + module=MLP, + submodules=MLPSubmodules(linear_fc1=TELayerNormColumnParallelLinear, linear_fc2=TERowParallelLinear,), + ), + mlp_bda=get_bias_dropout_add, + ), +) + +griffin_recurrent_layer_with_transformer_engine_spec = ModuleSpec( + module=RecurrentBlock, + submodules=RecurrentBlockSubmodules( + recurrent_layer=ModuleSpec( + module=RecurrentLayer, + submodules=RecurrentLayerSubmodules( + linear_in=TELayerNormColumnParallelLinear, + linear_out=TERowParallelLinear, + conv_1d=Conv1D, + rg_lru=RGLRU, + ), + ), + recurrent_bda=get_bias_dropout_add, + mlp=ModuleSpec( + module=MLP, + submodules=MLPSubmodules(linear_fc1=TELayerNormColumnParallelLinear, linear_fc2=TERowParallelLinear,), + ), + mlp_bda=get_bias_dropout_add, + ), +) diff --git a/nemo/collections/nlp/models/language_modeling/megatron/griffin/griffin_model.py b/nemo/collections/nlp/models/language_modeling/megatron/griffin/griffin_model.py new file mode 100755 index 000000000000..9f00fb9dd156 --- /dev/null +++ b/nemo/collections/nlp/models/language_modeling/megatron/griffin/griffin_model.py @@ -0,0 +1,156 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math + +import torch +from megatron.core.jit import jit_fuser +from megatron.core.models.common.embeddings.language_model_embedding import LanguageModelEmbedding +from megatron.core.models.common.embeddings.rotary_pos_embedding import RotaryEmbedding +from megatron.core.models.common.language_module.language_module import LanguageModule +from megatron.core.transformer.transformer_config import TransformerConfig +from torch import Tensor, nn + +from nemo.collections.nlp.models.language_modeling.megatron.griffin.griffin_block import GriffinStack + + +class GriffinModel(LanguageModule): + def __init__( + self, + config: TransformerConfig, + vocab_size: int = 256000, + logits_soft_cap: float = 30.0, + position_embedding_type: str = 'rope', + max_sequence_length: int = 1024, + rotary_percent: float = 0.5, + rotary_base: int = 10000, + pre_process=True, + ): + + super().__init__(config) + self.config = config + self.vocab_size = vocab_size + self.logits_soft_cap = logits_soft_cap + self.position_embedding_type = position_embedding_type + self.pre_process = pre_process + self.post_process = False + self.share_embeddings_and_output_weights = True + + if pre_process: + self.embedding = LanguageModelEmbedding( + config, + vocab_size=self.vocab_size, + max_sequence_length=max_sequence_length, + position_embedding_type=None, + ) + + if self.position_embedding_type == 'rope': + self.rotary_pos_emb = RotaryEmbedding( + kv_channels=config.kv_channels, + rotary_percent=rotary_percent, + rotary_interleaved=config.rotary_interleaved, + seq_len_interpolation_factor=None, + rotary_base=rotary_base, + ) + + self.decoder = GriffinStack(self.config) + + def shared_embedding_or_output_weight(self) -> Tensor: + """Gets the emedding weight or output logit weights when share embedding and output weights set to True. + + Returns: + Tensor: During pre processing it returns the input embeddings weight while during post processing it returns the final output layers weight + """ + if self.pre_process: + return self.embedding.word_embeddings.weight + elif self.post_process: + return self.output_layer.weight + return None + + def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs): + return { + i: layer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs) + for i, layer in enumerate(self.layers) + } + + def set_input_tensor(self, input_tensor: Tensor): + """Set input tensor to be used instead of forward()'s input. + + When doing pipeline parallelism the input from the previous + stage comes from communication, not from the input, so the + model's forward_step_func won't have it. This function is thus + used by internal code to bypass the input provided by the + forward_step_func""" + self.input_tensor = input_tensor + + def griffin_position_ids(self, token_ids): + # Create position ids + seq_length = token_ids.size(1) + position_ids = torch.arange(seq_length, dtype=torch.long, device=token_ids.device) + position_ids = position_ids.unsqueeze(0).expand_as(token_ids) + + return position_ids + + def embedding_forward(self, input_ids): + + position_ids = self.griffin_position_ids(input_ids) + embeddings = self.embedding(input_ids, position_ids) + embeddings = embeddings * torch.tensor(math.sqrt(self.config.hidden_size)).type_as(embeddings) + + return embeddings + + @jit_fuser + def _embedding_decode_(self, logits, transpose): + logits = nn.functional.tanh(logits / self.logits_soft_cap) * self.logits_soft_cap + if transpose: + logits = logits.transpose(0, 1) + return logits.contiguous() + + def embedding_decode(self, x, transpose): + x = x.permute(1, 0, 2) + logits = x @ self.embedding.word_embeddings.state_dict()['weight'].T + logits = self._embedding_decode_(logits, transpose) + + return logits + + def forward( + self, + input_ids: Tensor, + position_ids: Tensor = None, + attention_mask: Tensor = None, + labels: Tensor = None, + **extra_arg + ): + if input_ids is None: + input_ids = self.input_tensor + + hidden_states = self.embedding_forward(input_ids) + + rotary_pos_emb = None + self.decoder.input_tensor = None + if self.position_embedding_type == 'rope': + rotary_seq_len = self.rotary_pos_emb.get_rotary_seq_len(None, self.decoder, hidden_states, self.config) + rotary_pos_emb = self.rotary_pos_emb(rotary_seq_len) + + hidden_states = self.decoder(hidden_states, attention_mask=attention_mask, rotary_pos_emb=rotary_pos_emb) + + logits = self.embedding_decode(hidden_states, labels is not None) + + if labels is None: + # [b s h] + return logits + + loss = self.compute_language_model_loss(labels, logits) + + return loss diff --git a/nemo/collections/nlp/models/language_modeling/megatron/griffin/recurrent_layer.py b/nemo/collections/nlp/models/language_modeling/megatron/griffin/recurrent_layer.py new file mode 100755 index 000000000000..8263f54889a0 --- /dev/null +++ b/nemo/collections/nlp/models/language_modeling/megatron/griffin/recurrent_layer.py @@ -0,0 +1,106 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from typing import Union + +from megatron.core.transformer.identity_op import IdentityFuncOp, IdentityOp +from megatron.core.transformer.module import MegatronModule +from megatron.core.transformer.spec_utils import ModuleSpec, build_module +from megatron.core.transformer.transformer_config import TransformerConfig +from megatron.core.utils import make_viewless_tensor +from torch import Tensor + + +@dataclass +class RecurrentBlockSubmodules: + input_layernorm: Union[ModuleSpec, type] = IdentityOp + recurrent_layer: Union[ModuleSpec, type] = IdentityOp + recurrent_bda: Union[ModuleSpec, type] = IdentityFuncOp + + pre_mlp_layernorm: Union[ModuleSpec, type] = IdentityOp + mlp: Union[ModuleSpec, type] = IdentityOp + mlp_bda: Union[ModuleSpec, type] = IdentityFuncOp + + +class RecurrentBlock(MegatronModule): + def __init__( + self, + config: TransformerConfig, + submodules: RecurrentBlockSubmodules, + layer_idx=None, + residual_in_fp32=False, + **kwargs, + ): + """ + Top level Mamba Layer + """ + super().__init__(config) + self.config = config + self.residual_in_fp32 = residual_in_fp32 + self.hidden_dropout = config.hidden_dropout + + self.input_layernorm = build_module(submodules.input_layernorm, dim=self.config.hidden_size) + + self.recurrent_layer = build_module( + submodules.recurrent_layer, + self.config, + width=self.config.hidden_size, + num_heads=self.config.num_attention_heads, + lru_width=self.config.hidden_size, + conv1d_temporal_width=4, + final_w_init_variance_scale=1.0, + ) + + self.recurrent_bda = build_module(submodules.recurrent_bda) + + self.pre_mlp_layernorm = build_module(submodules.pre_mlp_layernorm, dim=self.config.hidden_size) + + self.mlp = build_module(submodules.mlp, config=self.config) + + self.mlp_bda = build_module(submodules.mlp_bda) + + def forward(self, hidden_states: Tensor, attention_mask: Tensor, inference_params=None, **kwargs): + + residual = hidden_states + + # Optional Input Layer norm + input_layernorm_output = self.input_layernorm(hidden_states) + + # Reccurent block. + recurrent_output_with_bias = self.recurrent_layer(input_layernorm_output) + + hidden_states = self.recurrent_bda(self.training, self.config.bias_dropout_fusion)( + recurrent_output_with_bias, residual, self.hidden_dropout + ) + + # Residual connection. + residual = hidden_states + + # Optional Layer norm post the cross-attention. + pre_mlp_layernorm_output = self.pre_mlp_layernorm(hidden_states) + + # MLP. + mlp_output_with_bias = self.mlp(pre_mlp_layernorm_output) + + hidden_states = self.mlp_bda(self.training, self.config.bias_dropout_fusion)( + mlp_output_with_bias, residual, self.hidden_dropout + ) + + output = make_viewless_tensor(inp=hidden_states, requires_grad=hidden_states.requires_grad, keep_graph=True) + + return output, None + + def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs): + return self.mixer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs) diff --git a/nemo/collections/nlp/models/language_modeling/megatron/griffin/recurrent_module.py b/nemo/collections/nlp/models/language_modeling/megatron/griffin/recurrent_module.py new file mode 100755 index 000000000000..6cd9eeaadc63 --- /dev/null +++ b/nemo/collections/nlp/models/language_modeling/megatron/griffin/recurrent_module.py @@ -0,0 +1,321 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from dataclasses import dataclass +from typing import Union + +import einops +import torch +from accelerated_scan.ref import scan +from causal_conv1d import causal_conv1d_fn +from einops import rearrange +from megatron.core.fusions.fused_bias_gelu import bias_gelu_impl +from megatron.core.transformer.identity_op import IdentityOp +from megatron.core.transformer.module import MegatronModule +from megatron.core.transformer.spec_utils import ModuleSpec, build_module +from megatron.core.transformer.transformer_config import TransformerConfig +from torch import nn + + +# Class copied from https://github.com/google-deepmind/recurrentgemma +class BlockDiagonalLinear(nn.Module): + """Block-diagonal linear layer.""" + + def __init__( + self, width: int, num_blocks: int, w_init_variance_scale: float = 1.0, + ): + """Initializes the BlockDiagonalLinear. + + Args: + width: The number of dimensions of the input and output. + num_blocks: The number of diagonal blocks in the layer. + w_init_variance_scale: A parameters that scales the variance of the + initialization of the weights. + """ + super().__init__() + self.width = width + self.num_blocks = num_blocks + self.w_init_variance_scale = w_init_variance_scale + self.block_width = self.width // self.num_blocks + + # Parameters. + self.w = nn.Parameter(torch.zeros([self.num_blocks, self.block_width, self.block_width])) + self.b = nn.Parameter(torch.zeros([self.num_blocks, self.block_width])) + + # Initialization. + self.w_init_(self.w) + + def w_init_(self, w: torch.Tensor) -> None: + """Initializes the weight `w` of the layer.""" + std = math.sqrt(self.w_init_variance_scale / self.block_width) + torch.nn.init.normal_(w, mean=0.0, std=std) + + def forward(self, x): + """Calls the BlockDiagonalLinear.""" + # Split x to blocks. + x = einops.rearrange(x, "... (h i) -> ... h i", h=self.num_blocks) + + # Linear layer over each block + bias. + y = torch.einsum("... h i, h i j -> ... h j", x, self.w) + self.b + + # Flatten the output. + return einops.rearrange(y, "... h j -> ... (h j)", h=self.num_blocks) + + +# Class copied from https://github.com/google-deepmind/recurrentgemma + + +def rnn_scan( + x, a, reset, h0, +): + """Runs the recurrence of a linear RNN. + + Args: + x: The input sequence. + a: The diagonal of the recurrence matrix `A`. + reset: Indicator of document boundaries, e.g. when to reset the hidden + state of the RNN. + h0: The initial hidden state. + + Returns: + The output of the linear recurrence. + """ + + assert x.ndim == 3 + assert a.shape == x.shape[-a.ndim :] + assert a.dtype == x.dtype + assert type(a) is type(x) + + # Multiply `a` by the reset. + a = a * (1 - reset)[..., None] + + if x.shape[1] == 1: + # Using scan in sampling mode. + y = a * h0[:, None] + x + else: + # Using scan in linear mode. + x = x.permute(0, 2, 1) + a = a.permute(0, 2, 1) + x = x.contiguous() + a = a.contiguous() + y = scan(a.float(), x.float()).type_as(x) + y = y.permute(0, 2, 1) + return y, None + + +# Class copied from https://github.com/google-deepmind/recurrentgemma + + +def rnn_param_init(*, width: int, min_rad: float, max_rad: float, transform: str = "softplus",) -> torch.Tensor: + """Initializes the `A` parameter of the RG-LRU uniformly on a ring.""" + unif = torch.rand(width) + # Proportional to area in a ring. + a_real = 0.5 * torch.log(unif * (max_rad ** 2 - min_rad ** 2) + min_rad ** 2 + 1e-8) + + if transform == "softplus": + # Inverse transform. + return torch.log(torch.exp(-a_real) - 1.0) + else: + raise NotImplementedError() + + +# Class copied from https://github.com/google-deepmind/recurrentgemma + + +class RGLRU(nn.Module): + """A Real-Gated Linear Recurrent Unit (RG-LRU) layer.""" + + def __init__( + self, width: int, num_heads: int, w_init_variance_scale: float = 1.0, + ): + """Initializes the RG-LRU. + + Args: + width: The number of dimensions of the input and output. + num_heads: The number of diagonal blocks in the input and A gate layers. + w_init_variance_scale: Initialization parameter for the + BlockDiagonalLinear layers of the gates. See the `BlockDiagonalLinear` + layer for details. + """ + super().__init__() + self.width = width + self.num_heads = num_heads + self.w_init_variance_scale = w_init_variance_scale + + # Parameters and layers. + self.a_param = nn.Parameter(self.a_param_init) + self.input_gate = BlockDiagonalLinear( + width=self.width, num_blocks=self.num_heads, w_init_variance_scale=w_init_variance_scale, + ) + self.a_gate = BlockDiagonalLinear( + width=self.width, num_blocks=self.num_heads, w_init_variance_scale=self.w_init_variance_scale + ) + + @property + def a_param_init(self) -> torch.Tensor: + """Initializes the `A` parameter of the RG-LRU.""" + return rnn_param_init(width=self.width, min_rad=0.9, max_rad=0.999) + + def __call__( + self, x, segment_pos, prev_h, + ): + """Calls the RG-LRU. + + Args: + x: Sequence of input activations. + segment_pos: Position of each token in the sequence. + prev_h: The previous hidden state of the RG-LRU. + + Returns: + Output of the block together with the updated hidden state. + """ + for param in self.parameters(): + param.data_ptr() + + bs, l, d = x.shape + assert segment_pos.shape == (bs, l) + reset = (segment_pos == 0).type(torch.int32) + prev_h = torch.zeros(size=(bs, d)) if prev_h is None else prev_h + prev_h = prev_h.cuda() + # Gates for x and a. + gate_x = torch.sigmoid(self.input_gate(x)) + gate_a = torch.sigmoid(self.a_gate(x)) + + # Compute the parameter `A` of the recurrence. + log_a = -8.0 * gate_a * nn.functional.softplus(self.a_param) + a = torch.exp(log_a) + + # Gate the input. + gated_x = x * gate_x + + # Apply gamma normalization to the input. + multiplier = torch.sqrt((1 - torch.exp(2 * log_a)) + 1e-6) + multiplier = reset[..., None] + (1 - reset)[..., None] * multiplier + normalized_x = gated_x * multiplier.type(x.dtype) + + y, last_h = rnn_scan(x=normalized_x, a=a, reset=reset, h0=prev_h,) + + return y, last_h + + +class Conv1D(MegatronModule): + def __init__(self, config, width, temporal_width): + super().__init__(config=config) + self.config = config + self.width = width + self.temporal_width = temporal_width + self.conv_1d = nn.Conv1d( + in_channels=width, + out_channels=width, + bias=True, + kernel_size=temporal_width, + groups=width, + padding=temporal_width - 1, + ) + + def forward( + self, x, segment_pos=None, prev_x=None, + ): + x = x.permute(0, 2, 1) + output = causal_conv1d_fn( + x=x, weight=rearrange(self.conv_1d.weight, "d 1 w -> d w"), bias=self.conv_1d.bias, activation=None, + ).permute(0, 2, 1) + return output, None + + +@dataclass +class RecurrentLayerSubmodules: + linear_in: Union[ModuleSpec, type] = IdentityOp + linear_out: Union[ModuleSpec, type] = IdentityOp + conv_1d: Union[ModuleSpec, type] = IdentityOp + rg_lru: Union[ModuleSpec, type] = IdentityOp + + +def gelu(x: torch.Tensor) -> torch.Tensor: + """Returns the GELU activation function with the same approximation as JAX.""" + return nn.functional.gelu(x, approximate="tanh") + + +class RecurrentLayer(MegatronModule): + def __init__( + self, + config: TransformerConfig, + submodules: RecurrentLayerSubmodules, + layer_idx=None, + residual_in_fp32=False, + **kwargs, + ): + """ + Top level Mamba Layer + """ + super().__init__(config) + self.config = config + self.residual_in_fp32 = residual_in_fp32 + + self.linear_in = build_module( + submodules.linear_in, + self.config.hidden_size, + self.config.hidden_size * 2, + config=self.config, + init_method=self.config.init_method, + gather_output=False, + bias=self.config.add_bias_linear, + skip_bias_add=True, + is_expert=False, + ) + + self.linear_out = build_module( + submodules.linear_out, + self.config.hidden_size, + self.config.hidden_size, + config=self.config, + init_method=self.config.init_method, + bias=self.config.add_bias_linear, + skip_bias_add=True, + is_expert=False, + input_is_parallel=True, + ) + + self.conv_1d = build_module( + submodules.conv_1d, config=self.config, width=self.config.hidden_size, temporal_width=4 + ) + + self.rg_lru = build_module( + submodules.rg_lru, width=self.config.hidden_size, num_heads=self.config.num_attention_heads + ) + + def forward(self, hidden_states, attention_mask=None, rotary_pos_emb=None): + + segment_pos = torch.arange(hidden_states.shape[0]).unsqueeze(0).repeat(hidden_states.shape[1], 1).cuda() + in_intermidiate_parallel, in_bias_parallel = self.linear_in(hidden_states) + + x_bias_parallel, y_bias_parallel = in_bias_parallel.chunk(2, dim=-1) + x_intermidiate_parallel, y_intermidiate_parallel = in_intermidiate_parallel.chunk(2, dim=-1) + + y = bias_gelu_impl(y_intermidiate_parallel, y_bias_parallel) + + x = x_intermidiate_parallel + x_bias_parallel + x = x.permute(1, 0, 2) + + x, _ = self.conv_1d(x=x, segment_pos=segment_pos, prev_x=None) + + x, _ = self.rg_lru(x=x, segment_pos=segment_pos, prev_h=None,) + + x = x.permute(1, 0, 2) + + x = x * y + x_intermidiate_parallel, x_bias_parallel = self.linear_out(x) + + return x_intermidiate_parallel, x_bias_parallel diff --git a/nemo/collections/nlp/models/language_modeling/megatron_gpt_sft_model.py b/nemo/collections/nlp/models/language_modeling/megatron_gpt_sft_model.py index 892a87189880..32b22df22d2c 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_gpt_sft_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_gpt_sft_model.py @@ -290,6 +290,7 @@ def _build_dataset(self, data_cfg, is_train=True): pad_to_max_length=data_cfg.get('pad_to_max_length', False), index_mapping_dir=data_cfg.get('index_mapping_dir', None), prompt_template=data_cfg.get('prompt_template', None), + ceil_to_power_2=data_cfg.get('ceil_to_power_2', False), virtual_tokens=self.virtual_tokens, tokens_to_generate=data_cfg.get( 'tokens_to_generate', 0 diff --git a/nemo/collections/nlp/models/language_modeling/megatron_griffin_model.py b/nemo/collections/nlp/models/language_modeling/megatron_griffin_model.py new file mode 100644 index 000000000000..20ad376b8f98 --- /dev/null +++ b/nemo/collections/nlp/models/language_modeling/megatron_griffin_model.py @@ -0,0 +1,96 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +from omegaconf.dictconfig import DictConfig +from pytorch_lightning.trainer.trainer import Trainer + +from nemo.collections.nlp.models.language_modeling.megatron.griffin.griffin_model import GriffinModel +from nemo.collections.nlp.models.language_modeling.megatron_gpt_model import MegatronGPTModel +from nemo.collections.nlp.modules.common.megatron.utils import ApexGuardDefaults + +try: + + HAVE_MEGATRON_CORE = True + +except (ImportError, ModuleNotFoundError): + TransformerConfig = ApexGuardDefaults + HAVE_MEGATRON_CORE = False + + +class MegatronGriffinModel(MegatronGPTModel): + """ + Megatron Griffin pretraining. + """ + + def __init__(self, cfg: DictConfig, trainer: Trainer): + if not HAVE_MEGATRON_CORE: + raise ImportError( + "megatron-core was not found. Please see the NeMo README for installation instructions: https://github.com/NVIDIA/NeMo#megatron-gpt." + ) + + # build the transformer config + # TODO: add type hint once pip package is out + + self.vocab_size = cfg.get('vocab_size', 256000) + self.cfg = cfg + super().__init__(cfg=cfg, trainer=trainer) + self.mcore_gpt = True + + def model_provider_func(self, pre_process, post_process): + model = GriffinModel( + config=self.transformer_config, + max_sequence_length=self.cfg.get('encoder_seq_length', 512), + vocab_size=self.cfg.get('vocab_size', 256000), + position_embedding_type=self.cfg.get('position_embedding_type', 'rope'), + logits_soft_cap=self.cfg.get('logits_soft_cap', 30.0), + rotary_percent=self.cfg.get('rotary_percentage', 0.5), + rotary_base=self.cfg.get('rotary_base', 10000), + ) + + return model + + def forward(self, input_ids, position_ids=None, attention_mask=None, labels=None): + + output_tensor = self.model( + input_ids=input_ids, position_ids=position_ids, attention_mask=attention_mask, labels=labels + ) + return output_tensor + + def build_transformer_config(self): + transformer_config = super().build_transformer_config() + transformer_config.gated_linear_unit = self.cfg.get('gated_linear_unit', True) + transformer_config.layernorm_zero_centered_gamma = self.cfg.get('layernorm_zero_centered_gamma', True) + + return transformer_config + + def on_validation_epoch_end(self): + + averaged_loss = torch.tensor(0.0, dtype=torch.float32).cuda() + return averaged_loss + + def sharded_state_dict(self, prefix: str = ''): + return None + + def _reset_activation_checkpointing_args(self): + return + + def _restore_activation_checkpointing_args(self): + return + + def _reset_sequence_parallelism_args(self): + return + + def _restore_sequence_parallelism_args(self): + return diff --git a/nemo/collections/nlp/models/language_modeling/megatron_griffin_sft_model.py b/nemo/collections/nlp/models/language_modeling/megatron_griffin_sft_model.py new file mode 100644 index 000000000000..c53d231b2719 --- /dev/null +++ b/nemo/collections/nlp/models/language_modeling/megatron_griffin_sft_model.py @@ -0,0 +1,55 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from omegaconf import DictConfig +from omegaconf.dictconfig import DictConfig +from pytorch_lightning.trainer.trainer import Trainer + +from nemo.collections.nlp.models.language_modeling.megatron_base_model import MegatronBaseModel +from nemo.collections.nlp.models.language_modeling.megatron_gpt_sft_model import MegatronGPTSFTModel +from nemo.collections.nlp.models.language_modeling.megatron_griffin_model import MegatronGriffinModel + +try: + HAVE_APEX = True +except (ImportError, ModuleNotFoundError): + HAVE_APEX = False + +__all__ = ['MegatronGriffinSFTModel'] + + +class MegatronGriffinSFTModel(MegatronGPTSFTModel, MegatronGriffinModel): + """ + Megatron Griffin Supervised Fine-Tuning + """ + + def __init__(self, cfg: DictConfig, trainer: Trainer): + if not HAVE_APEX: + raise ImportError( + "Apex was not found. Please see the NeMo README for installation instructions: https://github.com/NVIDIA/NeMo#megatron-gpt." + ) + + super().__init__(cfg, trainer=trainer) + self.mcore_gpt = True + self.validation_param_sync_overlap = self.cfg.get('validation_param_sync_overlap', False) + + def _reset_activation_checkpointing_args(self): + pass + + def on_validation_model_zero_grad(self) -> None: + """ + Skip gradient zeroing at the beginning of validation routine. + This is needed when overlapping the AllGather of the updated parameters with the following valdation step. + """ + if not self.validation_param_sync_overlap: + MegatronBaseModel.on_validation_model_zero_grad(self) diff --git a/nemo/collections/nlp/modules/common/text_generation_strategy.py b/nemo/collections/nlp/modules/common/text_generation_strategy.py index e29bb3423c4a..c6e96e94e6ff 100644 --- a/nemo/collections/nlp/modules/common/text_generation_strategy.py +++ b/nemo/collections/nlp/modules/common/text_generation_strategy.py @@ -330,6 +330,78 @@ def prepare_batch_at_step( return batch, tensor_shape +class GriffinModelTextGenerationStrategy(TextGenerationStrategy): + def __init__(self, model): + super().__init__(model) + self.forward_model = self.model.model + + def clip_max_len(self, maxlen: int) -> int: + """ clip the max len based on the LM model max sequence length""" + + # for positional embedding types that allow length extrapolation, don't clip the max length + if self.model.cfg.get("position_embedding_type", "learned_absolute") == "learned_absolute": + if maxlen > self.model.cfg.encoder_seq_length + 1: + maxlen = self.model.cfg.encoder_seq_length + 1 + return maxlen + + def init_batch(self, context_tokens: torch.Tensor, context_length: int, compute_attention_mask: bool): + """initialize the batch data before the inference steps.""" + # Move to GPU. + tokenizer = self.model.tokenizer + tokens = context_tokens.contiguous().cuda() + # Get the attention mask and postition ids. + self.attention_mask, _, self.position_ids = get_ltor_masks_and_position_ids( + tokens, + tokenizer.eos_id, + self.model.cfg.get('reset_position_ids', False), + self.model.cfg.get('reset_attention_mask', False), + self.model.cfg.get('eod_mask_loss', False), + compute_attention_mask=compute_attention_mask, + ) + self.attention_mask = None + + def prepare_batch_at_step( + self, + tokens: torch.Tensor, + maxlen: int, + micro_batch_size: int, + step: int, + context_length: int, + compute_attention_mask: bool = False, + ) -> Tuple[List[torch.Tensor], List[int]]: + """ + generate the batch used in inference for each of the steps + """ + # types2use = None + # Allocate memory for the entire context. + + tokens2use = tokens + + """Prepare batch for each of the inference steps""" + attention_mask_repeat = None + + batch = [tokens2use, attention_mask_repeat] + tensor_shape = [tokens2use.shape[1], micro_batch_size, self.model.cfg.hidden_size] + return batch, (tensor_shape, context_length) + + def forward_step(self, batch, tensor_shape_and_context_length): + tensor_shape, context_length = tensor_shape_and_context_length + fwd_bwd_function = get_forward_backward_func() + + output_tensor = fwd_bwd_function( + forward_step_func=self.model.get_forward_output_only_func(), + data_iterator=iter([batch,]), + model=[self.forward_model], + num_microbatches=get_num_microbatches(), + forward_only=True, + seq_length=tensor_shape[0], + micro_batch_size=tensor_shape[1], + ) + + output_tensor[0]['logits'] = output_tensor[0]['logits'][:, :context_length, :] + return output_tensor + + def neva_process_prompts(prompt, tokenizer, multimodal_cfg, num_media_latents, conv_template): from nemo.collections.multimodal.data.neva.neva_dataset import ( DEFAULT_IMAGE_TOKEN, @@ -821,6 +893,7 @@ def model_inference_strategy_dispatcher(model, **args): from nemo.collections.nlp.models.language_modeling.megatron_gpt_prompt_learning_model import ( MegatronGPTPromptLearningModel, ) + from nemo.collections.nlp.models.language_modeling.megatron_griffin_model import MegatronGriffinModel from nemo.collections.nlp.models.language_modeling.megatron_retrieval_model import MegatronRetrievalModel from nemo.collections.nlp.models.language_modeling.megatron_retro_model import MegatronRetroModel from nemo.collections.nlp.modules.common.retro_inference_strategies import ( @@ -829,6 +902,8 @@ def model_inference_strategy_dispatcher(model, **args): RetroQAModelTextGenerationStrategy, ) + if isinstance(model, MegatronGriffinModel): + return GriffinModelTextGenerationStrategy(model) if isinstance(model, MegatronNevaModel): return NevaModelTextGenerationStrategy(model) if isinstance(model, MegatronGPTPromptLearningModel): diff --git a/requirements/requirements_nlp.txt b/requirements/requirements_nlp.txt index 46e82089f0ea..9fd75ad8a95a 100644 --- a/requirements/requirements_nlp.txt +++ b/requirements/requirements_nlp.txt @@ -1,4 +1,6 @@ +accelerated-scan boto3 +causal-conv1d>=1.2.0 einops faiss-cpu fasttext diff --git a/scripts/checkpoint_converters/convert_griffin_hf_to_nemo.py b/scripts/checkpoint_converters/convert_griffin_hf_to_nemo.py new file mode 100644 index 000000000000..44435cc21135 --- /dev/null +++ b/scripts/checkpoint_converters/convert_griffin_hf_to_nemo.py @@ -0,0 +1,174 @@ +import os +from argparse import ArgumentParser + +import torch +from omegaconf.omegaconf import OmegaConf +from transformers import AutoModelForCausalLM + +from nemo.collections.nlp.models.language_modeling.megatron_griffin_model import MegatronGriffinModel +from nemo.collections.nlp.parts.megatron_trainer_builder import MegatronLMPPTrainerBuilder +from nemo.collections.nlp.parts.utils_funcs import torch_dtype_from_precision +from nemo.utils import logging + + +def get_args(): + parser = ArgumentParser() + parser.add_argument( + "--hparams_file", + type=str, + default=f"{os.path.dirname(__file__)}/../../examples/nlp/language_modeling/conf/megatron_griffin_config.yaml", + required=False, + help="Path config for restoring. It's created during training and may need to be modified during restore if restore environment is different than training. Ex: /raid/nemo_experiments/megatron_gpt/hparams.yaml", + ) + parser.add_argument("--output_path", type=str, default=None, required=True, help="Path to output .nemo file.") + parser.add_argument("--input_name_or_path", type=str, default="google/recurrentgemma-2b") + parser.add_argument( + "--precision", type=str, default="32", choices=["bf16", "32"], help="Precision for checkpoint weights saved" + ) + args = parser.parse_args() + return args + + +def convert(args): + + nemo_config = OmegaConf.load(args.hparams_file) + nemo_config.trainer["precision"] = args.precision + + logging.info(f"Loading checkpoint from HF: `{args.input_name_or_path}`") + hf_model = AutoModelForCausalLM.from_pretrained(args.input_name_or_path, device_map="auto") + + trainer = MegatronLMPPTrainerBuilder(nemo_config).create_trainer() + + nemo_model_from_hf = MegatronGriffinModel(nemo_config.model, trainer) + + new_state_dict = {} + + new_state_dict['model.embedding.word_embeddings.weight'] = hf_model.state_dict()['model.embed_tokens.weight'] + new_state_dict['model.decoder.final_layernorm.weight'] = hf_model.state_dict()['model.final_norm.weight'] + + for l in range(nemo_config.model.num_layers): + print(f"Converting Layer {l}") + print("********************") + + new_state_dict[f'model.decoder.layers.{l}.mlp.linear_fc1.weight'] = torch.cat( + [ + hf_model.state_dict()[f'model.layers.{l}.mlp_block.gate_proj.weight'], + hf_model.state_dict()[f'model.layers.{l}.mlp_block.up_proj.weight'], + ] + ) + new_state_dict[f'model.decoder.layers.{l}.mlp.linear_fc1.bias'] = torch.cat( + [ + hf_model.state_dict()[f'model.layers.{l}.mlp_block.gate_proj.bias'], + hf_model.state_dict()[f'model.layers.{l}.mlp_block.up_proj.bias'], + ] + ).flatten() + new_state_dict[f'model.decoder.layers.{l}.mlp.linear_fc2.weight'] = hf_model.state_dict()[ + f'model.layers.{l}.mlp_block.down_proj.weight' + ] + new_state_dict[f'model.decoder.layers.{l}.mlp.linear_fc2.bias'] = hf_model.state_dict()[ + f'model.layers.{l}.mlp_block.down_proj.bias' + ] + new_state_dict[f'model.decoder.layers.{l}.mlp.linear_fc1._extra_state'] = nemo_model_from_hf.state_dict()[ + f'model.decoder.layers.{l}.mlp.linear_fc1._extra_state' + ] + new_state_dict[f'model.decoder.layers.{l}.mlp.linear_fc2._extra_state'] = nemo_model_from_hf.state_dict()[ + f'model.decoder.layers.{l}.mlp.linear_fc2._extra_state' + ] + + new_state_dict[f'model.decoder.layers.{l}.mlp.linear_fc1.layer_norm_weight'] = hf_model.state_dict()[ + f'model.layers.{l}.channel_pre_norm.weight' + ] + + if l % 3 == 2: + new_state_dict[f'model.decoder.layers.{l}.self_attention.linear_proj.weight'] = hf_model.state_dict()[ + f'model.layers.{l}.temporal_block.o_proj.weight' + ] + new_state_dict[f'model.decoder.layers.{l}.self_attention.linear_proj.bias'] = hf_model.state_dict()[ + f'model.layers.{l}.temporal_block.o_proj.bias' + ] + new_state_dict[ + f'model.decoder.layers.{l}.self_attention.linear_qkv.layer_norm_weight' + ] = hf_model.state_dict()[f'model.layers.{l}.temporal_pre_norm.weight'] + new_state_dict[f'model.decoder.layers.{l}.self_attention.linear_qkv.weight'] = torch.cat( + [ + hf_model.state_dict()[f'model.layers.{l}.temporal_block.q_proj.weight'], + hf_model.state_dict()[f'model.layers.{l}.temporal_block.k_proj.weight'], + hf_model.state_dict()[f'model.layers.{l}.temporal_block.v_proj.weight'], + ] + ) + new_state_dict[f'model.decoder.layers.{l}.self_attention.linear_qkv.bias'] = torch.zeros( + new_state_dict[f'model.decoder.layers.{l}.self_attention.linear_qkv.weight'].shape[0] + ) + new_state_dict[ + f'model.decoder.layers.{l}.self_attention.linear_proj._extra_state' + ] = nemo_model_from_hf.state_dict()[f'model.decoder.layers.{l}.self_attention.linear_proj._extra_state'] + new_state_dict[ + f'model.decoder.layers.{l}.self_attention.linear_qkv._extra_state' + ] = nemo_model_from_hf.state_dict()[f'model.decoder.layers.{l}.self_attention.linear_qkv._extra_state'] + + else: + + new_state_dict[ + f'model.decoder.layers.{l}.recurrent_layer.linear_in.layer_norm_weight' + ] = hf_model.state_dict()[f'model.layers.{l}.temporal_pre_norm.weight'] + new_state_dict[f'model.decoder.layers.{l}.recurrent_layer.linear_in.weight'] = torch.cat( + [ + hf_model.state_dict()[f'model.layers.{l}.temporal_block.linear_x.weight'], + hf_model.state_dict()[f'model.layers.{l}.temporal_block.linear_y.weight'], + ] + ) + new_state_dict[f'model.decoder.layers.{l}.recurrent_layer.linear_in.bias'] = torch.cat( + [ + hf_model.state_dict()[f'model.layers.{l}.temporal_block.linear_x.bias'], + hf_model.state_dict()[f'model.layers.{l}.temporal_block.linear_y.bias'], + ] + ) + + new_state_dict[f'model.decoder.layers.{l}.recurrent_layer.linear_out.weight'] = hf_model.state_dict()[ + f'model.layers.{l}.temporal_block.linear_out.weight' + ] + new_state_dict[f'model.decoder.layers.{l}.recurrent_layer.linear_out.bias'] = hf_model.state_dict()[ + f'model.layers.{l}.temporal_block.linear_out.bias' + ] + + new_state_dict[f'model.decoder.layers.{l}.recurrent_layer.conv_1d.conv_1d.weight'] = hf_model.state_dict()[ + f'model.layers.{l}.temporal_block.conv_1d.weight' + ] + new_state_dict[f'model.decoder.layers.{l}.recurrent_layer.conv_1d.conv_1d.bias'] = hf_model.state_dict()[ + f'model.layers.{l}.temporal_block.conv_1d.bias' + ] + + new_state_dict[f'model.decoder.layers.{l}.recurrent_layer.rg_lru.a_param'] = hf_model.state_dict()[ + f'model.layers.{l}.temporal_block.rg_lru.recurrent_param' + ] + new_state_dict[f'model.decoder.layers.{l}.recurrent_layer.rg_lru.input_gate.w'] = hf_model.state_dict()[ + f'model.layers.{l}.temporal_block.rg_lru.input_gate_weight' + ] + new_state_dict[f'model.decoder.layers.{l}.recurrent_layer.rg_lru.input_gate.b'] = hf_model.state_dict()[ + f'model.layers.{l}.temporal_block.rg_lru.input_gate_bias' + ] + new_state_dict[f'model.decoder.layers.{l}.recurrent_layer.rg_lru.a_gate.w'] = hf_model.state_dict()[ + f'model.layers.{l}.temporal_block.rg_lru.recurrent_gate_weight' + ] + new_state_dict[f'model.decoder.layers.{l}.recurrent_layer.rg_lru.a_gate.b'] = hf_model.state_dict()[ + f'model.layers.{l}.temporal_block.rg_lru.recurrent_gate_bias' + ] + + new_state_dict[ + f'model.decoder.layers.{l}.recurrent_layer.linear_in._extra_state' + ] = nemo_model_from_hf.state_dict()[f'model.decoder.layers.{l}.recurrent_layer.linear_in._extra_state'] + new_state_dict[ + f'model.decoder.layers.{l}.recurrent_layer.linear_out._extra_state' + ] = nemo_model_from_hf.state_dict()[f'model.decoder.layers.{l}.recurrent_layer.linear_out._extra_state'] + + nemo_model_from_hf.load_state_dict(new_state_dict, strict=True) + dtype = torch_dtype_from_precision(args.precision) + nemo_model_from_hf = nemo_model_from_hf.to(dtype=dtype) + + nemo_model_from_hf.save_to(args.output_path) + logging.info(f'Griffin NeMo model saved to: {args.output_path}') + + +if __name__ == '__main__': + args = get_args() + convert(args) diff --git a/scripts/checkpoint_converters/convert_griffin_nemo_to_hf.py b/scripts/checkpoint_converters/convert_griffin_nemo_to_hf.py new file mode 100644 index 000000000000..265af9e55cbd --- /dev/null +++ b/scripts/checkpoint_converters/convert_griffin_nemo_to_hf.py @@ -0,0 +1,147 @@ +import os +from argparse import ArgumentParser + +from omegaconf.omegaconf import OmegaConf +from transformers import AutoConfig, RecurrentGemmaModel + +from nemo.collections.nlp.models.language_modeling.megatron_griffin_model import MegatronGriffinModel +from nemo.collections.nlp.parts.megatron_trainer_builder import MegatronLMPPTrainerBuilder +from nemo.collections.nlp.parts.utils_funcs import torch_dtype_from_precision +from nemo.utils import logging + + +def get_args(): + parser = ArgumentParser() + parser.add_argument( + "--hparams_file", + type=str, + default=f"{os.path.dirname(__file__)}/../../examples/nlp/language_modeling/conf/megatron_griffin_config.yaml", + required=False, + help="Path config for restoring. It's created during training and may need to be modified during restore if restore environment is different than training. Ex: /raid/nemo_experiments/megatron_gpt/hparams.yaml", + ) + parser.add_argument("--output_path", type=str, default=None, required=True, help="Path to output .nemo file.") + parser.add_argument("--input_path", type=str, default=None, required=True) + parser.add_argument( + "--precision", type=str, default="32", choices=["bf16", "32"], help="Precision for checkpoint weights saved" + ) + args = parser.parse_args() + return args + + +def convert(args): + + nemo_config = OmegaConf.load(args.hparams_file) + nemo_config.trainer["precision"] = args.precision + + logging.info(f"Loading checkpoint from NeMo: `{args.input_path}`") + + trainer = MegatronLMPPTrainerBuilder(nemo_config).create_trainer() + + nemo_model = MegatronGriffinModel.restore_from(args.input_path, trainer=trainer) + hf_config = AutoConfig.from_pretrained("google/recurrentgemma-2b") + + # NeMo doesn't support LM Head for Griffin yet, so RecurrentGemmaModel is used instead of AutoModelForCausalLM + hf_model = RecurrentGemmaModel._from_config(hf_config) + + new_state_dict = {} + + new_state_dict['embed_tokens.weight'] = nemo_model.state_dict()['model.embedding.word_embeddings.weight'] + new_state_dict['final_norm.weight'] = nemo_model.state_dict()['model.decoder.final_layernorm.weight'] + + for l in range(nemo_config.model.num_layers): + print(f"Converting Layer {l}") + print("********************") + + ( + new_state_dict[f'layers.{l}.mlp_block.gate_proj.weight'], + new_state_dict[f'layers.{l}.mlp_block.up_proj.weight'], + ) = nemo_model.state_dict()[f'model.decoder.layers.{l}.mlp.linear_fc1.weight'].chunk(2) + ( + new_state_dict[f'layers.{l}.mlp_block.gate_proj.bias'], + new_state_dict[f'layers.{l}.mlp_block.up_proj.bias'], + ) = nemo_model.state_dict()[f'model.decoder.layers.{l}.mlp.linear_fc1.bias'].chunk(2) + new_state_dict[f'layers.{l}.mlp_block.down_proj.weight'] = nemo_model.state_dict()[ + f'model.decoder.layers.{l}.mlp.linear_fc2.weight' + ] + new_state_dict[f'layers.{l}.mlp_block.down_proj.bias'] = nemo_model.state_dict()[ + f'model.decoder.layers.{l}.mlp.linear_fc2.bias' + ] + + new_state_dict[f'layers.{l}.channel_pre_norm.weight'] = nemo_model.state_dict()[ + f'model.decoder.layers.{l}.mlp.linear_fc1.layer_norm_weight' + ] + + if l % 3 == 2: + + new_state_dict[f'layers.{l}.temporal_block.o_proj.weight'] = nemo_model.state_dict()[ + f'model.decoder.layers.{l}.self_attention.linear_proj.weight' + ] + new_state_dict[f'layers.{l}.temporal_block.o_proj.bias'] = nemo_model.state_dict()[ + f'model.decoder.layers.{l}.self_attention.linear_proj.bias' + ] + new_state_dict[f'layers.{l}.temporal_pre_norm.weight'] = nemo_model.state_dict()[ + f'model.decoder.layers.{l}.self_attention.linear_qkv.layer_norm_weight' + ] + ( + new_state_dict[f'layers.{l}.temporal_block.q_proj.weight'], + new_state_dict[f'layers.{l}.temporal_block.k_proj.weight'], + new_state_dict[f'layers.{l}.temporal_block.v_proj.weight'], + ) = nemo_model.state_dict()[f'model.decoder.layers.{l}.self_attention.linear_qkv.weight'].split( + [2560, 256, 256] + ) + + else: + + new_state_dict[f'layers.{l}.temporal_pre_norm.weight'] = nemo_model.state_dict()[ + f'model.decoder.layers.{l}.recurrent_layer.linear_in.layer_norm_weight' + ] + ( + new_state_dict[f'layers.{l}.temporal_block.linear_x.weight'], + new_state_dict[f'layers.{l}.temporal_block.linear_y.weight'], + ) = nemo_model.state_dict()[f'model.decoder.layers.{l}.recurrent_layer.linear_in.weight'].chunk(2) + ( + new_state_dict[f'layers.{l}.temporal_block.linear_x.bias'], + new_state_dict[f'layers.{l}.temporal_block.linear_y.bias'], + ) = nemo_model.state_dict()[f'model.decoder.layers.{l}.recurrent_layer.linear_in.bias'].chunk(2) + + new_state_dict[f'layers.{l}.temporal_block.linear_out.weight'] = nemo_model.state_dict()[ + f'model.decoder.layers.{l}.recurrent_layer.linear_out.weight' + ] + new_state_dict[f'layers.{l}.temporal_block.linear_out.bias'] = nemo_model.state_dict()[ + f'model.decoder.layers.{l}.recurrent_layer.linear_out.bias' + ] + + new_state_dict[f'layers.{l}.temporal_block.conv_1d.weight'] = nemo_model.state_dict()[ + f'model.decoder.layers.{l}.recurrent_layer.conv_1d.conv_1d.weight' + ] + new_state_dict[f'layers.{l}.temporal_block.conv_1d.bias'] = nemo_model.state_dict()[ + f'model.decoder.layers.{l}.recurrent_layer.conv_1d.conv_1d.bias' + ] + + new_state_dict[f'layers.{l}.temporal_block.rg_lru.recurrent_param'] = nemo_model.state_dict()[ + f'model.decoder.layers.{l}.recurrent_layer.rg_lru.a_param' + ] + new_state_dict[f'layers.{l}.temporal_block.rg_lru.input_gate_weight'] = nemo_model.state_dict()[ + f'model.decoder.layers.{l}.recurrent_layer.rg_lru.input_gate.w' + ] + new_state_dict[f'layers.{l}.temporal_block.rg_lru.input_gate_bias'] = nemo_model.state_dict()[ + f'model.decoder.layers.{l}.recurrent_layer.rg_lru.input_gate.b' + ] + new_state_dict[f'layers.{l}.temporal_block.rg_lru.recurrent_gate_weight'] = nemo_model.state_dict()[ + f'model.decoder.layers.{l}.recurrent_layer.rg_lru.a_gate.w' + ] + new_state_dict[f'layers.{l}.temporal_block.rg_lru.recurrent_gate_bias'] = nemo_model.state_dict()[ + f'model.decoder.layers.{l}.recurrent_layer.rg_lru.a_gate.b' + ] + + hf_model.load_state_dict(new_state_dict, strict=True) + dtype = torch_dtype_from_precision(args.precision) + hf_model = hf_model.to(dtype=dtype) + + hf_model.save_pretrained(args.output_path) + logging.info(f'Full HF model model saved to: {args.output_path}') + + +if __name__ == '__main__': + args = get_args() + convert(args) From 93e326592c67f17af5c97fb4bd1f69371b861055 Mon Sep 17 00:00:00 2001 From: Ao Tang Date: Fri, 3 May 2024 12:10:59 -0400 Subject: [PATCH 17/73] Llama3 Conversion Script Update (#9089) * Add conversion script and CI test * fix llama2 vocab_file * typo --- .github/workflows/cicd-main.yml | 25 +++++++++++++++ .../convert_llama_hf_to_nemo.py | 32 ++++++++++++++++--- .../convert_llama_nemo_to_hf.py | 2 +- 3 files changed, 54 insertions(+), 5 deletions(-) diff --git a/.github/workflows/cicd-main.yml b/.github/workflows/cicd-main.yml index df631443e7f7..8389efff07ad 100644 --- a/.github/workflows/cicd-main.yml +++ b/.github/workflows/cicd-main.yml @@ -271,6 +271,31 @@ jobs: - uses: "NVIDIA/NeMo/.github/actions/cancel-workflow@main" if: "failure()" + L2_Community_LLM_Checkpoints_tests_Llama3: + needs: [cicd-test-container-setup] + runs-on: self-hosted-azure + container: + image: nemoci.azurecr.io/nemo_container_${{ github.run_id }} + options: + # --user 0:128 + --device=/dev/nvidia0 + --gpus all + --shm-size=8g + --env TRANSFORMERS_OFFLINE=0 + --env HYDRA_FULL_ERROR=1 + --volume /mnt/datadrive/TestData:/home/TestData + steps: + - name: Checkout repository + uses: actions/checkout@v2 + - run: | + CUDA_VISIBLE_DEVICES=0 python scripts/checkpoint_converters/convert_llama_hf_to_nemo.py \ + --input_name_or_path=/home/TestData/nlp/megatron_llama/llama3-ci-hf \ + --output_path=/home/TestData/nlp/megatron_llama/llama3-ci-hf/llama3_ci.nemo \ + --precision=16 + rm -f /home/TestData/nlp/megatron_llama/llama3-ci-hf/llama3_ci.nemo + - uses: "NVIDIA/NeMo/.github/actions/cancel-workflow@main" + if: "failure()" + L2_Community_LLM_Checkpoints_tests_StarCoder: needs: [cicd-test-container-setup] runs-on: self-hosted-azure diff --git a/scripts/checkpoint_converters/convert_llama_hf_to_nemo.py b/scripts/checkpoint_converters/convert_llama_hf_to_nemo.py index c8ccf50aa05f..e1dc00c77439 100644 --- a/scripts/checkpoint_converters/convert_llama_hf_to_nemo.py +++ b/scripts/checkpoint_converters/convert_llama_hf_to_nemo.py @@ -27,7 +27,7 @@ import torch from omegaconf import OmegaConf from pytorch_lightning.trainer.trainer import Trainer -from transformers import LlamaForCausalLM, LlamaTokenizer +from transformers import AutoTokenizer, LlamaForCausalLM, LlamaTokenizer from nemo.collections.nlp.models.language_modeling.megatron_gpt_model import MegatronGPTModel from nemo.collections.nlp.parts.nlp_overrides import ( @@ -78,7 +78,19 @@ def load_config(args, llama_config): nemo_config.num_query_groups = llama_config['num_key_value_heads'] nemo_config.use_cpu_initialization = True nemo_config.activation = 'fast-swiglu' - nemo_config.tokenizer.model = llama_config['tokenizer_model'] + + # Tokenizer config + if 'tokenizer_model' in llama_config: + nemo_config.tokenizer.model = llama_config['tokenizer_model'] + else: + # Llama3 uses converted TikToken Tokenizer + tokenizer_dict = { + 'library': 'huggingface', + 'type': args.input_name_or_path, + 'use_fast': True, + } + nemo_config.tokenizer = tokenizer_dict + if llama_config['rope_scaling'] is not None: if llama_config['rope_scaling']['type'] == 'linear': nemo_config['seq_len_interpolation_factor'] = llama_config['rope_scaling']['factor'] @@ -98,9 +110,12 @@ def load_config(args, llama_config): def convert(args): logging.info(f"loading checkpoint {args.input_name_or_path}") model = LlamaForCausalLM.from_pretrained(args.input_name_or_path) - tokenizer = LlamaTokenizer.from_pretrained(args.input_name_or_path) hf_config = vars(model.config) - hf_config['tokenizer_model'] = str(tokenizer.vocab_file) + if os.path.exists(f'{args.input_name_or_path}/tokenizer.model'): + tokenizer = LlamaTokenizer.from_pretrained(args.input_name_or_path) + hf_config['tokenizer_model'] = str(tokenizer.vocab_file) + else: + tokenizer = AutoTokenizer.from_pretrained(args.input_name_or_path) print(f"hf_config: {hf_config}") print("named parameters:") for name, param in model.named_parameters(): @@ -274,6 +289,15 @@ def convert(args): model._save_restore_connector = NLPSaveRestoreConnector() + # We make sure that the tokenizer can be instantiated later regardless of args.input_name_or_path + if 'tokenizer_model' not in hf_config: + if hf_config['num_hidden_layers'] == 32: + model.cfg.tokenizer.update(type='meta-llama/Meta-Llama-3-8B') + elif hf_config['num_hidden_layers'] == 80: + model.cfg.tokenizer.update(type='meta-llama/Meta-Llama-3-70B') + else: + logging.warning("Unexpected model config for Llama3. Tokenizer config has not been modified.") + # cast to target precision and disable cpu init dtype = torch_dtype_from_precision(precision) model = model.to(dtype=dtype) diff --git a/scripts/checkpoint_converters/convert_llama_nemo_to_hf.py b/scripts/checkpoint_converters/convert_llama_nemo_to_hf.py index 159676f8b58e..8da15148dfd8 100644 --- a/scripts/checkpoint_converters/convert_llama_nemo_to_hf.py +++ b/scripts/checkpoint_converters/convert_llama_nemo_to_hf.py @@ -263,5 +263,5 @@ def replace_hf_weights_and_tokenizer( args.hf_output_tokenizer, ) else: - logging.info("`hf-in-path` and/or `hf-out-path` not provided, not generating full HF model.") + logging.info("`hf_input_path` and/or `hf_output_path` not provided, not generating full HF model.") logging.info(f".bin file is saved to {args.output_path}") From 805e5ec595dd217a3c3b39577e0e998b2ce38570 Mon Sep 17 00:00:00 2001 From: Jason Date: Fri, 3 May 2024 12:23:31 -0400 Subject: [PATCH 18/73] Update radtts.py (#9097) * Update radtts.py Signed-off-by: Jason * Update Jenkinsfile Signed-off-by: Jason * Update cicd-main.yml Signed-off-by: Jason * Update cicd-main.yml Signed-off-by: Jason * Update Jenkinsfile Signed-off-by: Jason * Update cicd-main.yml Signed-off-by: Jason --------- Signed-off-by: Jason --- .github/workflows/cicd-main.yml | 3 ++- examples/tts/radtts.py | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/.github/workflows/cicd-main.yml b/.github/workflows/cicd-main.yml index 8389efff07ad..ad6a1faf78ae 100644 --- a/.github/workflows/cicd-main.yml +++ b/.github/workflows/cicd-main.yml @@ -78,7 +78,7 @@ jobs: run: | # Pull base PyTorch container docker pull nvcr.io/nvidia/pytorch:24.02-py3 - docker run --device=/dev/nvidia0 --gpus all --shm-size=8g --env TRANSFORMERS_OFFLINE=0 --env HYDRA_FULL_ERROR=1 --volume ${{ github.workspace }}/${{ github.run_id }}:/workspace --volume /mnt/datadrive/TestData:/home/TestData nvcr.io/nvidia/pytorch:24.02-py3 /bin/bash -c ' + docker run --device=/dev/nvidia0 --gpus all --shm-size=8g --env TRANSFORMERS_OFFLINE=0 --env HYDRA_FULL_ERROR=1 --env PYTHONUNBUFFERED=1 --volume ${{ github.workspace }}/${{ github.run_id }}:/workspace --volume /mnt/datadrive/TestData:/home/TestData nvcr.io/nvidia/pytorch:24.02-py3 /bin/bash -c ' set -x # PyTorch version @@ -6224,6 +6224,7 @@ jobs: L2_TTS_Fast_dev_runs_1_RADTTS: needs: [cicd-test-container-setup] runs-on: self-hosted-azure + timeout-minutes: 15 container: image: nemoci.azurecr.io/nemo_container_${{ github.run_id }} options: diff --git a/examples/tts/radtts.py b/examples/tts/radtts.py index 7dbdaedced03..09bf69a2d6e5 100644 --- a/examples/tts/radtts.py +++ b/examples/tts/radtts.py @@ -68,7 +68,7 @@ def main(cfg): lr_logger = pl.callbacks.LearningRateMonitor() epoch_time_logger = LogEpochTimeCallback() trainer.callbacks.extend([lr_logger, epoch_time_logger]) - trainer.fit(model.cuda()) + trainer.fit(model) if __name__ == '__main__': From f28773f14f6bdc8d0f7f7bee1da17aea44c2f803 Mon Sep 17 00:00:00 2001 From: mikolajblaz Date: Fri, 3 May 2024 19:29:56 +0200 Subject: [PATCH 19/73] Implement DistributedCheckpointIO (#9016) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Introduce DistributedCheckpointIO Signed-off-by: Mikołaj Błaż * Fix DistCkptIO usage Signed-off-by: Mikołaj Błaż * Use NeMo logger Signed-off-by: Mikołaj Błaż * [DCIO] Fix save_to dist ckpt path Signed-off-by: Mikołaj Błaż * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Use dist ckpt flag in all methods Signed-off-by: Mikołaj Błaż * Improve error msg Signed-off-by: Mikołaj Błaż * Add dist ckpt unit tests Signed-off-by: Mikołaj Błaż * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix load_checkpoint Signed-off-by: Mikołaj Błaż * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: Mikołaj Błaż * Fix auto-issues Signed-off-by: Mikołaj Błaż * Fix ckpt_dir var Signed-off-by: Mikołaj Błaż * Restore skipping behavior The fix from prevent-duplicated-checkpoints is required to skip the checkpoints Signed-off-by: Mikołaj Błaż * Fix steps on single-GPU machine Signed-off-by: Mikołaj Błaż * Add docs Signed-off-by: Mikołaj Błaż * Apply black Signed-off-by: Mikołaj Błaż * Fix num steps in tests Signed-off-by: Mikołaj Błaż * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: Mikołaj Błaż * Use dist-ckpt for Bert Signed-off-by: Mikołaj Błaż * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix load checkpoint return val Signed-off-by: Mikołaj Błaż * Use dist-ckpt based on sharded_state_dict Signed-off-by: Mikołaj Błaż * Use correct checkpoint_io Signed-off-by: Mikołaj Błaż * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Mikołaj Błaż Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- .../conf/megatron_gpt_config.yaml | 4 +- .../nlp/parts/megatron_trainer_builder.py | 9 +- nemo/collections/nlp/parts/nlp_overrides.py | 84 +++++++--------- nemo/utils/callbacks/dist_ckpt_io.py | 85 ++++++++++++++++ tests/core/test_dist_ckpt.py | 98 +++++++++++++++++++ 5 files changed, 227 insertions(+), 53 deletions(-) create mode 100644 nemo/utils/callbacks/dist_ckpt_io.py create mode 100644 tests/core/test_dist_ckpt.py diff --git a/examples/nlp/language_modeling/conf/megatron_gpt_config.yaml b/examples/nlp/language_modeling/conf/megatron_gpt_config.yaml index ea37237f2eac..57c82726ae11 100755 --- a/examples/nlp/language_modeling/conf/megatron_gpt_config.yaml +++ b/examples/nlp/language_modeling/conf/megatron_gpt_config.yaml @@ -150,8 +150,8 @@ model: fsdp_grad_reduce_dtype: 32 # Gradient reduction data type. fsdp_sharded_checkpoint: False # Store and load FSDP shared checkpoint. - # PyTorch distributed checkpoint - torch_distributed_checkpoint: False # Set to True to use PyTorch distributed checkpoint format. + # Distributed checkpoint format + dist_ckpt_format: 'zarr' # Set to 'torch_dist' to use PyTorch distributed checkpoint format. ## Activation Checkpointing # NeMo Megatron supports 'selective' activation checkpointing where only the memory intensive part of attention is checkpointed. diff --git a/nemo/collections/nlp/parts/megatron_trainer_builder.py b/nemo/collections/nlp/parts/megatron_trainer_builder.py index 6b9763a53414..ad184157abc3 100644 --- a/nemo/collections/nlp/parts/megatron_trainer_builder.py +++ b/nemo/collections/nlp/parts/megatron_trainer_builder.py @@ -31,6 +31,7 @@ PipelineMixedPrecisionPlugin, ) from nemo.utils import logging +from nemo.utils.callbacks.dist_ckpt_io import DistributedCheckpointIO class MegatronTrainerBuilder: @@ -80,7 +81,6 @@ def _training_strategy(self) -> Union[NLPDDPStrategy, NLPFSDPStrategy]: find_unused_parameters=False, nccl_communicator_config_path=self.cfg.model.get('nccl_communicator_config_path', None), sharp=self.cfg.model.get('sharp', False), - torch_dist_ckpt=self.cfg.model.get('torch_distributed_checkpoint', False), ) def _grad_scaler(self) -> GradScaler: @@ -127,6 +127,13 @@ def _plugins(self) -> list: if self.cfg.get('cluster_type', None) == 'BCP': plugins.append(TorchElasticEnvironment()) + # Use dist-ckt for non-FSDP MCore models + use_dist_ckpt = not self.cfg.model.get('fsdp', False) and ( + self.cfg.model.get('mcore_gpt', False) or self.cfg.model.get('mcore_bert', False) + ) + if use_dist_ckpt: + plugins.append(DistributedCheckpointIO(self.cfg.model.get('dist_ckpt_format', 'zarr'))) + return plugins def create_trainer(self, callbacks=None) -> Trainer: diff --git a/nemo/collections/nlp/parts/nlp_overrides.py b/nemo/collections/nlp/parts/nlp_overrides.py index 0a030759fe9b..b477c64a7510 100644 --- a/nemo/collections/nlp/parts/nlp_overrides.py +++ b/nemo/collections/nlp/parts/nlp_overrides.py @@ -25,6 +25,7 @@ import pytorch_lightning as pl import torch +from lightning_fabric.plugins import TorchCheckpointIO from lightning_fabric.utilities.cloud_io import get_filesystem from lightning_fabric.utilities.optimizer import _optimizer_to_device from omegaconf import OmegaConf @@ -54,6 +55,8 @@ from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy from torch.nn.parallel import DistributedDataParallel +from nemo.utils.get_rank import is_global_rank_zero + try: from torch.cuda.amp.grad_scaler import _refresh_per_optimizer_state except ImportError: @@ -68,7 +71,6 @@ from nemo.core.optim import MainParamsOptimizerWrapper from nemo.core.optim.optimizers import init_optimizer_states from nemo.utils import AppState, logging -from nemo.utils.get_rank import is_global_rank_zero from nemo.utils.model_utils import ckpt_to_dir, inject_model_parallel_rank, uninject_model_parallel_rank try: @@ -104,6 +106,7 @@ from megatron.core.tensor_parallel.layers import param_is_not_tensor_parallel_duplicate from megatron.core.transformer.module import Float16Module as MCoreFloat16Module from megatron.core.transformer.transformer_layer import TransformerLayer as MCoreTransformerLayer + from nemo.utils.callbacks.dist_ckpt_io import DistributedCheckpointIO HAVE_MEGATRON_CORE = True @@ -178,7 +181,6 @@ def __init__( no_ddp_communication_hook: bool = False, nccl_communicator_config_path: Optional[str] = None, sharp: bool = False, - torch_dist_ckpt: bool = False, **kwargs: Union[Any, Dict[str, Any]], ) -> None: if not HAVE_APEX: @@ -195,7 +197,6 @@ def __init__( self.no_ddp_communication_hook = no_ddp_communication_hook self.nccl_communicator_config_path = nccl_communicator_config_path self.sharp = sharp - self.torch_dist_ckpt = torch_dist_ckpt def setup(self, trainer: "pl.Trainer") -> None: """ @@ -350,10 +351,7 @@ def save_checkpoint( called on every rank and internally does the rank checking. """ # check if using distributed checkpointing - if ( - hasattr(self.lightning_module, 'sharded_state_dict') - and self.lightning_module.sharded_state_dict() is not None - ): + if self.use_distributed_checkpointing: assert ( len(checkpoint['optimizer_states']) == 1 ), "Currently only support checkpointing 1 distributed optimizer per time!" @@ -371,16 +369,10 @@ def save_checkpoint( logging.info(f'Distributed checkpoint at path {checkpoint_dir} already exists, skipping saving') return - if is_global_rank_zero(): - fs.makedirs(checkpoint_dir, exist_ok=True) - # remove device state_dict checkpoint['state_dict'] = OrderedDict([]) - sharded_strategy = ('torch_dist', 1) if self.torch_dist_ckpt else ('zarr', 1) - dist_checkpointing.save( - sharded_state_dict=checkpoint, checkpoint_dir=checkpoint_dir, sharded_strategy=sharded_strategy - ) + self.checkpoint_io.save_checkpoint(checkpoint, ckpt_to_dir(filepath), storage_options=storage_options) else: # PTL override to accomodate model parallel checkpoints filepath = inject_model_parallel_rank(filepath) @@ -390,10 +382,7 @@ def save_checkpoint( # PTL 2.2 supports non strict loading of the ckpt with the strict arg (https://github.com/Lightning-AI/pytorch-lightning/pull/19404) def load_model_state_dict(self, checkpoint: Mapping[str, Any], strict: bool = True) -> None: # if using distributed checkpointing, the state dict logic is at the model level - if ( - hasattr(self.lightning_module, 'sharded_state_dict') - and self.lightning_module.sharded_state_dict() is not None - ): + if self.use_distributed_checkpointing: return # legacy state dict logic, does not use megatron core @@ -442,11 +431,7 @@ def load_checkpoint(self, checkpoint_path: Union[str, Path]) -> Dict[str, Any]: fs = get_filesystem(checkpoint_path) # Check if using distributed checkpointing - if ( - hasattr(self.lightning_module, 'sharded_state_dict') - and self.lightning_module.sharded_state_dict() is not None - ): - + if self.use_distributed_checkpointing: # Distributed checkpoints must be directories. if not fs.isdir(checkpoint_path): raise ValueError(f'Distributed checkpoints should be a directory. Found: {checkpoint_path}.') @@ -458,16 +443,7 @@ def load_checkpoint(self, checkpoint_path: Union[str, Path]) -> Dict[str, Any]: # after dist_checkpointing.load, sharded tensors will be replaced with tensors checkpoint['state_dict'] = sharded_state_dict checkpoint['optimizer_states'] = [self.optimizer_sharded_state_dict()] - - if self.torch_dist_ckpt: - sharded_strategy = ('torch_dist', 1) - else: - sharded_strategy = tensorstore.TensorStoreLoadShardedStrategy(load_directly_on_device=True) - checkpoint = dist_checkpointing.load( - sharded_state_dict=checkpoint, checkpoint_dir=checkpoint_path, sharded_strategy=sharded_strategy - ) - - return checkpoint + return self.checkpoint_io.load_checkpoint(checkpoint_path, sharded_state_dict=checkpoint) # Legacy model parallel checkpointing logic, does not use megatron core else: @@ -480,12 +456,8 @@ def load_checkpoint(self, checkpoint_path: Union[str, Path]) -> Dict[str, Any]: def remove_checkpoint(self, filepath: Union[str, Path]) -> None: # check if filepath is a distributed checkpoint - if ( - hasattr(self.lightning_module, 'sharded_state_dict') - and self.lightning_module.sharded_state_dict() is not None - ): - if self.is_global_zero: - shutil.rmtree(ckpt_to_dir(filepath), ignore_errors=True) + if self.use_distributed_checkpointing and self.is_global_zero: + self.checkpoint_io.remove_checkpoint(ckpt_to_dir(filepath)) # legacy checkpoint logic, does not use megatron core else: @@ -496,6 +468,25 @@ def remove_checkpoint(self, filepath: Union[str, Path]) -> None: logging.info(f'Removing checkpoint: {filepath}') self.checkpoint_io.remove_checkpoint(filepath) + @property + def use_distributed_checkpointing(self): + has_dist_ckpt_io = HAVE_MEGATRON_CORE and isinstance(self.checkpoint_io, DistributedCheckpointIO) + has_sharded_state_dict = ( + hasattr(self.lightning_module, 'sharded_state_dict') + and self.lightning_module.sharded_state_dict() is not None + ) + if has_sharded_state_dict and not has_dist_ckpt_io: + logging.warning( + 'Distributed checkpoints requires DistributedCheckpointIO plugin to be used. Setting up a default now.' + ) + self.checkpoint_io = DistributedCheckpointIO(self.lightning_module.cfg.get('dist_ckpt_format', 'zarr')) + if not has_sharded_state_dict and has_dist_ckpt_io: + logging.warning( + 'DistributedCheckpointIO configured but should not be used. Reverting back to TorchCheckpointIO' + ) + self.checkpoint_io = TorchCheckpointIO() + return has_sharded_state_dict + @property def distributed_sampler_kwargs(self): app_state = AppState() @@ -887,14 +878,8 @@ def dummy(): if model.trainer.strategy.launcher is not None: model.trainer.strategy.launcher.launch(dummy, trainer=model.trainer) model.trainer.strategy.setup_environment() - sharded_strategy = ( - ('torch_dist', 1) if model.cfg.get("torch_distributed_checkpoint", False) else ('zarr', 1) - ) - dist_checkpointing.save( - sharded_state_dict=sharded_state_dict, - checkpoint_dir=dist_ckpt_dir, - sharded_strategy=sharded_strategy, - ) + checkpoint_io = DistributedCheckpointIO(model.cfg.get('dist_ckpt_format', 'zarr')) + checkpoint_io.save_checkpoint(sharded_state_dict, dist_ckpt_dir) else: @@ -1177,9 +1162,8 @@ def dummy(): tmp_model_weights_ckpt = os.path.join(tmpdir, self.model_weights_ckpt) tmp_model_weights_dir = os.path.splitext(tmp_model_weights_ckpt)[0] assert os.path.isdir(tmp_model_weights_dir), f'Expected {tmp_model_weights_dir} to be a directory.' - checkpoint = dist_checkpointing.load( - sharded_state_dict=checkpoint, checkpoint_dir=tmp_model_weights_dir - ) + checkpoint_io = DistributedCheckpointIO(conf.get('dist_ckpt_format', 'zarr')) + checkpoint = checkpoint_io.load_checkpoint(tmp_model_weights_dir, sharded_state_dict=checkpoint) instance.on_load_checkpoint(checkpoint) if hasattr(instance, 'setup_transformer_engine_tp_groups'): instance.setup_transformer_engine_tp_groups() diff --git a/nemo/utils/callbacks/dist_ckpt_io.py b/nemo/utils/callbacks/dist_ckpt_io.py new file mode 100644 index 000000000000..7dff9b458a0d --- /dev/null +++ b/nemo/utils/callbacks/dist_ckpt_io.py @@ -0,0 +1,85 @@ +import shutil +from typing import Any, Dict, Optional + +from lightning_fabric.plugins import CheckpointIO +from lightning_fabric.utilities.cloud_io import get_filesystem +from lightning_fabric.utilities.types import _PATH +from megatron.core import dist_checkpointing +from megatron.core.dist_checkpointing.strategies import tensorstore + +from nemo.utils import logging + + +class DistributedCheckpointIO(CheckpointIO): + """ CheckpointIO for a distributed checkpoint format. + + Args: + save_ckpt_format (str): Distributed checkpoint format to use for checkpoint saving. + """ + + def __init__(self, save_ckpt_format: str): + super().__init__() + self.save_ckpt_format = save_ckpt_format + + self.save_sharded_strategy = self.determine_dist_ckpt_save_strategy() + + def save_checkpoint(self, checkpoint: Dict[str, Any], path: _PATH, storage_options: Optional[Any] = None) -> None: + """ Saves a distributed checkpoint. Creates the checkpoint root directory if doesn't exist. + + Args: + checkpoint (Dict[str, Any]): sharded state dict to save + path (_PATH): checkpoint directory + storage_options (Any, optional): Optional parameters when saving the checkpoint + """ + fs = get_filesystem(path) + fs.makedirs(path, exist_ok=True) + + dist_checkpointing.save( + sharded_state_dict=checkpoint, checkpoint_dir=path, sharded_strategy=self.save_sharded_strategy + ) + + def load_checkpoint( + self, path: _PATH, map_location: Optional[Any] = None, sharded_state_dict: Dict[str, Any] = None + ) -> Dict[str, Any]: + """ Loads a distributed checkpoint. + + Args: + path (_PATH): checkpoint directory + map_location (Any, optional): required to be None in this implementation + sharded_state_dict (Dict[str, Any], optional): state dict which + defines the loading procedure for the distributed checkpoint. + Defaults to None to comply with the CheckpointIO interface, + but it's a required argument. + + Returns: + Dist[str, Any]: loaded checkpoint. + """ + if sharded_state_dict is None: + raise ValueError('DistributedCheckpointIO requires passing sharded_state_dict argument to load_checkpoint') + if map_location is not None: + raise ValueError('DistributedCheckpointIO doesnt handle map_location argument') + + if self.save_ckpt_format == 'zarr': + sharded_strategy = tensorstore.TensorStoreLoadShardedStrategy(load_directly_on_device=True) + else: + sharded_strategy = None + + return dist_checkpointing.load( + sharded_state_dict=sharded_state_dict, checkpoint_dir=path, sharded_strategy=sharded_strategy + ) + + def remove_checkpoint(self, path: _PATH) -> None: + """ Remove a distributed checkpoint. + + Due to potentially large number of files, the implementation remove the whole directory at once. + """ + shutil.rmtree(path, ignore_errors=True) + + def determine_dist_ckpt_save_strategy(self): + """ Determine the saving strategy based on storage config. + + For now only decides the checkpoint format. + """ + save_strategy = (self.save_ckpt_format, 1) + logging.info(f'Using {save_strategy} dist-ckpt save strategy.') + return save_strategy diff --git a/tests/core/test_dist_ckpt.py b/tests/core/test_dist_ckpt.py new file mode 100644 index 000000000000..b6dc5ca89d3e --- /dev/null +++ b/tests/core/test_dist_ckpt.py @@ -0,0 +1,98 @@ +import os +import types +from pathlib import Path + +import pytest +import pytorch_lightning as pl +import torch +from lightning_fabric.plugins import TorchCheckpointIO +from pytorch_lightning.demos.boring_classes import BoringModel + +from nemo.collections.nlp.parts.nlp_overrides import NLPDDPStrategy +from nemo.utils.callbacks.dist_ckpt_io import DistributedCheckpointIO + + +class ExampleModel(BoringModel): + def on_validation_epoch_end(self) -> None: + self.log("val_loss", torch.tensor(1.0)) + + +class ExampleMCoreModel(ExampleModel): + def sharded_state_dict(self): + return {'a': 3} + + +class MockDistributedCheckpointIO(DistributedCheckpointIO): + def __init__(self, save_ckpt_format): + super().__init__(save_ckpt_format) + self.save_checkpoint_called_args = None + + def save_checkpoint(self, *args, **kwargs) -> None: + self.save_checkpoint_called_args = args, kwargs + + +class MockTorchCheckpointIO(TorchCheckpointIO): + def __init__(self): + super().__init__() + self.save_checkpoint_called_args = None + + def save_checkpoint(self, *args, **kwargs) -> None: + self.save_checkpoint_called_args = args, kwargs + + +def _get_last_checkpoint_dir(root_dir: Path, model: pl.LightningModule, suffix: str = '') -> Path: + steps = len(model.train_dataloader().dataset) * model.trainer.max_epochs // torch.distributed.get_world_size() + return root_dir / 'checkpoints' / f'epoch=1-step={steps}{suffix}' + + +class TestDistCkptIO: + @pytest.mark.run_only_on('GPU') + def test_dist_ckpt_io_called_for_mcore_models(self, tmp_path): + strategy = NLPDDPStrategy() + # skip optimizer sharded state creation: + strategy.optimizer_sharded_state_dict = types.MethodType( + lambda self, unsharded_optim_state: unsharded_optim_state, strategy + ) + checkpoint_io = MockDistributedCheckpointIO('xxx') + + test_trainer = pl.Trainer( + enable_checkpointing=True, + logger=False, + max_epochs=2, + strategy=strategy, + plugins=[checkpoint_io], + default_root_dir=tmp_path, + ) + model = ExampleMCoreModel() + test_trainer.fit(model) + + assert isinstance(test_trainer.strategy.checkpoint_io, MockDistributedCheckpointIO) + assert checkpoint_io.save_checkpoint_called_args is not None + (state_dict, path), _ = checkpoint_io.save_checkpoint_called_args + # Ckpt path doesn't contain the .ckpt suffix + assert path.name == _get_last_checkpoint_dir(tmp_path, model).name, len(test_trainer.strategy.parallel_devices) + + @pytest.mark.run_only_on('GPU') + def test_dist_ckpt_path_not_executed_for_non_core_models(self, tmp_path): + strategy = NLPDDPStrategy() + checkpoint_io = MockTorchCheckpointIO() + + test_trainer = pl.Trainer( + enable_checkpointing=True, + logger=False, + max_epochs=2, + strategy=strategy, + plugins=[checkpoint_io], + default_root_dir=tmp_path, + ) + model = ExampleModel() + test_trainer.fit(model) + + assert isinstance(test_trainer.strategy.checkpoint_io, MockTorchCheckpointIO) + if test_trainer.is_global_zero: + assert checkpoint_io.save_checkpoint_called_args is not None + (state_dict, path), _ = checkpoint_io.save_checkpoint_called_args + # Ckpt path *does* contain the .ckpt suffix + assert os.path.basename(path) == _get_last_checkpoint_dir(tmp_path, model, suffix='.ckpt').name + else: + assert checkpoint_io.save_checkpoint_called_args is None From c5a5a79ee154b1a7cd0caf4ef63b35b1ea24768d Mon Sep 17 00:00:00 2001 From: paul-gibbons <87940629+paul-gibbons@users.noreply.github.com> Date: Fri, 3 May 2024 11:37:31 -0700 Subject: [PATCH 20/73] Video Neva Pretraining + Inference Implementation (#9095) * video_neva pretrain * support video neva inference Signed-off-by: Vivian Chen * yaml update, adding media_type * yaml update, adding media_type * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * modify neva inference config Signed-off-by: Vivian Chen * modify based on review Signed-off-by: Vivian Chen * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * remove video test asset Signed-off-by: Vivian Chen * video_neva doc, describing config changes. Signed-off-by: paul-gibbons * Revert "video_neva doc, describing config changes." This reverts commit 1a02ccd3adf30e851b1f74b0780c4a785c92eb43. * vneva brief doc Signed-off-by: paul-gibbons * vneva doc update Signed-off-by: paul-gibbons * doc update Signed-off-by: paul-gibbons * Revert "doc update" This reverts commit 80af9a43a342fa3ab1c7a4f002694bb23fd2af91. * doc update Signed-off-by: paul-gibbons * Revert "doc update" This reverts commit 8c885c7633b8b04ebdf3ce8280f2c3bb54ed0f20. * doc update Signed-off-by: paul-gibbons * Revert "doc update" This reverts commit 94aba65911d518b083c9a238c8f02d06979ef1ec. * doc update Signed-off-by: paul-gibbons * add inference doc to docs, resolve review Signed-off-by: Vivian Chen * modify inference config for other mlm Signed-off-by: Vivian Chen --------- Signed-off-by: Vivian Chen Signed-off-by: paul-gibbons Co-authored-by: Vivian Chen Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- docs/source/multimodal/mllm/video_neva.rst | 134 +++++++++ .../neva/conf/llava_config.yaml | 1 + .../multimodal_llm/neva/conf/neva_config.yaml | 1 + .../neva/conf/neva_finetune.yaml | 1 + .../neva/conf/neva_inference.yaml | 5 +- .../multimodal_llm/neva/conf/neva_peft.yaml | 1 + .../neva/conf/video_neva_config.yaml | 222 ++++++++++++++ .../multimodal_llm/neva/eval/gradio_server.py | 6 +- .../multimodal_llm/neva/eval/vqa_science.py | 10 +- .../multimodal_llm/neva/neva_evaluation.py | 22 +- .../multimodal/data/neva/conversation.py | 1 + .../multimodal/data/neva/neva_dataset.py | 274 +++++++++++++++--- .../models/multimodal_llm/neva/neva_model.py | 3 +- nemo/collections/multimodal/parts/utils.py | 76 ++++- .../common/text_generation_strategy.py | 15 +- .../modules/common/text_generation_utils.py | 6 +- requirements/requirements_multimodal.txt | 1 + tutorials/multimodal/NeVA Tutorial.ipynb | 6 +- 18 files changed, 702 insertions(+), 83 deletions(-) create mode 100644 docs/source/multimodal/mllm/video_neva.rst create mode 100644 examples/multimodal/multimodal_llm/neva/conf/video_neva_config.yaml diff --git a/docs/source/multimodal/mllm/video_neva.rst b/docs/source/multimodal/mllm/video_neva.rst new file mode 100644 index 000000000000..b5831a45ab28 --- /dev/null +++ b/docs/source/multimodal/mllm/video_neva.rst @@ -0,0 +1,134 @@ +Video NeVA +========== + +Model Introduction +------------------ + +Video NeVa adds support for video modality in NeVa by representing video as multiple image frames. + +There is only a minor change done to :class:`~nemo.collections.multimodal.models.multimodal_llm.neva.neva_model.MegatronNevaModel` class in order to support pretraining on video input data. + +Representing video input as a series of images is done in :class:`~nemo.collections.multimodal.data.neva.TarOrFolderVideoLoader` class, using Decord which provides convenient video slicing methods. + + +Video Neva Configuration +^^^^^^^^^^^^^^^^^^^^^^^^ + +.. code-block:: yaml + + data: + media_type: video + splice_single_frame: null + num_frames: 8 + image_token_len: 256 + image_folder: null + video_folder: null + +- ``media_type``: If set to `video`, NeVa's dataloader goes through the additional preprocessing steps to represent the input video data as a series of image frames. +- ``splice_single_frame``: Can either be set as `first`, `middle` or `last`. This will result in only a single frame in that specific location of the video being selected. +- ``image_token_len``: The NeVa dataloader calculates `image_token_len` based on the height and width of the preprocessed image frame and the patch size of the CLIP model being used. + +.. code-block:: python + + image_token_len = (224 // 14) * (224 // 14) = 16 * 16 = 256 + +- ``num_frames``: This is used to select the number of image frames that will be used to represent the video. +- ``video_folder``: This specifies the directory where the video files are located. This follows the same format as NeVa's `image_folder`. + + + +Inference with Video NeVA +========================= + +We can run ``neva_evaluation.py`` located in ``NeMo/examples/multimodal/multimodal_llm/neva`` to generate inference results from the Video NeVA model. +Currently, video NeVA supports both image and video inference by changing the config attribute ``inference.media_type`` in ``NeMo/examples/multimodal/multimodal_llm/neva/conf/neva_inference.yaml`` to either ``image`` or ``video``, and adding the corresponding media path ``inference.media_base_path``. + +Inference with Pretrained Projectors with Base LM Model +------------------------------------------------------- + +An example of an inference script execution: + +For running video inference:: + + CUDA_DEVICE_MAX_CONNECTIONS=1 CUDA_VISIBLE_DEVICES=0,1,2,3 python3 /path/to/neva_evaluation.py \ + --config-path=/path/to/conf/ \ + --config-name=neva_inference.yaml \ + tensor_model_parallel_size=4 \ + pipeline_model_parallel_size=1 \ + neva_model_file=/path/to/projector/checkpoint \ + base_model_file=/path/to/base/lm/checkpoint \ + trainer.devices=4 \ + trainer.precision=bf16 \ + prompt_file=/path/to/prompt/file \ + inference.media_base_path=/path/to/videos \ + inference.media_type=video \ + output_file=/path/for/output/file/ \ + inference.temperature=0.2 \ + inference.top_k=0 \ + inference.top_p=0.9 \ + inference.greedy=False \ + inference.add_BOS=False \ + inference.all_probs=False \ + inference.repetition_penalty=1.2 \ + inference.insert_media_token=right \ + inference.tokens_to_generate=256 \ + quantization.algorithm=awq \ + quantization.enable=False + +Example format of ``.jsonl`` prompt_file:: + + {"video": "video_test.mp4", "text": "Can you describe the scene?", "category": "conv", "question_id": 0} + +input video file:: video_test.mp4 + +Output:: + + System + A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. + + User + Can you describe the scene?