From 11eb478ac088f2faa26c2174e2af202bff93cf44 Mon Sep 17 00:00:00 2001 From: zhehuaichen <139396994+zhehuaichen@users.noreply.github.com> Date: Tue, 8 Aug 2023 11:17:49 -0400 Subject: [PATCH] Merge heh and zhehuai's initial version of frozen am+llm (#5) * Merge heh and zhehuai's initial version of frozen am+llm The previous differences are summarized here: https://docs.google.com/document/d/1zNI4hC6vJtUfcHbrUSPaMuYWRBQdN_36H0P2NiBiuPY/edit This PR includes 1. Finish merging the model, dataset, and config code 2. Previous tests are still enabled and passed (prepare_llm_input, training_step, validation_step) 3. the example training script with LS960 has been run to make sure the training pipeline works The major remaining works are listed here https://docs.google.com/document/d/1o0AM7v4gcTQkPZjE0Vl9TTX4vYnGTrbXEFGWh0UhGlk/edit#bookmark=id.pzvdadt5oxyw --------- Co-authored-by: He Huang (Steve) <105218074+stevehuang52@users.noreply.github.com> Signed-off-by: zhehuaichen --- .../modularized_speech_gpt_config.yaml | 294 ++++- .../common/parts/preprocessing/collections.py | 194 ++++ nemo/collections/multimodal/__init__.py | 2 +- nemo/collections/multimodal/data/__init__.py | 13 + .../multimodal/data/audio_text_qa_dataset.py | 1005 +++++++++++++++++ .../collections/multimodal/models/__init__.py | 2 +- .../multimodal/models/speechllm_models.py | 757 ++++++------- .../multimodal/modules/__init__.py | 2 +- .../modules/speechllm_perception.py | 71 +- nemo/collections/multimodal/parts/__init__.py | 0 .../multimodal/parts/utils/data_utils.py | 26 + .../megatron_gpt_peft_models.py | 2 +- .../multimodal/test_speechllm_models.py | 116 +- workspace/run_sft.sh | 15 + workspace/run_sft_audio_lm.py | 129 +++ workspace/scripts/convert_hf_llama_to_nemo.sh | 16 + workspace/scripts/convert_llama_to_hf.sh | 11 + workspace/tools/convert_hf_llama_to_nemo.py | 263 +++++ workspace/tools/convert_llama2_to_hf.py | 310 +++++ 19 files changed, 2625 insertions(+), 603 deletions(-) create mode 100644 nemo/collections/multimodal/data/__init__.py create mode 100644 nemo/collections/multimodal/data/audio_text_qa_dataset.py create mode 100644 nemo/collections/multimodal/parts/__init__.py create mode 100644 nemo/collections/multimodal/parts/utils/data_utils.py create mode 100755 workspace/run_sft.sh create mode 100644 workspace/run_sft_audio_lm.py create mode 100755 workspace/scripts/convert_hf_llama_to_nemo.sh create mode 100755 workspace/scripts/convert_llama_to_hf.sh create mode 100644 workspace/tools/convert_hf_llama_to_nemo.py create mode 100644 workspace/tools/convert_llama2_to_hf.py diff --git a/examples/multimodel/conf/speechllm/modularized_speech_gpt_config.yaml b/examples/multimodel/conf/speechllm/modularized_speech_gpt_config.yaml index 1325f21c8849..00ad7aea2d82 100644 --- a/examples/multimodel/conf/speechllm/modularized_speech_gpt_config.yaml +++ b/examples/multimodel/conf/speechllm/modularized_speech_gpt_config.yaml @@ -1,4 +1,18 @@ -name: modularized_speech_gpt +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +name: megatron_gpt_peft_tuning trainer: devices: 1 @@ -8,15 +22,11 @@ trainer: logger: False # logger provided by exp_manager enable_checkpointing: False replace_sampler_ddp: False - max_epochs: 3 # min 25 recommended - max_steps: -1 # consumed_samples = global_step * micro_batch_size * data_parallel_size * accumulate_grad_batches + max_epochs: 9999 + max_steps: 20000 # consumed_samples = global_step * micro_batch_size * data_parallel_size * accumulate_grad_batches log_every_n_steps: 10 # frequency with which training steps are logged - val_check_interval: 1.0 # If is an int n > 1, will run val every n training steps, if a float 0.0 - 1.0 will run val every epoch fraction, e.g. 0.25 will run val every quarter epoch + val_check_interval: 200 # If is an int n > 1, will run val every n training steps, if a float 0.0 - 1.0 will run val every epoch fraction, e.g. 0.25 will run val every quarter epoch gradient_clip_val: 1.0 - resume_from_checkpoint: null # The path to a checkpoint file to continue the training, restores the whole state including the epoch, step, LR schedulers, apex, etc. - benchmark: False - - exp_manager: explicit_log_dir: null @@ -30,12 +40,13 @@ exp_manager: resume_ignore_no_checkpoint: True create_checkpoint_callback: True checkpoint_callback_params: - monitor: val_loss - save_top_k: 2 + monitor: validation_${model.data.validation_ds.metric.name} + save_top_k: 1 mode: min - save_nemo_on_train_end: True - filename: 'megatron_gpt_prompt_tune--{val_loss:.3f}-{step}' + save_nemo_on_train_end: True + filename: '${name}--{${exp_manager.checkpoint_callback_params.monitor}:.3f}-{step}-{consumed_samples}' model_parallel_size: ${model.tensor_model_parallel_size} + always_save_nemo: False save_best_model: True create_early_stopping_callback: True early_stopping_callback_params: @@ -45,28 +56,22 @@ exp_manager: patience: 10 verbose: True strict: False # Should be False to avoid a runtime error where EarlyStopping says monitor is unavailable, which sometimes happens with resumed training. - + model: - # LLM related stanza - precision: 16 seed: 1234 - nemo_path: ${name}.nemo # .nemo filename/absolute path to where the virtual prompt model parameters will be saved - # TODO(zhehuai): remove the following parameter which is not used - virtual_prompt_style: 'p-tuning' # one of 'prompt-tuning', 'p-tuning', or 'inference' tensor_model_parallel_size: 1 # intra-layer model parallelism pipeline_model_parallel_size: 1 # inter-layer model parallelism - global_batch_size: 8 - micro_batch_size: 4 - validation_global_batch_size: ${model.global_batch_size} - validation_micro_batch_size: ${model.micro_batch_size} - validation_drop_last: False - report_validation_metric: False - validation_metric: 'accuracy' + + pretrained_audio_model: stt_en_fastconformer_transducer_large - restore_path: null # Path to an existing p-tuned/prompt tuned .nemo model you wish to add new tasks to or run inference with - language_model_path: ??? # Path to the GPT language model .nemo file, always required - save_nemo_on_validation_end: True # Saves an inference ready .nemo file every time a checkpoint is saved during training. + global_batch_size: 128 + micro_batch_size: 4 + restore_from_path: ??? # Path to an existing .nemo model you wish to add new tasks to or run inference with + resume_from_checkpoint: null # The path to a checkpoint file to continue the training, restores the whole state including the epoch, step, LR schedulers, apex, etc. + save_nemo_on_validation_end: False # Saves an inference ready .nemo file every time a checkpoint is saved during training. + sync_batch_comm: False + megatron_amp_O2: False ## Sequence Parallelism # Makes tensor parallelism more memory efficient for LLMs (20B+) by parallelizing layer norms and dropout sequentially @@ -80,31 +85,208 @@ model: # of each chunk at the specified granularity # 'block' checkpoints the specified number of layers per pipeline stage at the specified granularity activations_checkpoint_num_layers: null # not used with 'selective' - - # TODO(zhehuai): support task template to support complexer SFT format - task_templates: null - # TODO(zhehuai): update the following value according to the ckpt - existing_tasks: ['boolq', 'intent_and_slot'] # List of tasks the model has already been p-tuned/prompt-tuned for, needed when a restore path is given - new_tasks: ['asr'] # List of new tasknames to be prompt-tuned - - # TODO(zhehuai): audio related stanza - perception: null - fixed_prompt_prefix: 'Can you transcribe the speech into a written format?' + answer_only_loss: True + gradient_as_bucket_view: False + + hidden_dropout: 0.0 + attention_dropout: 0.0 + ffn_dropout: 0.0 + + peft: + peft_scheme: "adapter" # can be either adapter,ia3, or ptuning + restore_from_path: null + + # Used for adapter peft training + adapter_tuning: + type: 'parallel_adapter' # this should be either 'parallel_adapter' or 'linear_adapter' + adapter_dim: 32 + adapter_dropout: 0.0 + norm_position: 'pre' # This can be set to 'pre' or 'post', 'pre' is normally what is used. + column_init_method: 'xavier' # IGNORED if linear_adapter is used, options: xavier, zero or normal + row_init_method: 'zero' # IGNORED if linear_adapter is used, options: xavier, zero or normal + norm_type: 'mixedfusedlayernorm' # IGNORED if layer_adapter is used, options are ['layernorm', 'mixedfusedlayernorm'] + + lora_tuning: + adapter_dim: 32 + adapter_dropout: 0.0 + column_init_method: 'xavier' # IGNORED if linear_adapter is used, options: xavier, zero or normal + row_init_method: 'zero' # IGNORED if linear_adapter is used, options: xavier, zero or normal + + # Used for p-tuning peft training + p_tuning: + virtual_tokens: 10 # The number of virtual tokens the prompt encoder should add at the start of the sequence + bottleneck_dim: 1024 # the size of the prompt encoder mlp bottleneck + embedding_dim: 1024 # the size of the prompt encoder embeddings + init_std: 0.023 + + perception: + matcher: + _target_: nemo.collections.asr.modules.ConformerEncoder + feat_in: 1024 + feat_out: -1 # you may set it if you need different output size other than the default d_model + n_layers: 2 + d_model: 512 + + # Sub-sampling parameters + subsampling: dw_striding # vggnet, striding, stacking or stacking_norm, dw_striding + subsampling_factor: 8 # must be power of 2 for striding and vggnet + subsampling_conv_channels: 256 # set to -1 to make it equal to the d_model + causal_downsampling: false + + # Reduction parameters: Can be used to add another subsampling layer at a given position. + # Having a 2x reduction will speedup the training and inference speech while keeping similar WER. + # Adding it at the end will give the best WER while adding it at the beginning will give the best speedup. + reduction: null # pooling, striding, or null + reduction_position: null # Encoder block index or -1 for subsampling at the end of encoder + reduction_factor: 1 + + # Feed forward module's params + ff_expansion_factor: 4 + + # Multi-headed Attention Module's params + self_attention_model: rel_pos # rel_pos or abs_pos + n_heads: 8 # may need to be lower for smaller d_models + # [left, right] specifies the number of steps to be seen from left and right of each step in self-attention + att_context_size: [-1, -1] # -1 means unlimited context + att_context_style: regular # regular or chunked_limited + xscaling: true # scales up the input embeddings by sqrt(d_model) + untie_biases: true # unties the biases of the TransformerXL layers + pos_emb_max_len: 5000 + + # Convolution module's params + conv_kernel_size: 9 + conv_norm_type: 'batch_norm' # batch_norm or layer_norm or groupnormN (N specifies the number of groups) + # conv_context_size can be"causal" or a list of two integers while conv_context_size[0]+conv_context_size[1]+1==conv_kernel_size + # null means [(kernel_size-1)//2, (kernel_size-1)//2], and 'causal' means [(kernel_size-1), 0] + conv_context_size: null + + ### regularization + dropout: 0.1 # The dropout used in most of the Conformer Modules + dropout_pre_encoder: 0.1 # The dropout used before the encoder + dropout_emb: 0.0 # The dropout used for embeddings + dropout_att: 0.1 # The dropout for multi-headed attention modules + + # set to non-zero to enable stochastic depth + stochastic_depth_drop_prob: 0.0 + stochastic_depth_mode: linear # linear or uniform + stochastic_depth_start_layer: 1 + + spec_augment: + _target_: nemo.collections.asr.modules.SpectrogramAugmentation + freq_masks: 2 # set to zero to disable it + time_masks: 10 # set to zero to disable it + freq_width: 27 + time_width: 0.05 + + # the following are read from the pretrained AM: + # output_dim: null + # encoder: null + # preprocessor: null data: - train_ds: [data/rte_train.jsonl,] - validation_ds: [data/rte_val.jsonl,] - add_eos: True - shuffle: True - num_workers: 8 - pin_memory: True - train_cache_data_path: null # the path to the train cache data - validation_cache_data_path: null # the path to the validation cache data - test_cache_data_path: null # the path to the test cache data - load_cache: False # whether to load from the cache data - max_seq_length: 1024 # filter out training and validation examples longer than 1024 tokens. Set to None will default to model's encoder length. - min_seq_length: 1 # filter out training and validation examples less than 1 token long. + train_ds: + # Example of how to specify paths to multiple datasets + # file_names: + # - /path/to/squad.jsonl + # - /path/to/mnli.jsonl + # - /path/to/boolq.jsonl + # Example of how each dataset is formatted + # {'input': 'John von Neumann\nVon Neumann made fundamental contributions .... Q: What did the math of artificial viscosity do?', 'output': 'smoothed the shock transition without sacrificing basic physics'} + file_names: ??? # Path to a list of JSONL files corresponding to the source data. + global_batch_size: ${model.global_batch_size} + micro_batch_size: ${model.micro_batch_size} + shuffle: True + num_workers: 0 + pin_memory: True + max_seq_length: 2048 + min_seq_length: 1 + drop_last: True + # Example of how to specify concat_sampling_probabilities + # concat_sampling_probabilities: + # - 0.5 + # - 0.25 + # - 0.25 + concat_sampling_probabilities: null # When providing a list of datasets, this arg defines the sampling probabilities from each dataset when strategy='random' + context_key: 'input' + label_key: 'output' + add_eos: True + add_sep: False + add_bos: False + separate_prompt_and_response_with_newline: False + truncation_field: "context" # Options: ['context', 'answer'] + index_mapping_dir: null # Path to a directory to write index mapping files. + prompt_template: "Q: {input}\nA: {output}" # fstring to use for assistant prompt. Example: "Q: {input}\nA: {output}" + # ASR configs + sample_rate: 16000 #${model.audio_encoder.preprocessor.sample_rate} + max_duration: 16.7 # it is set for LibriSpeech, you may need to update it for your dataset + min_duration: 0.1 + # tarred datasets + is_tarred: false + tarred_audio_filepaths: null + shuffle_n: 2048 + # bucketing params + bucketing_strategy: "fully_randomized" + bucketing_batch_size: null + + validation_ds: + file_names: ??? # Path to a list of JSONL files corresponding to the source data. Data format is identical to train_ds. + names: null # Names of the corresponding datasets used to log metrics. + global_batch_size: ${model.global_batch_size} + micro_batch_size: ${model.micro_batch_size} + shuffle: False + num_workers: 0 + pin_memory: True + max_seq_length: 2048 + min_seq_length: 1 + drop_last: False + context_key: 'input' + label_key: 'output' + add_eos: ${model.data.train_ds.add_eos} + add_sep: ${model.data.train_ds.add_sep} + add_bos: ${model.data.train_ds.add_bos} + separate_prompt_and_response_with_newline: ${model.data.train_ds.separate_prompt_and_response_with_newline} + write_predictions_to_file: False + output_file_path_prefix: null # Prefix of the file to write predictions to. + truncation_field: "context" # Options: ['context', 'answer'] + index_mapping_dir: null # Path to a directory to write index mapping files. + prompt_template: ${model.data.train_ds.prompt_template} # fstring to use for assistant prompt. Example: "Q: {input}\nA: {output}" + # ASR configs + sample_rate: 16000 #${model.audio_encoder.preprocessor.sample_rate} + + metric: + name: "loss" # Name of the evaluation metric to use. Options: ['exact_string_match', 'loss'] + average: null # Average the metric over the dataset. Options: ['macro', 'micro']. Works only for 'F1', 'accuracy' etc. Refer to torchmetrics for metrics where this is supported. + num_classes: null + + test_ds: + file_names: null # Path to a list of JSONL files corresponding to the source data. Data format is identical to train_ds. + names: null # Names of the corresponding datasets used to log metrics. + global_batch_size: ${model.global_batch_size} + micro_batch_size: ${model.micro_batch_size} + shuffle: False + num_workers: 4 + pin_memory: True + max_seq_length: 2048 + min_seq_length: 1 + drop_last: False + context_key: 'input' + label_key: 'output' + add_eos: ${model.data.train_ds.add_eos} + add_sep: ${model.data.train_ds.add_sep} + add_bos: ${model.data.train_ds.add_bos} + separate_prompt_and_response_with_newline: ${model.data.train_ds.separate_prompt_and_response_with_newline} + write_predictions_to_file: False + output_file_path_prefix: null # Prefix of the file to write predictions to. + truncation_field: "context" # Options: ['context', 'answer'] + index_mapping_dir: null # Path to a directory to write index mapping files. + prompt_template: ${model.data.train_ds.prompt_template} + # ASR configs + sample_rate: 16000 #${model.audio_encoder.preprocessor.sample_rate} + metric: + name: "loss" # Name of the evaluation metric to use. Options: ['exact_string_match', 'loss'] + average: null # Average the metric over the dataset. Options: ['macro', 'micro']. Works only for 'F1', 'accuracy' etc. Refer to torchmetrics for metrics where this is supported. + num_classes: null optim: name: fused_adam @@ -119,14 +301,4 @@ model: min_lr: 0.0 # min_lr must be 0.0 for prompt learning when pipeline parallel > 1 constant_steps: 0 # Constant steps should also be 0 when min_lr=0 monitor: val_loss - reduce_on_plateau: false - - # required for reporting validation metrics - inference: - greedy: False - top_k: 0 - top_p: 0.9 - temperature: 1.0 - tokens_to_generate: 30 - repetition_penalty: 1.2 - min_tokens_to_generate: 0 + reduce_on_plateau: false \ No newline at end of file diff --git a/nemo/collections/common/parts/preprocessing/collections.py b/nemo/collections/common/parts/preprocessing/collections.py index ed9e53ae6ffe..131c1aaa56e7 100644 --- a/nemo/collections/common/parts/preprocessing/collections.py +++ b/nemo/collections/common/parts/preprocessing/collections.py @@ -235,6 +235,200 @@ def __init__(self, manifests_files: Union[str, List[str]], *args, **kwargs): ) +class AudioQuestAns(_Collection): + """List of audio-transcript text correspondence with preprocessing.""" + + OUTPUT_TYPE = collections.namedtuple( + typename='AudioQAEntity', field_names='id audio_file duration question answer offset speaker orig_sr lang', + ) + + def __init__( + self, + ids: List[int], + audio_files: List[str], + durations: List[float], + questions: List[str], + answers: List[str], + offsets: List[str], + speakers: List[Optional[int]], + orig_sampling_rates: List[Optional[int]], + langs: List[Optional[str]], + min_duration: Optional[float] = None, + max_duration: Optional[float] = None, + max_number: Optional[int] = None, + do_sort_by_duration: bool = False, + index_by_file_id: bool = False, + ): + """Instantiates audio-question-answer manifest with filters and preprocessing. + + Args: + ids: List of examples positions. + audio_files: List of audio files. + durations: List of float durations. + questions: List of raw text transcripts. + answers: List of raw text transcripts. + offsets: List of duration offsets or None. + speakers: List of optional speakers ids. + orig_sampling_rates: List of original sampling rates of audio files. + langs: List of language ids, one for eadh sample, or None. + min_duration: Minimum duration to keep entry with (default: None). + max_duration: Maximum duration to keep entry with (default: None). + max_number: Maximum number of samples to collect. + do_sort_by_duration: True if sort samples list by duration. Not compatible with index_by_file_id. + index_by_file_id: If True, saves a mapping from filename base (ID) to index in data. + """ + + output_type = self.OUTPUT_TYPE + data, duration_filtered, num_filtered, total_duration = [], 0.0, 0, 0.0 + if index_by_file_id: + self.mapping = {} + + for id_, audio_file, duration, offset, question, answer, speaker, orig_sr, lang in zip( + ids, audio_files, durations, offsets, questions, answers, speakers, orig_sampling_rates, langs + ): + # Duration filters. + if min_duration is not None and duration < min_duration: + duration_filtered += duration + num_filtered += 1 + continue + + if max_duration is not None and duration > max_duration: + duration_filtered += duration + num_filtered += 1 + continue + + if answer is None: + duration_filtered += duration + num_filtered += 1 + continue + + total_duration += duration + + data.append(output_type(id_, audio_file, duration, question, answer, offset, speaker, orig_sr, lang)) + if index_by_file_id: + file_id, _ = os.path.splitext(os.path.basename(audio_file)) + if file_id not in self.mapping: + self.mapping[file_id] = [] + self.mapping[file_id].append(len(data) - 1) + + # Max number of entities filter. + if len(data) == max_number: + break + + if do_sort_by_duration: + if index_by_file_id: + logging.warning("Tried to sort dataset by duration, but cannot since index_by_file_id is set.") + else: + data.sort(key=lambda entity: entity.duration) + + logging.info("Dataset loaded with %d files totalling %.2f hours", len(data), total_duration / 3600) + logging.info("%d files were filtered totalling %.2f hours", num_filtered, duration_filtered / 3600) + + super().__init__(data) + + +class ALMAudioQA(AudioQuestAns): + """`AudioQuestAns` collector from audio-LM json files.""" + + def __init__(self, manifests_files: Union[str, List[str]], *args, **kwargs): + """Parse lists of audio files, durations and transcripts texts. + + Args: + manifests_files: Either single string file or list of such - + manifests to yield items from. + *args: Args to pass to `AudioText` constructor. + **kwargs: Kwargs to pass to `AudioText` constructor. + """ + + ids, audio_files, durations, questions, answers, offsets, = ( + [], + [], + [], + [], + [], + [], + ) + speakers, orig_srs, langs = ( + [], + [], + [], + ) + for item in manifest.item_iter(manifests_files, parse_func=self.__parse_item): + ids.append(item['id']) + audio_files.append(item['audio_file']) + durations.append(item['duration']) + questions.append(item['question']) + answers.append(item['answer']) + offsets.append(item['offset']) + speakers.append(item['speaker']) + orig_srs.append(item['orig_sr']) + langs.append(item['lang']) + super().__init__( + ids, audio_files, durations, questions, answers, offsets, speakers, orig_srs, langs, *args, **kwargs + ) + + def __parse_item(self, line: str, manifest_file: str) -> Dict[str, Any]: + item = json.loads(line) + + # Audio file + if 'audio_filename' in item: + item['audio_file'] = item.pop('audio_filename') + elif 'audio_filepath' in item: + item['audio_file'] = item.pop('audio_filepath') + elif 'audio_file' not in item: + raise ValueError( + f"Manifest file {manifest_file} has invalid json line structure: {line} without proper audio file key." + ) + + # If the audio path is a relative path and does not exist, + # try to attach the parent directory of manifest to the audio path. + # Revert to the original path if the new path still doesn't exist. + # Assume that the audio path is like "wavs/xxxxxx.wav". + item['audio_file'] = manifest.get_full_path(audio_file=item['audio_file'], manifest_file=manifest_file) + + # Duration. + if 'duration' not in item: + raise ValueError( + f"Manifest file {manifest_file} has invalid json line structure: {line} without proper duration key." + ) + + # Question. + if 'question' in item: + pass + elif 'question_filepath' in item: + with open(item.pop('text_filepath'), 'r') as f: + item['question'] = f.read().replace('\n', '') + elif 'normalized_text' in item: + item['question'] = item['normalized_text'] + else: + item['question'] = "what does this audio mean" + + # Answer. + if 'answer' in item: + pass + elif 'text' in item: + item['answer'] = item.pop('text') + elif 'text_filepath' in item: + with open(item.pop('text_filepath'), 'r') as f: + item['answer'] = f.read().replace('\n', '') + elif 'normalized_text' in item: + item['answer'] = item['normalized_text'] + else: + item['answer'] = "" + + item = dict( + audio_file=item['audio_file'], + duration=item['duration'], + question=item['question'], + answer=item['answer'], + offset=item.get('offset', None), + speaker=item.get('speaker', None), + orig_sr=item.get('orig_sample_rate', None), + lang=item.get('lang', None), + ) + return item + + class SpeechLabel(_Collection): """List of audio-label correspondence with preprocessing.""" diff --git a/nemo/collections/multimodal/__init__.py b/nemo/collections/multimodal/__init__.py index ebaf915b793d..25bd1810500a 100644 --- a/nemo/collections/multimodal/__init__.py +++ b/nemo/collections/multimodal/__init__.py @@ -12,4 +12,4 @@ # See the License for the specific language governing permissions and # limitations under the License. -from nemo.collections.multimodal import models, modules \ No newline at end of file +from nemo.collections.multimodal import models, modules diff --git a/nemo/collections/multimodal/data/__init__.py b/nemo/collections/multimodal/data/__init__.py new file mode 100644 index 000000000000..4fc50543f1d2 --- /dev/null +++ b/nemo/collections/multimodal/data/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/nemo/collections/multimodal/data/audio_text_qa_dataset.py b/nemo/collections/multimodal/data/audio_text_qa_dataset.py new file mode 100644 index 000000000000..f14bab3abc92 --- /dev/null +++ b/nemo/collections/multimodal/data/audio_text_qa_dataset.py @@ -0,0 +1,1005 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import io +import json +import math +import multiprocessing +import os +from typing import Callable, Dict, Iterable, List, Optional, Tuple, Union + +import braceexpand +import numpy as np +import torch +import webdataset as wd +from torch.utils.data import ChainDataset +from tqdm import tqdm + +from nemo.collections.asr.data.audio_to_text import ( + cache_datastore_manifests, + expand_sharded_filepaths, + shard_manifests_if_needed, +) +from nemo.collections.asr.parts.preprocessing.features import WaveformFeaturizer +from nemo.collections.asr.parts.utils.audio_utils import ChannelSelectorType +from nemo.collections.common import tokenizers +from nemo.collections.common.parts.preprocessing import collections, parsers +from nemo.collections.multimodal.parts.utils.data_utils import ceil_to_nearest, maybe_cast_to_list +from nemo.core.classes import Dataset, IterableDataset +from nemo.core.neural_types import * +from nemo.utils import logging +from nemo.utils.data_utils import ( + DataStoreObject, + datastore_object_get, + datastore_path_to_webdataset_url, + is_datastore_cache_shared, + is_datastore_path, + is_tarred_path, +) +from nemo.utils.get_rank import is_global_rank_zero + +__all__ = [ + 'AudioQuestionAnswerDataset', + 'TarredAudioQuestionAnswerDataset', +] + + +def _speech_collate_fn(audio_signals, audio_lengths): + """collate batch of audio sig, audio len, tokens, tokens len + Args: + audio_signals: List[Tensor] + audio_lengths: List[Tensor] + """ + + max_audio_len = 0 + has_audio = audio_lengths[0] is not None + if has_audio: + max_audio_len = max(audio_lengths).item() + + audio_signals_padded = [] + for sig, sig_len in zip(audio_signals, audio_lengths): + if has_audio: + sig_len = sig_len.item() + if sig_len < max_audio_len: + pad = (0, max_audio_len - sig_len) + sig = torch.nn.functional.pad(sig, pad) + audio_signals_padded.append(sig) + + if has_audio: + audio_signals_padded = torch.stack(audio_signals_padded) + audio_lengths = torch.stack(audio_lengths) + else: + audio_signals_padded, audio_lengths = None, None + + return audio_signals_padded, audio_lengths + + +class TokenizerWrapper: + def __init__(self, tokenizer): + if isinstance(tokenizer, tokenizers.aggregate_tokenizer.AggregateTokenizer): + self.is_aggregate = True + else: + self.is_aggregate = False + self._tokenizer = tokenizer + + def __call__(self, *args): + if isinstance(args[0], List) and self.is_aggregate: + t = [] + for span in args[0]: + t.extend(self._tokenizer.text_to_ids(span['str'], span['lang'])) + return t + + t = self._tokenizer.text_to_ids(*args) + return t + + +class AudioQuestionAnswerDataset(Dataset): + """ + Dataset that loads tensors via a json file containing paths to audio files, transcripts, and durations (in seconds). + Each new line is a different sample. Example below: + {"audio_filepath": "/path/to/audio.wav", "text_filepath": "/path/to/audio.txt", "duration": 23.147} + ... + {"audio_filepath": "/path/to/audio.wav", "text": "the transcription", "offset": 301.75, "duration": 0.82, "utt": + "utterance_id", "ctm_utt": "en_4156", "side": "A"} + Args: + manifest_filepath: Path to manifest json as described above. Can be comma-separated paths. + parser: Str for a language specific preprocessor or a callable. + sample_rate (int): Sample rate to resample loaded audio to + int_values (bool): If true, load samples as 32-bit integers. Defauts to False. + augmentor (nemo.collections.asr.parts.perturb.AudioAugmentor): An AudioAugmentor object used to augment loaded + audio + max_duration: If audio exceeds this length, do not include in dataset + min_duration: If audio is less than this length, do not include in dataset + max_utts: Limit number of utterances + trim: whether or not to trim silence. Defaults to False + channel_selector (int | Iterable[int] | str): select a single channel or a subset of channels from multi-channel audio. If set to `'average'`, it performs averaging across channels. Disabled if set to `None`. Defaults to `None`. Uses zero-based indexing. + --------- NLP SPECIFIC ARGS ------------- + max_seq_length (int): maximum sequence length for each dataset examples. Examples will either be truncated to fit this length or dropped if they cannot be truncated. + min_seq_length (int): min length of each data example in the dataset. Data examples will be dropped if they do not meet the min length requirements. + add_bos (bool): Whether to add a beginning of sentence token to each data example + add_eos (bool): Whether to add an end of sentence token to each data example + add_sep (bool): Whether to add a separation token to each data example (goes between prompt and answer) + tokens_to_generate (int): (inference only) Number of tokens to generate during inference + seed: Random seed for data shuffling. + max_num_samples: Maximum number of samples to load. This can be > dataset length if you want to oversample data. If None, all samples will be loaded. + seed: int = 1234, + input_key: Key to use for the context in your JSONL file + output_key: Key to use for the label in your JSONL file + separate_prompt_and_response_with_newline: Adds a newline between prompt and response. + answer_only_loss: If True, will compute the loss only on the answer part of the input. If False, will compute the loss on the entire input. + truncation_field: Field to use for truncation. (Options: "answer", "context"). Field to be used for truncation if the combined length exceeds the max sequence length. + pad_to_max_length: Whether to pad the input to the max sequence length. If False, will pad to the max length of the current batch. + index_mapping_dir: Directory to save the index mapping to. If None, will write to the same folder as the dataset. + prompt_template: Prompt template to inject via an fstring. Formatted like Q: {input}\n\nA: {output} + """ + + def __init__( + self, + manifest_filepath: str, + tokenizer: 'nemo.collections.common.tokenizers.TokenizerSpec', + sample_rate: int, + int_values: bool = False, + augmentor: 'nemo.collections.asr.parts.perturb.AudioAugmentor' = None, + max_duration: Optional[int] = None, + min_duration: Optional[int] = None, + max_utts: int = 0, + trim: bool = False, + channel_selector: Optional[ChannelSelectorType] = None, + max_seq_length: int = 1024, + min_seq_length: int = 1, + add_bos: bool = False, + add_eos: bool = True, + add_sep: bool = False, + sep_id: Optional[int] = None, + max_num_samples: Optional[int] = None, + seed: int = 1234, + separate_prompt_and_response_with_newline: bool = False, + answer_only_loss: bool = True, + truncation_field: str = "answer", + pad_to_max_length: bool = False, # (@adithyare) allows for much faster training especially in PEFT settings. + index_mapping_dir: str = None, + prompt_template: str = None, + virtual_tokens: int = 0, + tokens_to_generate: int = 0, + index_by_file_id: bool = False, + ): + self.input_key = 'input' + self.output_key = 'output' + self.tokenizer = tokenizer + self.max_seq_length = max_seq_length + self.min_seq_length = min_seq_length + self.max_num_samples = max_num_samples + self.seed = seed + self.separate_prompt_and_response_with_newline = separate_prompt_and_response_with_newline + self.answer_only_loss = answer_only_loss + self.truncation_field = truncation_field + self.pad_to_max_length = pad_to_max_length + self.index_mapping_dir = index_mapping_dir + self.prompt_template = prompt_template + self.virtual_tokens = virtual_tokens + self.tokens_to_generate = tokens_to_generate + self.tokens_to_generate = tokens_to_generate + self.add_bos = add_bos + self.add_eos = add_eos + self.add_sep = add_sep + + if add_bos and hasattr(tokenizer, "bos_id") and tokenizer.bos_id > 0: + self.bos_id = tokenizer.bos_id + else: + self.bos_id = None + + if add_eos and hasattr(tokenizer, "eos_id") and tokenizer.eos_id > 0: + self.eos_id = tokenizer.eos_id + else: + self.eos_id = None + + if hasattr(tokenizer, "pad_id"): + self.pad_id = tokenizer.pad_id + else: + self.pad_id = self.eos_id if self.eos_id is not None else 0 + + self.sep_id = sep_id if add_sep else None + + if hasattr(tokenizer, "pad_id") and tokenizer.pad_id > 0: + self.pad_id = tokenizer.pad_id + else: + self.pad_id = 0 + + if self.prompt_template is not None: + # When providing things like newlines in the prompt template via the CLI, they are escaped. This line unescapes them. + self.prompt_template = self.prompt_template.encode('utf-8').decode('unicode_escape') + assert self.truncation_field in ["answer", "context"] + + if type(manifest_filepath) == str: + manifest_filepath = manifest_filepath.split(",") + + # If necessary, cache manifests and audio from object store + cache_datastore_manifests(manifest_filepaths=manifest_filepath, cache_audio=True) + + self.collection = collections.ALMAudioQA( + manifests_files=manifest_filepath, + min_duration=min_duration, + max_duration=max_duration, + max_number=max_utts, + index_by_file_id=index_by_file_id, + ) + + self.featurizer = WaveformFeaturizer(sample_rate=sample_rate, int_values=int_values, augmentor=augmentor) + self.trim = trim + self.channel_selector = channel_selector + + def get_manifest_sample(self, sample_id): + return self.collection[sample_id] + + def __getitem__(self, index): + output = {"idx": index} + sample = self.collection[index] + offset = sample.offset + + if offset is None: + offset = 0 + + features = self.featurizer.process( + sample.audio_file, + offset=offset, + duration=sample.duration, + trim=self.trim, + orig_sr=sample.orig_sr, + channel_selector=self.channel_selector, + ) + f, fl = features, torch.tensor(features.shape[0]).long() + output["audio_signal"] = f + output["audio_length"] = fl + + text_data = self._process_example(context=sample.question, output=sample.answer) + + output.update(text_data) + + return output + + def __len__(self): + return len(self.collection) + + def _process_example(self, context: str, output: str): + """ + Create an example by concatenating text and answer. + Truncation is carried out when needed, but it is performed only on the prompt side. + BOS, EOS, and SEP, are added if specified. + + function copied from nemo/collections/nlp/data/language_modelling/megatron/gpt_sft_dataset.py + """ + + if self.prompt_template is not None: + assert f'{{{self.input_key}}}' in self.prompt_template + assert f'{{{self.output_key}}}' in self.prompt_template + # Make sure that '{output}' always occurs at the end of the prompt template string + assert self.prompt_template.index(f'{{{self.output_key}}}') == len(self.prompt_template) - len( + f'{{{self.output_key}}}' + ) + # Get the context by replacing only the input + original_context = context + context = ( + self.prompt_template.replace(f'{{{self.input_key}}}', context) + .replace(f'{{{self.output_key}}}', '') + .strip(' ') + ) + # Replace the input and output placeholders with the actual input and output + text = self.prompt_template.replace(f'{{{self.input_key}}}', original_context).replace( + f'{{{self.output_key}}}', output + ) + + if self.separate_prompt_and_response_with_newline and self.prompt_template is None: + text = context + '\n' + output + elif not self.separate_prompt_and_response_with_newline and self.prompt_template is None: + text = context + ' ' + output + + if self.virtual_tokens: + # (@adithyare) we are going to insert "pad/eos" tokens in the beginning of the text and context + # these pad/eos tokens are placeholders for virtual tokens + pre_pad = [self.tokenizer.eos_id] * self.virtual_tokens + else: + pre_pad = [] + tokenized_text = pre_pad + self.tokenizer.text_to_ids(text) + context_ids = pre_pad + self.tokenizer.text_to_ids(context) + answer_ids = tokenized_text[len(context_ids) :] + + # for the long context cases, collate_fn includes self.tokens_to_generate for padding + total_ids = len(context_ids) + max(len(answer_ids), self.tokens_to_generate) + if self.add_bos: + total_ids += 1 + if self.add_sep: + total_ids += 1 + if self.add_eos: + total_ids += 1 + + # If the total number of token is greater than the max, we will try to truncate the answer + if total_ids > self.max_seq_length: + truncation_length = total_ids - self.max_seq_length + if self.truncation_field == "answer": + answer_ids = answer_ids[: -min(truncation_length, len(answer_ids))] + elif self.truncation_field == "context": + context_ids = context_ids[: -min(truncation_length, len(context_ids))] + + if len(context_ids) > self.max_seq_length: + context_ids = context_ids[: self.max_seq_length] + + assert len(context_ids) <= self.max_seq_length + input_ids = context_ids + + answer_start_idx = len(input_ids) + # Adds sep token between text/prompt and answer + if self.add_sep: + input_ids = input_ids + [self.sep_id] + answer_start_idx += 1 + + input_ids = input_ids + answer_ids + + if self.add_bos: + input_ids = [self.tokenizer.bos_id] + input_ids + answer_start_idx += 1 + if self.add_eos: + input_ids = input_ids + [self.tokenizer.eos_id] + + if len(input_ids) < self.min_seq_length or len(input_ids) > self.max_seq_length: + input_ids = input_ids[: self.max_seq_length] + + processed_example = { + 'input_ids': input_ids, + 'answer_start_idx': answer_start_idx, + 'context_ids': context_ids, + 'context_length': len(context_ids), + } + + return processed_example + + def _build_loss_mask(self, processed_example): + """ Pad input_ids in batch to max batch length while building loss mask """ + # function copied from nemo/collections/nlp/data/language_modelling/megatron/gpt_sft_dataset.py + input_ids = processed_example['input_ids'] + answer_start_idx = processed_example['answer_start_idx'] + if self.answer_only_loss: + loss_mask = [float(idx >= answer_start_idx) for idx in range(len(input_ids))] + else: + loss_mask = [1.0] * len(input_ids) + + return loss_mask + + def _collate_item(self, item, max_length, pad_id): + # function copied from nemo/collections/nlp/data/language_modelling/megatron/gpt_sft_dataset.py + item = maybe_cast_to_list(item) + # max_length = max([len(x) for x in item]) if item else 0 + # here [0] should be tokenizer.pad_id + item = [x + [pad_id] * (max_length - len(x)) for x in item] + return item + + def _collate_fn(self, batch): + sample_ids = [x["idx"] for x in batch] + sample_ids = torch.tensor(sample_ids, dtype=torch.int32) + + audio_signal = [x["audio_signal"] for x in batch] + audio_lengths = [x["audio_length"] for x in batch] + audio_signal, audio_lengths = _speech_collate_fn(audio_signal, audio_lengths) + + input_ids = [item['input_ids'][:-1] for item in batch] + labels = [item['input_ids'][1:] for item in batch] + contexts = [item['context_ids'] for item in batch] + context_lengths = torch.LongTensor([item['context_length'] for item in batch]) + loss_mask = [self._build_loss_mask(item)[1:] for item in batch] + + max_length = max([len(x) for x in input_ids]) + self.tokens_to_generate + # increase max length to nearest multiple of 4 or 8 + if self.pad_to_max_length: + max_length = self.max_seq_length + else: + max_length = min(self.max_seq_length, ceil_to_nearest(max_length, 8)) + assert max_length <= self.max_seq_length + + position_ids = [list(range(max_length)) for _ in batch] + position_ids = torch.LongTensor(position_ids) + input_length = torch.LongTensor([len(x) for x in input_ids]) + input_ids = torch.LongTensor( + self._collate_item(input_ids, max_length=max_length, pad_id=self.tokenizer.eos_id) + ) + labels = torch.LongTensor(self._collate_item(labels, max_length=max_length, pad_id=self.tokenizer.eos_id)) + loss_mask = torch.LongTensor(self._collate_item(loss_mask, max_length=max_length, pad_id=0)) + contexts = torch.LongTensor(self._collate_item(contexts, max_length=max_length, pad_id=self.tokenizer.eos_id)) + + batch = { + 'sample_ids': sample_ids, + 'audio_signal': audio_signal, + 'audio_signal_length': audio_lengths, + 'tokens': input_ids, + 'tokens_length': input_length, + 'labels': labels, + 'loss_mask': loss_mask, + 'position_ids': position_ids, + 'contexts': contexts, + 'context_lengths': context_lengths, + 'max_length': torch.LongTensor(max_length), + } + + return batch + + +class _TarredAudioToTextDataset(IterableDataset): + """ + A similar Dataset to the AudioToCharDataset/AudioToBPEDataset, but which loads tarred audio files. + + Accepts a single comma-separated JSON manifest file (in the same style as for the AudioToCharDataset/AudioToBPEDataset), + as well as the path(s) to the tarball(s) containing the wav files. Each line of the manifest should + contain the information for one audio file, including at least the transcript and name of the audio + file within the tarball. + + Valid formats for the audio_tar_filepaths argument include: + (1) a single string that can be brace-expanded, e.g. 'path/to/audio.tar' or 'path/to/audio_{1..100}.tar.gz', or + (2) a list of file paths that will not be brace-expanded, e.g. ['audio_1.tar', 'audio_2.tar', ...]. + + Note: For brace expansion in (1), there may be cases where `{x..y}` syntax cannot be used due to shell interference. + This occurs most commonly inside SLURM scripts. Therefore we provide a few equivalent replacements. + Supported opening braces - { <=> (, [, < and the special tag _OP_. + Supported closing braces - } <=> ), ], > and the special tag _CL_. + For SLURM based tasks, we suggest the use of the special tags for ease of use. + + See the WebDataset documentation for more information about accepted data and input formats. + + If using multiple workers the number of shards should be divisible by world_size to ensure an + even split among workers. If it is not divisible, logging will give a warning but training will proceed. + In addition, if using mutiprocessing, each shard MUST HAVE THE SAME NUMBER OF ENTRIES after filtering + is applied. We currently do not check for this, but your program may hang if the shards are uneven! + + Notice that a few arguments are different from the AudioToCharDataset; for example, shuffle (bool) has been + replaced by shuffle_n (int). + + Additionally, please note that the len() of this DataLayer is assumed to be the length of the manifest + after filtering. An incorrect manifest length may lead to some DataLoader issues down the line. + + Args: + audio_tar_filepaths: Either a list of audio tarball filepaths, or a + string (can be brace-expandable). + manifest_filepath (str): Path to the manifest. + parser (callable): A callable which is used to pre-process the text output. + sample_rate (int): Sample rate to resample loaded audio to + int_values (bool): If true, load samples as 32-bit integers. Defauts to False. + augmentor (nemo.collections.asr.parts.perturb.AudioAugmentor): An AudioAugmentor + object used to augment loaded audio + shuffle_n (int): How many samples to look ahead and load to be shuffled. + See WebDataset documentation for more details. + Defaults to 0. + min_duration (float): Dataset parameter. + All training files which have a duration less than min_duration + are dropped. Note: Duration is read from the manifest JSON. + Defaults to 0.1. + max_duration (float): Dataset parameter. + All training files which have a duration more than max_duration + are dropped. Note: Duration is read from the manifest JSON. + Defaults to None. + blank_index (int): Blank character index, defaults to -1. + unk_index (int): Unknown character index, defaults to -1. + normalize (bool): Dataset parameter. + Whether to use automatic text cleaning. + It is highly recommended to manually clean text for best results. + Defaults to True. + trim (bool): Whether to use trim silence from beginning and end + of audio signal using librosa.effects.trim(). + Defaults to False. + bos_id (id): Dataset parameter. + Beginning of string symbol id used for seq2seq models. + Defaults to None. + eos_id (id): Dataset parameter. + End of string symbol id used for seq2seq models. + Defaults to None. + pad_id (id): Token used to pad when collating samples in batches. + If this is None, pads using 0s. + Defaults to None. + shard_strategy (str): Tarred dataset shard distribution strategy chosen as a str value during ddp. + - `scatter`: The default shard strategy applied by WebDataset, where each node gets + a unique set of shards, which are permanently pre-allocated and never changed at runtime. + - `replicate`: Optional shard strategy, where each node gets all of the set of shards + available in the tarred dataset, which are permanently pre-allocated and never changed at runtime. + The benefit of replication is that it allows each node to sample data points from the entire + dataset independently of other nodes, and reduces dependence on value of `shuffle_n`. + + .. warning:: + Replicated strategy allows every node to sample the entire set of available tarfiles, + and therefore more than one node may sample the same tarfile, and even sample the same + data points! As such, there is no assured guarantee that all samples in the dataset will be + sampled at least once during 1 epoch. Scattered strategy, on the other hand, on specific + occasions (when the number of shards is not divisible with ``world_size``), will not sample + the entire dataset. For these reasons it is not advisable to use tarred datasets as validation + or test datasets. + shard_manifests (bool): Whether or not to try / shard manifests. Defaults to False. + global_rank (int): Worker rank, used for partitioning shards. Defaults to 0. + world_size (int): Total number of processes, used for partitioning shards. Defaults to 0. + return_sample_id (bool): whether to return the sample_id as a part of each sample + --------- NLP SPECIFIC ARGS ------------- + max_seq_length (int): maximum sequence length for each dataset examples. Examples will either be truncated to fit this length or dropped if they cannot be truncated. + min_seq_length (int): min length of each data example in the dataset. Data examples will be dropped if they do not meet the min length requirements. + add_bos (bool): Whether to add a beginning of sentence token to each data example + add_eos (bool): Whether to add an end of sentence token to each data example + add_sep (bool): Whether to add a separation token to each data example (goes between prompt and answer) + tokens_to_generate (int): (inference only) Number of tokens to generate during inference + seed: Random seed for data shuffling. + max_num_samples: Maximum number of samples to load. This can be > dataset length if you want to oversample data. If None, all samples will be loaded. + seed: int = 1234, + input_key: Key to use for the context in your JSONL file + output_key: Key to use for the label in your JSONL file + separate_prompt_and_response_with_newline: Adds a newline between prompt and response. + answer_only_loss: If True, will compute the loss only on the answer part of the input. If False, will compute the loss on the entire input. + truncation_field: Field to use for truncation. (Options: "answer", "context"). Field to be used for truncation if the combined length exceeds the max sequence length. + pad_to_max_length: Whether to pad the input to the max sequence length. If False, will pad to the max length of the current batch. + index_mapping_dir: Directory to save the index mapping to. If None, will write to the same folder as the dataset. + prompt_template: Prompt template to inject via an fstring. Formatted like Q: {input}\n\nA: {output} + """ + + def __init__( + self, + audio_tar_filepaths: Union[str, List[str]], + manifest_filepath: str, + parser: Callable, + sample_rate: int, + int_values: bool = False, + augmentor: Optional['nemo.collections.asr.parts.perturb.AudioAugmentor'] = None, + shuffle_n: int = 0, + min_duration: Optional[float] = None, + max_duration: Optional[float] = None, + trim: bool = False, + bos_id: Optional[int] = None, + eos_id: Optional[int] = None, + pad_id: int = 0, + shard_strategy: str = "scatter", + shard_manifests: bool = False, + global_rank: int = 0, + world_size: int = 0, + return_sample_id: bool = False, + max_seq_length: int = 1024, + min_seq_length: int = 1, + add_bos: bool = False, + add_eos: bool = True, + add_sep: bool = False, + sep_id: int = None, + max_num_samples: int = None, + seed: int = 1234, + separate_prompt_and_response_with_newline: bool = False, + answer_only_loss: bool = True, + truncation_field: str = "answer", + pad_to_max_length: bool = False, # (@adithyare) allows for much faster training especially in PEFT settings. + index_mapping_dir: str = None, + prompt_template: str = None, + virtual_tokens: int = 0, + tokens_to_generate: int = 0, + ): + self.max_seq_length = max_seq_length + self.min_seq_length = min_seq_length + self.add_bos = add_bos + self.add_eos = add_eos + self.add_sep = add_sep + self.sep_id = sep_id + self.max_num_samples = max_num_samples + self.seed = seed + self.separate_prompt_and_response_with_newline = separate_prompt_and_response_with_newline + self.answer_only_loss = answer_only_loss + self.truncation_field = truncation_field + self.pad_to_max_length = pad_to_max_length + self.index_mapping_dir = index_mapping_dir + self.prompt_template = prompt_template + self.virtual_tokens = virtual_tokens + self.tokens_to_generate = tokens_to_generate + self.tokens_to_generate = tokens_to_generate + if self.prompt_template is not None: + # When providing things like newlines in the prompt template via the CLI, they are escaped. This line unescapes them. + self.prompt_template = self.prompt_template.encode('utf-8').decode('unicode_escape') + assert self.truncation_field in ["answer", "context"] + + self.shard_manifests = shard_manifests + + # Shard manifests if necessary and possible and then expand the paths + manifest_filepath = shard_manifests_if_needed( + shard_manifests=shard_manifests, + shard_strategy=shard_strategy, + manifest_filepaths=manifest_filepath, + world_size=world_size, + global_rank=global_rank, + ) + + # If necessary, cache manifests from object store + cache_datastore_manifests(manifest_filepaths=manifest_filepath) + + self.manifest_processor = ASRManifestProcessor( + manifest_filepath=manifest_filepath, + parser=parser, + max_duration=max_duration, + min_duration=min_duration, + max_utts=0, + bos_id=bos_id, + eos_id=eos_id, + pad_id=pad_id, + index_by_file_id=True, # Must set this so the manifest lines can be indexed by file ID + ) + + self.len = self._compute_len() + + self.featurizer = WaveformFeaturizer(sample_rate=sample_rate, int_values=int_values, augmentor=augmentor) + self.trim = trim + self.eos_id = eos_id + self.bos_id = bos_id + self.pad_id = pad_id + self.return_sample_id = return_sample_id + + audio_tar_filepaths = expand_sharded_filepaths( + sharded_filepaths=audio_tar_filepaths, + shard_strategy=shard_strategy, + world_size=world_size, + global_rank=global_rank, + ) + + # Put together WebDataset + self._dataset = wd.WebDataset(urls=audio_tar_filepaths, nodesplitter=None) + + if shuffle_n > 0: + self._dataset = self._dataset.shuffle(shuffle_n) + else: + logging.info("WebDataset will not shuffle files within the tar files.") + + self._dataset = ( + self._dataset.rename(audio='wav;ogg;flac', key='__key__') + .to_tuple('audio', 'key') + .pipe(self._filter) + .pipe(self._loop_offsets) + .map(f=self._build_sample) + ) + + def _filter(self, iterator): + """This function is used to remove samples that have been filtered out by ASRAudioText already. + Otherwise, we would get a KeyError as _build_sample attempts to find the manifest entry for a sample + that was filtered out (e.g. for duration). + Note that if using multi-GPU training, filtering may lead to an imbalance in samples in each shard, + which may make your code hang as one process will finish before the other. + """ + + class TarredAudioFilter: + def __init__(self, collection): + self.iterator = iterator + self.collection = collection + + def __iter__(self): + return self + + def __next__(self): + while True: + audio_bytes, audio_filename = next(self.iterator) + file_id, _ = os.path.splitext(os.path.basename(audio_filename)) + if file_id in self.collection.mapping: + return audio_bytes, audio_filename + + return TarredAudioFilter(self.manifest_processor.collection) + + def _loop_offsets(self, iterator): + """This function is used to iterate through utterances with different offsets for each file. + """ + + class TarredAudioLoopOffsets: + def __init__(self, collection): + self.iterator = iterator + self.collection = collection + self.current_fn = None + self.current_bytes = None + self.offset_id = 0 + + def __iter__(self): + return self + + def __next__(self): + if self.current_fn is None: + self.current_bytes, self.current_fn = next(self.iterator) + self.offset_id = 0 + else: + offset_list = self.collection.mapping[self.current_fn] + if len(offset_list) == self.offset_id + 1: + self.current_bytes, self.current_fn = next(self.iterator) + self.offset_id = 0 + else: + self.offset_id += 1 + + return self.current_bytes, self.current_fn, self.offset_id + + return TarredAudioLoopOffsets(self.manifest_processor.collection) + + def _collate_fn(self, batch): + return _speech_collate_fn(batch, self.pad_id) + + def _build_sample(self, tup): + """Builds the training sample by combining the data from the WebDataset with the manifest info. + """ + audio_bytes, audio_filename, offset_id = tup + + # Grab manifest entry from self.manifest_preprocessor.collection + file_id, _ = os.path.splitext(os.path.basename(audio_filename)) + manifest_idx = self.manifest_processor.collection.mapping[file_id][offset_id] + manifest_entry = self.manifest_processor.collection[manifest_idx] + + offset = manifest_entry.offset + if offset is None: + offset = 0 + + # Convert audio bytes to IO stream for processing (for SoundFile to read) + audio_filestream = io.BytesIO(audio_bytes) + features = self.featurizer.process( + audio_filestream, + offset=offset, + duration=manifest_entry.duration, + trim=self.trim, + orig_sr=manifest_entry.orig_sr, + ) + audio_filestream.close() + + # Audio features + f, fl = features, torch.tensor(features.shape[0]).long() + + # Text features + t, tl = manifest_entry.text_tokens, len(manifest_entry.text_tokens) + + self.manifest_processor.process_text_by_sample(sample=manifest_entry) + + if self.bos_id is not None: + t = [self.bos_id] + t + tl += 1 + if self.eos_id is not None: + t = t + [self.eos_id] + tl += 1 + + if self.return_sample_id: + return f, fl, torch.tensor(t).long(), torch.tensor(tl).long(), manifest_idx + else: + return f, fl, torch.tensor(t).long(), torch.tensor(tl).long() + + def get_manifest_sample(self, sample_id): + return self.manifest_processor.collection[sample_id] + + def __iter__(self): + return self._dataset.__iter__() + + def _compute_len(self): + if self.shard_manifests and torch.distributed.is_available() and torch.distributed.is_initialized(): + my_len = torch.tensor(len(self.manifest_processor.collection), dtype=torch.int32).cuda() + torch.distributed.all_reduce(my_len) + my_len = my_len.int() + logging.info(f'Sharded manifests: Total length: {my_len}') + else: + my_len = len(self.manifest_processor.collection) + + return my_len + + def __len__(self): + return self.len + + +class TarredAudioToBPEDataset(_TarredAudioToTextDataset): + """ + A similar Dataset to the AudioToBPEDataset, but which loads tarred audio files. + + Accepts a single comma-separated JSON manifest file (in the same style as for the AudioToBPEDataset), + as well as the path(s) to the tarball(s) containing the wav files. Each line of the manifest should + contain the information for one audio file, including at least the transcript and name of the audio + file within the tarball. + + Valid formats for the audio_tar_filepaths argument include: + (1) a single string that can be brace-expanded, e.g. 'path/to/audio.tar' or 'path/to/audio_{1..100}.tar.gz', or + (2) a list of file paths that will not be brace-expanded, e.g. ['audio_1.tar', 'audio_2.tar', ...]. + + See the WebDataset documentation for more information about accepted data and input formats. + + If using multiple workers the number of shards should be divisible by world_size to ensure an + even split among workers. If it is not divisible, logging will give a warning but training will proceed. + In addition, if using mutiprocessing, each shard MUST HAVE THE SAME NUMBER OF ENTRIES after filtering + is applied. We currently do not check for this, but your program may hang if the shards are uneven! + + Notice that a few arguments are different from the AudioToBPEDataset; for example, shuffle (bool) has been + replaced by shuffle_n (int). + + Additionally, please note that the len() of this DataLayer is assumed to be the length of the manifest + after filtering. An incorrect manifest length may lead to some DataLoader issues down the line. + + Args: + audio_tar_filepaths: Either a list of audio tarball filepaths, or a + string (can be brace-expandable). + manifest_filepath (str): Path to the manifest. + tokenizer (TokenizerSpec): Either a Word Piece Encoding tokenizer (BERT), + or a Sentence Piece Encoding tokenizer (BPE). The CTC blank + symbol is automatically added later for models using ctc. + sample_rate (int): Sample rate to resample loaded audio to + int_values (bool): If true, load samples as 32-bit integers. Defauts to False. + augmentor (nemo.collections.asr.parts.perturb.AudioAugmentor): An AudioAugmentor + object used to augment loaded audio + shuffle_n (int): How many samples to look ahead and load to be shuffled. + See WebDataset documentation for more details. + Defaults to 0. + min_duration (float): Dataset parameter. + All training files which have a duration less than min_duration + are dropped. Note: Duration is read from the manifest JSON. + Defaults to 0.1. + max_duration (float): Dataset parameter. + All training files which have a duration more than max_duration + are dropped. Note: Duration is read from the manifest JSON. + Defaults to None. + trim (bool): Whether to use trim silence from beginning and end + of audio signal using librosa.effects.trim(). + Defaults to False. + use_start_end_token: Boolean which dictates whether to add [BOS] and [EOS] + tokens to beginning and ending of speech respectively. + pad_id (id): Token used to pad when collating samples in batches. + If this is None, pads using 0s. + Defaults to None. + shard_strategy (str): Tarred dataset shard distribution strategy chosen as a str value during ddp. + + - `scatter`: The default shard strategy applied by WebDataset, where each node gets + a unique set of shards, which are permanently pre-allocated and never changed at runtime. + - `replicate`: Optional shard strategy, where each node gets all of the set of shards + available in the tarred dataset, which are permanently pre-allocated and never changed at runtime. + The benefit of replication is that it allows each node to sample data points from the entire + dataset independently of other nodes, and reduces dependence on value of `shuffle_n`. + + .. warning:: + + Replicated strategy allows every node to sample the entire set of available tarfiles, + and therefore more than one node may sample the same tarfile, and even sample the same + data points! As such, there is no assured guarantee that all samples in the dataset will be + sampled at least once during 1 epoch. Scattered strategy, on the other hand, on specific + occasions (when the number of shards is not divisible with ``world_size``), will not sample + the entire dataset. For these reasons it is not advisable to use tarred datasets as validation + or test datasets. + + global_rank (int): Worker rank, used for partitioning shards. Defaults to 0. + world_size (int): Total number of processes, used for partitioning shards. Defaults to 0. + return_sample_id (bool): whether to return the sample_id as a part of each sample + """ + + def __init__( + self, + audio_tar_filepaths: Union[str, List[str]], + manifest_filepath: str, + tokenizer: 'nemo.collections.common.tokenizers.TokenizerSpec', + sample_rate: int, + int_values: bool = False, + augmentor: Optional['nemo.collections.asr.parts.perturb.AudioAugmentor'] = None, + shuffle_n: int = 0, + min_duration: Optional[float] = None, + max_duration: Optional[float] = None, + trim: bool = False, + use_start_end_token: bool = True, + shard_strategy: str = "scatter", + shard_manifests: bool = False, + global_rank: int = 0, + world_size: int = 0, + return_sample_id: bool = False, + ): + if use_start_end_token and hasattr(tokenizer, "bos_id") and tokenizer.bos_id > 0: + bos_id = tokenizer.bos_id + else: + bos_id = None + + if use_start_end_token and hasattr(tokenizer, "eos_id") and tokenizer.eos_id > 0: + eos_id = tokenizer.eos_id + else: + eos_id = None + + if hasattr(tokenizer, "pad_id") and tokenizer.pad_id > 0: + pad_id = tokenizer.pad_id + else: + pad_id = 0 + + class TokenizerWrapper: + def __init__(self, tokenizer): + if isinstance(tokenizer, tokenizers.aggregate_tokenizer.AggregateTokenizer): + self.is_aggregate = True + else: + self.is_aggregate = False + self._tokenizer = tokenizer + + def __call__(self, *args): + if isinstance(args[0], List) and self.is_aggregate: + t = [] + for span in args[0]: + t.extend(self._tokenizer.text_to_ids(span['str'], span['lang'])) + return t + + t = self._tokenizer.text_to_ids(*args) + return t + + super().__init__( + audio_tar_filepaths=audio_tar_filepaths, + manifest_filepath=manifest_filepath, + parser=TokenizerWrapper(tokenizer), + sample_rate=sample_rate, + int_values=int_values, + augmentor=augmentor, + shuffle_n=shuffle_n, + min_duration=min_duration, + max_duration=max_duration, + trim=trim, + bos_id=bos_id, + eos_id=eos_id, + pad_id=pad_id, + shard_strategy=shard_strategy, + shard_manifests=shard_manifests, + global_rank=global_rank, + world_size=world_size, + return_sample_id=return_sample_id, + ) + + +class BucketingDataset(IterableDataset): + """ + A Dataset which wraps another IterableDataset and adopts it for bucketing + Args: + dataset (IterableDataset): The IterableDataset to get wrapped + bucketing_batch_size (int): Number of samples to build a batch + """ + + def __init__( + self, dataset: IterableDataset, bucketing_batch_size: int, + ): + self.wrapped_dataset = dataset + self.bucketing_batch_size = bucketing_batch_size + super().__init__() + + def _collate_fn(self, batch): + return _speech_collate_fn(batch[0], self.wrapped_dataset.pad_id) + + def __iter__(self): + return BucketingIterator( + wrapped_ds=self.wrapped_dataset._dataset, bucketing_batch_size=self.bucketing_batch_size + ).__iter__() + + def __len__(self): + return int(math.ceil(len(self.wrapped_dataset) / float(self.bucketing_batch_size))) + + +class BucketingIterator: + def __init__(self, wrapped_ds, bucketing_batch_size): + self.wrapped_ds = wrapped_ds + self.wrapped_iter = None + self.bucketing_batch_size = bucketing_batch_size + + def __iter__(self): + self.wrapped_iter = iter(self.wrapped_ds) + return self + + def __next__(self): + batches = [] + for idx in range(self.bucketing_batch_size): + try: + sample = next(self.wrapped_iter) + except StopIteration: + break + batches.append(sample) + if len(batches) == 0: + raise StopIteration + return batches + + +class RandomizedChainDataset(ChainDataset): + def __init__(self, datasets: Iterable[Dataset], rnd_seed=0) -> None: + super(RandomizedChainDataset, self).__init__(list(datasets)) + self.rnd_gen = np.random.RandomState(rnd_seed) + + def __iter__(self): + shuffled_order = self.rnd_gen.permutation(len(self.datasets)) + for dataset_idx in shuffled_order: + d = self.datasets[dataset_idx] + assert isinstance(d, IterableDataset), "ChainDataset only supports IterableDataset" + for idx, x in enumerate(d): + yield x + # in case d is an infinite dataset, we want to break the loop + # so that the other datasets get a chance to yield too + if idx >= len(d) - 1: + break diff --git a/nemo/collections/multimodal/models/__init__.py b/nemo/collections/multimodal/models/__init__.py index 932d3e7e57b1..bef1dfbef83f 100644 --- a/nemo/collections/multimodal/models/__init__.py +++ b/nemo/collections/multimodal/models/__init__.py @@ -12,4 +12,4 @@ # See the License for the specific language governing permissions and # limitations under the License. -from nemo.collections.multimodal.models.speechllm_models import * \ No newline at end of file +from nemo.collections.multimodal.models.speechllm_models import * diff --git a/nemo/collections/multimodal/models/speechllm_models.py b/nemo/collections/multimodal/models/speechllm_models.py index 30d24b24930c..d10d8e98bd94 100644 --- a/nemo/collections/multimodal/models/speechllm_models.py +++ b/nemo/collections/multimodal/models/speechllm_models.py @@ -14,26 +14,31 @@ import itertools import os -from typing import Optional, Union, Dict from functools import partial +from typing import Dict, Optional, Union import torch from omegaconf.dictconfig import DictConfig -from nemo.utils import logging, model_utils -from nemo.collections.asr.data.audio_to_text_dali import ( - DALIOutputs, -) +from omegaconf.omegaconf import OmegaConf, open_dict from pytorch_lightning.trainer.trainer import Trainer from nemo.collections.asr.data import audio_to_text_dataset -from nemo.collections.asr.data.audio_to_text_dali import AudioToBPEDALIDataset -from nemo.core.classes.mixins import AccessMixin -from nemo.collections.nlp.modules.common.megatron.utils import build_position_ids +from nemo.collections.asr.data.audio_to_text_dali import AudioToBPEDALIDataset, DALIOutputs +from nemo.collections.asr.models import ASRModel +from nemo.collections.asr.parts.preprocessing.perturb import process_augmentations +from nemo.collections.multimodal.data.audio_text_qa_dataset import AudioQuestionAnswerDataset +from nemo.collections.multimodal.modules.speechllm_perception import AudioPerceptionModel +from nemo.collections.nlp.data.language_modeling.megatron.base_dataset_utils import ( + get_datasets_weights_and_num_samples, +) +from nemo.collections.nlp.data.language_modeling.megatron.blendable_dataset import BlendableDataset +from nemo.collections.nlp.models.language_modeling.megatron_gpt_peft_models import MegatronGPTLoRAModel from nemo.collections.nlp.models.language_modeling.megatron_gpt_prompt_learning_model import ( MegatronGPTPromptLearningModel, ) from nemo.collections.nlp.modules.common.megatron.utils import ( average_losses_across_data_parallel_group, + build_position_ids, get_iterator_k_split, ) from nemo.collections.nlp.modules.common.text_generation_utils import ( @@ -41,25 +46,15 @@ get_default_sampling_params, megatron_gpt_generate, ) -from nemo.collections.nlp.modules.common.transformer.text_generation import ( - LengthParam, - SamplingParam, -) +from nemo.collections.nlp.modules.common.transformer.text_generation import LengthParam, SamplingParam +from nemo.collections.nlp.parts.nlp_overrides import NLPSaveRestoreConnector, PEFTSaveRestoreConnector from nemo.collections.nlp.parts.utils_funcs import get_last_rank -from nemo.core.neural_types import ( - AcousticEncodedRepresentation, - AudioSignal, - LengthsType, - NeuralType, - SpectrogramType, -) -from nemo.utils import AppState, logging +from nemo.core.classes.mixins import AccessMixin, adapter_mixins +from nemo.core.neural_types import AcousticEncodedRepresentation, AudioSignal, LengthsType, NeuralType, SpectrogramType +from nemo.utils import AppState, logging, model_utils try: - from apex.transformer.pipeline_parallel.utils import ( - get_micro_batch_size, - get_num_microbatches, - ) + from apex.transformer.pipeline_parallel.utils import get_micro_batch_size, get_num_microbatches HAVE_APEX = True @@ -77,104 +72,104 @@ HAVE_MEGATRON_CORE = False -__all__ = ["ModularizedSpeechGPTModel"] +__all__ = ["ModularizedAudioGPTModel"] -class ModularizedSpeechGPTModel(MegatronGPTPromptLearningModel): +class ModularizedAudioGPTModel(MegatronGPTLoRAModel): """Modularized speech GPT model.""" def __init__(self, cfg: DictConfig, trainer: Trainer): self.cfg = cfg super().__init__(cfg, trainer) - self.init_perception_model(cfg, trainer) - - def init_perception_model(self, cfg: DictConfig, trainer: Trainer): - # Convert to Hydra 1.0 compatible DictConfig - cfg = model_utils.convert_model_config_to_dict_config(cfg) - cfg = model_utils.maybe_update_config_version(cfg) - - if not isinstance(cfg, DictConfig): - raise ValueError("cfg must be an OmegaConf DictConfig") - - self.perception = ModularizedSpeechGPTModel.from_config_dict( - self.cfg.perception - ) - # TODO(zhehuai): load pretrained perception model weights - - # Get global rank and total number of GPU workers for IterableDataset partitioning, if applicable - # Global_rank and local_rank is set by LightningModule in Lightning 1.2.0 - self.world_size = 1 - if trainer is not None: - self.world_size = trainer.world_size - fixed_prompt_prefix_str = cfg.get("fixed_prompt_prefix", None) - if fixed_prompt_prefix_str is not None: - self.fixed_prompt_prefix = torch.Tensor(self.tokenizer.text_to_ids( - fixed_prompt_prefix_str - )).int().cuda() - else: - self.fixed_prompt_prefix = None - - # follow MegatronGPTPromptLearningModel for GPT model init - def init_model(self, cfg: DictConfig, trainer: Trainer): - super().init_model(cfg, trainer) - # gpt code handle the setup of the tokenizer - # disable text prompt tuning specifics - self.existing_tasks = None - self.new_tasks = None - self.virtual_prompt_style = None - self.word_embeddings = ( - self.frozen_model.model.language_model.embedding.word_embeddings - ) - self.pseudo_tokens = None - self.pseudo_token_ids = None - self.pseudo_token_ids_start = None - self.virtual_prompt_source = None - self.prompt_encoder = None - # self.frozen_model is frozen by setup_optimizer_param_groups - - def state_dict(self): + self.perception = AudioPerceptionModel(cfg=cfg.perception) + self.setup_optimizer_param_groups() + self.configure_optimizers() + self.summarize() + + def parameters(self): + # override the same method in MegatronGPT model to include parameters ouside of LM + all_names = [] + all_params = [] + for name, param in self.named_parameters(recurse=True): + all_names.append(name) + all_params.append(param) + + if isinstance(self.model, list): + for module in self.model: + for name, param in module.named_parameters(recurse=True): + all_names.append(name) + all_params.append(param) + + return itertools.chain(all_params) + + def setup_optimizer_param_groups(self): """ - TODO(zhehuai): Custom state dict. + ModelPT override. Optimizer will get self._optimizer_param_groups. + Makes two optimizer param groups, one for the frozen model params + and one for the prompt-table/prompt-encoder params. The learning + rate for the frozen model's params will always be zero effectively + freezing the model's params but still allowing for the needed gradients + to be passed around in pipeline parallel models. The prompt-encoder + and/or prompt table will use the learning rate set by the user. """ - state_dict_ = {} - - if self.first_stage_of_pipeline(): - pass - - return state_dict_ - - def load_task_templates(self, task_templates): - # TODO(zhehuai): support task template to support complexer SFT format - self.task_templates = {} - self.task_id_num_to_name = {} - self.max_virtual_tokens = 0 - - def get_text_batch_from_audio(self, audio_batch): - _, _, transcript, transcript_len = audio_batch - # TODO(zhehuai) Add BOS/EOS if desired, adds EOS by default - labels = transcript[:, 1:].contiguous() - input_ids = transcript[:, :-1].contiguous() - input_length = transcript_len - 1 - - b = labels.shape[0] - max_len = labels.shape[1] - # Loss mask where answer tokens are 1.0 and all other tokens are 0.0 - loss_mask = torch.arange(max_len).expand(b, max_len).cuda() < input_length.unsqueeze(1) - loss_mask = loss_mask.float() - return input_ids, input_length, labels, loss_mask + self.unfreeze() + known_groups = [] + if self.cfg.get('freeze_llm', True): + for param in self.model.parameters(): + param.requires_grad = False + known_groups.append('model.') + # TODO(heh): double check this part works properly + if self.cfg.get('freeze_matcher', False): + self.perception.matcher.freeze() + known_groups.append('matcher.') + if self.cfg.get('freeze_audio_encoder', False): + self.perception.encoder.freeze() + known_groups.append('audio_encoder.') + + opt_params = [] + for _, module in self.named_modules(): + if isinstance(module, adapter_mixins.AdapterModuleMixin) and module.is_adapter_available(): + module.set_enabled_adapters(enabled=True) + module.unfreeze_enabled_adapters() # selectively unfreeze the adapter modules. + opt_params += [p for p in module.parameters()] + + param_groups = [] + if "optim_param_groups" in self.cfg: + param_groups_cfg = self.cfg.optim_param_groups + for group, group_cfg in param_groups_cfg.items(): + module = getattr(self, group, None) + if module is None: + raise ValueError(f"{group} not found in model.") + elif hasattr(module, "parameters"): + known_groups.append(f"{group}.") + new_group = {"params": module.parameters()} + for k, v in group_cfg.items(): + new_group[k] = v + param_groups.append(new_group) + else: + raise ValueError(f"{group} does not have parameters.") + + for n, p in self.named_parameters(): + is_unknown = True + for group in known_groups: + if n.startswith(group): + is_unknown = False + if is_unknown: + opt_params.append(p) + + param_groups = [{"params": opt_params}] + param_groups + + self._optimizer_param_groups = param_groups + logging.info(f"Optimizer groups set:\n{self.summarize()}") def prepare_llm_input(self, audio_batch): def _concat_embs(embs1, emb1_lens, embs2, emb2_lens): concat_emb = [] concat_len = [] - for emb1, emb1_len, emb2, emb2_len in zip( - embs1, emb1_lens, embs2, emb2_lens - ): + for emb1, emb1_len, emb2, emb2_len in zip(embs1, emb1_lens, embs2, emb2_lens): new_len = emb1_len + emb2_len new_emb = torch.concat([emb1[:emb1_len], emb2[:emb2_len]], axis=0) - padded_new_emb = torch.zeros( - emb1.shape[0] + emb2.shape[0], emb1.shape[-1] - ) + padded_new_emb = torch.zeros(emb1.shape[0] + emb2.shape[0], emb1.shape[-1], device=emb1.device) padded_new_emb[:new_len, ...] = new_emb concat_emb.append(padded_new_emb) concat_len.append(new_len) @@ -182,73 +177,46 @@ def _concat_embs(embs1, emb1_lens, embs2, emb2_lens): concat_len = torch.stack(concat_len, dim=0) return concat_emb, concat_len - def _shift_labels_by_emb_len( - labels, label_lens, emb_lens, max_len, pad_token=0 - ): + def _shift_labels_by_emb_len(labels, label_lens, emb_lens, max_len, pad_token=0): shifted_labels = [] for label, label_len, emb_len in zip(labels, label_lens, emb_lens): - shifted_label = torch.full([max_len], pad_token) + shifted_label = torch.full([max_len], pad_token, device=label.device) shifted_label[emb_len : emb_len + label_len] = label[:label_len] shifted_labels.append(shifted_label) shifted_labels = torch.stack(shifted_labels, dim=0) return shifted_labels - signal, signal_len, _, _ = audio_batch + input_signal = audio_batch['audio_signal'] + input_signal_length = audio_batch['audio_signal_length'] - # forward() only performs encoder forward - if isinstance(audio_batch, DALIOutputs) and audio_batch.has_processed_signal: - ( - input_signal, - input_signal_length, - processed_signal, - processed_signal_length, - ) = (None, None, signal, signal_len) - else: - ( - input_signal, - input_signal_length, - processed_signal, - processed_signal_length, - ) = (signal, signal_len, None, None) - - input_ids, input_length, labels, loss_mask = self.get_text_batch_from_audio( - audio_batch + input_ids, input_length, labels, loss_mask = ( + audio_batch['tokens'], + audio_batch['tokens_length'], + audio_batch['labels'], + audio_batch['loss_mask'], ) - if not self.frozen_model.model.pre_process: - raise ValueError("Model does not have pre_process method defined.") - # [b, t, c] encoded, encoded_len = self.perception( input_signal=input_signal, input_signal_length=input_signal_length, - processed_signal=processed_signal, - processed_signal_length=processed_signal_length, + processed_signal=None, + processed_signal_length=None, ) - if self.fixed_prompt_prefix is not None: - fixed_prompt_prefix = self.fixed_prompt_prefix.expand(encoded.shape[0], -1) - prompt_prefix = self.word_embeddings(fixed_prompt_prefix) - encoded = torch.cat([prompt_prefix, encoded], dim=1) - encoded_len += fixed_prompt_prefix.shape[1] # [b, t, c] - input_embeds = self.word_embeddings(input_ids) - encoder_input, encoder_length = _concat_embs( - encoded, encoded_len, input_embeds, input_length - ) - labels = _shift_labels_by_emb_len( - labels, input_length, encoded_len, encoder_input.shape[1], pad_token=0 - ) + lm_embedding = self.model.language_model.embedding + input_embeds = lm_embedding.word_embeddings(input_ids) + encoder_input, encoder_length = _concat_embs(encoded, encoded_len, input_embeds, input_length) + labels = _shift_labels_by_emb_len(labels, input_length, encoded_len, encoder_input.shape[1], pad_token=0) # Loss mask where answer tokens are 1.0 and all other tokens are 0.0 - loss_mask = _shift_labels_by_emb_len( - loss_mask, input_length, encoded_len, encoder_input.shape[1], pad_token=0 - ) + loss_mask = _shift_labels_by_emb_len(loss_mask, input_length, encoded_len, encoder_input.shape[1], pad_token=0) b = encoder_input.shape[0] max_len = encoder_input.shape[1] # Using causal attention mask for whole input # TODO(zhehuai): use prefixlm instead for the audio embeddings - attention_mask = torch.tril(torch.ones((b, max_len, max_len))).view( + attention_mask = torch.tril(torch.ones((b, max_len, max_len), device=encoder_input.device)).view( b, 1, max_len, max_len ) # Convert attention mask from float to bool @@ -256,329 +224,250 @@ def _shift_labels_by_emb_len( position_ids = build_position_ids(encoder_input[:, :, 0]) # Add position embeddings - if hasattr( - self.frozen_model.model.language_model.embedding, "position_embeddings" - ): - position_embeddings = ( - self.frozen_model.model.language_model.embedding.position_embeddings( - position_ids - ) - ) + if hasattr(lm_embedding, "position_embeddings"): + position_embeddings = lm_embedding.position_embeddings(position_ids) encoder_input = encoder_input + position_embeddings else: encoder_input = encoder_input encoder_input = encoder_input.transpose(0, 1).contiguous() if self.cfg.get("sequence_parallel", False): - encoder_input = ( - tensor_parallel.mappings.scatter_to_sequence_parallel_region( - encoder_input - ) - ) + encoder_input = tensor_parallel.mappings.scatter_to_sequence_parallel_region(encoder_input) return encoder_input, attention_mask, labels, loss_mask, encoder_length def forward( - self, - audio_batch, - inference=True, - set_inference_key_value_memory=False, - inference_max_sequence_len=None, + self, audio_batch, checkpoint_activations_all_layers, ): """Forward pass of the model. - We first prepend a fixed text instruction that briefly describes the - task to the audio embeddings. Then we prepend audio embeddings to - the label text tokens as the LLM input. - TODO(zhehuai): read text instruction from the SFT dataset, set loss_mask - accordingly, following pad_batch_and_build_loss_mask. + We prepend audio embeddings to the instruction and label text tokens + as the LLM input. """ - - # concat the text embeddings and the audio embeddings together to form the input embeddings - - encoder_input, attention_mask, labels, loss_mask, _ = self.prepare_llm_input( - audio_batch - ) - output = self.frozen_model.model( + encoder_input, attention_mask, labels, loss_mask, _ = self.prepare_llm_input(audio_batch) + output = self.model( input_ids=None, position_ids=None, encoder_input=encoder_input, attention_mask=attention_mask, labels=labels, - set_inference_key_value_memory=set_inference_key_value_memory, - inference_max_sequence_len=inference_max_sequence_len, + checkpoint_activations_all_layers=checkpoint_activations_all_layers, ) return output, loss_mask - def get_forward_output_and_loss_func(self): - def fwd_output_and_loss_func(dataloader_iter, model): + def get_forward_output_and_loss_func(self, validation_step=False): + def fwd_output_and_loss_func(dataloader_iter, model, checkpoint_activations_all_layers=None): batch = next(dataloader_iter) - batch = [x.cuda(non_blocking=True) for x in batch] - output_tensor, loss_mask = model(batch, inference=False) - - if isinstance(output_tensor, tuple): - output_tensor, _ = output_tensor + batch = {key: val.cuda(non_blocking=True) for key, val in batch.items()} + output_tensor, loss_mask = self.forward( + batch, checkpoint_activations_all_layers=checkpoint_activations_all_layers + ) + output_tensor = output_tensor[0] # get loss only, ingore logits def loss_func(output_tensor): - loss = self.frozen_model.loss_func(loss_mask, output_tensor) - reduced_loss = average_losses_across_data_parallel_group([loss]) - return loss, {"avg": reduced_loss} + # Loss for a micro-batch (ub) + loss_for_ub = self.loss_func(loss_mask, output_tensor) + if validation_step and not self.cfg.data.get('validation_drop_last', True): + num_valid_tokens_in_ub = batch['loss_mask'].sum() + if loss_for_ub.isnan(): + assert batch['loss_mask'].count_nonzero() == 0, 'Got NaN loss with non-empty input' + loss_sum_for_ub = torch.zeros_like(num_valid_tokens_in_ub) + else: + loss_sum_for_ub = num_valid_tokens_in_ub * loss_for_ub + + loss_sum_and_ub_size_all_gpu = torch.cat( + [ + loss_sum_for_ub.clone().detach().view(1), + torch.tensor([num_valid_tokens_in_ub]).cuda().clone().detach(), + ] + ) + # Could potentially reduce num_valid_samples_in_microbatch and use that to aggregate instead of len(self._validation_ds) + torch.distributed.all_reduce( + loss_sum_and_ub_size_all_gpu, group=parallel_state.get_data_parallel_group() + ) + return loss_for_ub, {'loss_sum_and_ub_size': loss_sum_and_ub_size_all_gpu} + else: + reduced_loss = average_losses_across_data_parallel_group([loss_for_ub]) + return loss_for_ub, {'avg': reduced_loss} return output_tensor, loss_func return fwd_output_and_loss_func - def training_step(self, batch, batch_nb): - # Reset access registry - if AccessMixin.is_access_enabled(): - AccessMixin.reset_registry(self) - - signal, signal_len, transcript, transcript_len = batch - loss_mean = self.fwd_bwd_step( - itertools.chain([batch]), None, forward_only=False - ) - self.allreduce_gradients() - - ## logging - # we can only log on one rank if it is rank zero so we broadcast from last rank - # we can avoid this broadcast by updating the PTL log function to accept specific ranks - torch.distributed.broadcast(loss_mean, get_last_rank()) - - # TODO(zhehuai): add loss and step logging - return loss_mean - - def predict_step(self, batch, batch_idx, dataloader_idx=0): - # TODO(zhehuai) support infernece - pass - - def validation_step(self, batch, batch_idx, dataloader_idx=0): - loss_mean = self.fwd_bwd_step(itertools.chain([batch]), None, forward_only=True) - if loss_mean.item == 0.0: - loss_mean = [] - return {"loss": loss_mean} - - # dataset configuration - def _setup_dataloader_from_config(self, config: Optional[Dict], for_train=True): - dataset = audio_to_text_dataset.get_audio_to_text_bpe_dataset_from_config( - config=config, - local_rank=self.local_rank, - global_rank=self.global_rank, - world_size=self.world_size, - tokenizer=self.tokenizer, - preprocessor_cfg=self.cfg.get("preprocessor", None), - ) - - if dataset is None: - return None - - if isinstance(dataset, AudioToBPEDALIDataset): - # DALI Dataset implements dataloader interface - return dataset + def _build_dataset(self, data_cfg, is_train=True): + datasets = [] - shuffle = config["shuffle"] - if isinstance(dataset, torch.utils.data.IterableDataset): - shuffle = False - - # TODO(zhehuai): test distributed dataloader and parallel_state - # Make distributed dataloader following build_virtual_prompt_dataset - rank = parallel_state.get_data_parallel_rank() - data_parallel_size = parallel_state.get_data_parallel_world_size() - sampler = torch.utils.data.distributed.DistributedSampler( - dataset, num_replicas=data_parallel_size, rank=rank, shuffle=shuffle, seed=self.cfg.seed - ) + if isinstance(data_cfg.file_names, str): + file_names = data_cfg.file_names.split(',') + else: + file_names = data_cfg.file_names + + if is_train and not data_cfg.get('is_tarred', False): + # Construct the data prefix list for `get_datasets_weights_and_num_samples()` + # that is of the format [weight1,file_name1,weight2,file_name2,...] + concat_sampling_probabilities = data_cfg.get('concat_sampling_probabilities', None) + if concat_sampling_probabilities is None: + concat_sampling_probabilities = [1.0 / len(file_names)] * len(file_names) + elif len(data_cfg.get('concat_sampling_probabilities', None)) != len(file_names): + raise ValueError( + ( + f"concat_sampling_probabilities must be of the same size as file_names.", + f"Provided size {len(data_cfg.concat_sampling_probabilities)}, number of datasets {len(file_names)}", + ) + ) - batch_size=config["batch_size"] - assert batch_size % data_parallel_size == 0, "Global batch size must be evenly divisible by data parallel size" + data_prefix = [] + for weight, prefix in zip(concat_sampling_probabilities, file_names): + data_prefix.append(weight) + data_prefix.append(prefix) - if for_train: - if self.cfg.get("sequence_parallel", False): - collate_fn = partial( - dataset.collate_fn, tp_workers=parallel_state.get_tensor_model_parallel_world_size() + if self.trainer.max_steps is None or self.trainer.max_steps <= 0: + raise ValueError( + f'Trainer max_steps must be set to a positive integer. Found {self.trainer.max_steps}' ) - else: - collate_fn = partial(dataset.collate_fn, tp_workers=0) + num_train_samples = [self.trainer.max_steps * data_cfg.global_batch_size] + _, _, num_train_samples_per_dataset = get_datasets_weights_and_num_samples(data_prefix, num_train_samples) + num_train_samples_after_blend = sum([x[0] for x in num_train_samples_per_dataset]) else: - collate_fn = dataset.inference_collate_fn - assert config.get("num_workers", 0) > 0, "(@adithyare and @eharper) We need this to make spawn=True to work." - - return torch.utils.data.DataLoader( - dataset, - collate_fn=collate_fn, - sampler=sampler, - batch_size=batch_size // data_parallel_size, - shuffle=shuffle, - drop_last=config.get("drop_last", False), - num_workers=config.get("num_workers", 0), - pin_memory=config.get("pin_memory", False), - persistent_workers=True - ) + num_train_samples_per_dataset = [[None]] * len(data_cfg.file_names) - def _setup_transcribe_dataloader( - self, config: Dict - ) -> "torch.utils.data.DataLoader": - """ - Setup function for a temporary data loader which wraps the provided audio file. - - Args: - config: A python dictionary which contains the following keys: - paths2audio_files: (a list) of paths to audio files. The files should be relatively short fragments. \ - Recommended length per file is between 5 and 25 seconds. - batch_size: (int) batch size to use during inference. \ - Bigger will result in better throughput performance but would use more memory. - temp_dir: (str) A temporary directory where the audio manifest is temporarily - stored. - - Returns: - A pytorch DataLoader for the given audio file(s). - """ - if "manifest_filepath" in config: - manifest_filepath = config["manifest_filepath"] - batch_size = config["batch_size"] + if 'augmentor' in data_cfg: + augmentor = process_augmentations( + data_cfg['augmentor'], global_rank=self.global_rank, world_size=self.world_size + ) else: - manifest_filepath = os.path.join(config["temp_dir"], "manifest.json") - batch_size = min(config["batch_size"], len(config["paths2audio_files"])) - - dl_config = { - "manifest_filepath": manifest_filepath, - "sample_rate": self.preprocessor._sample_rate, - "batch_size": batch_size, - "shuffle": False, - "num_workers": config.get( - "num_workers", min(batch_size, os.cpu_count() - 1) - ), - "pin_memory": True, - "channel_selector": config.get("channel_selector", None), - "use_start_end_token": self.cfg.validation_ds.get( - "use_start_end_token", False - ), - } - - if config.get("augmentor"): - dl_config["augmentor"] = config.get("augmentor") - - temporary_datalayer = self._setup_dataloader_from_config( - config=DictConfig(dl_config), for_train=False - ) - return temporary_datalayer - - def setup_training_data(self, train_data_config: Optional[Union[DictConfig, Dict]]): - """ - TODO(zhehuai): support unpaired data and the mixing of paired and unpaired data. - Sets up the training data loader via a Dict-like object. - - Args: - train_data_config: A config that contains the information regarding construction - of an ASR Training dataset. - - Supported Datasets: - - :class:`~nemo.collections.asr.data.audio_to_text.AudioToCharDataset` - - :class:`~nemo.collections.asr.data.audio_to_text.AudioToBPEDataset` - - :class:`~nemo.collections.asr.data.audio_to_text.TarredAudioToCharDataset` - - :class:`~nemo.collections.asr.data.audio_to_text.TarredAudioToBPEDataset` - - :class:`~nemo.collections.asr.data.audio_to_text_dali.AudioToCharDALIDataset` - """ - if "shuffle" not in train_data_config: - train_data_config["shuffle"] = True - - # preserve config - self._update_dataset_config(dataset_name="train", config=train_data_config) - - self._train_dl = self._setup_dataloader_from_config(config=train_data_config, for_train=True) - - # Need to set this because if using an IterableDataset, the length of the dataloader is the total number - # of samples rather than the number of batches, and this messes up the tqdm progress bar. - # So we set the number of steps manually (to the correct number) to fix this. - if ( - self._train_dl is not None - and hasattr(self._train_dl, "dataset") - and isinstance(self._train_dl.dataset, torch.utils.data.IterableDataset) - ): - # We also need to check if limit_train_batches is already set. - # If it's an int, we assume that the user has set it to something sane, i.e. <= # training batches, - # and don't change it. Otherwise, adjust batches accordingly if it's a float (including 1.0). - if self._trainer is not None and isinstance( - self._trainer.limit_train_batches, float - ): - self._trainer.limit_train_batches = int( - self._trainer.limit_train_batches - * ceil( - (len(self._train_dl.dataset) / self.world_size) - / train_data_config["batch_size"] - ) - ) - elif self._trainer is None: - logging.warning( - "Model Trainer was not set before constructing the dataset, incorrect number of " - "training batches will be used. Please set the trainer and rebuild the dataset." - ) - - def setup_validation_data(self, val_data_config: Optional[Union[DictConfig, Dict]]): - """ - Sets up the validation data loader via a Dict-like object. - - Args: - val_data_config: A config that contains the information regarding construction - of an ASR Training dataset. - - Supported Datasets: - - :class:`~nemo.collections.asr.data.audio_to_text.AudioToCharDataset` - - :class:`~nemo.collections.asr.data.audio_to_text.AudioToBPEDataset` - - :class:`~nemo.collections.asr.data.audio_to_text.TarredAudioToCharDataset` - - :class:`~nemo.collections.asr.data.audio_to_text.TarredAudioToBPEDataset` - - :class:`~nemo.collections.asr.data.audio_to_text_dali.AudioToCharDALIDataset` - """ - if "shuffle" not in val_data_config: - val_data_config["shuffle"] = False - - # preserve config - self._update_dataset_config(dataset_name="validation", config=val_data_config) + augmentor = None + for file_path, num_samples in zip(file_names, num_train_samples_per_dataset): + dataset = AudioQuestionAnswerDataset( + manifest_filepath=file_path, + tokenizer=self.tokenizer, + sample_rate=data_cfg.sample_rate, + int_values=data_cfg.get('int_values', False), + augmentor=augmentor, + max_duration=getattr(data_cfg, 'max_duration', None), + min_duration=getattr(data_cfg, 'min_duration', None), + max_utts=getattr(data_cfg, 'max_utts', -1), + trim=getattr(data_cfg, 'trim_silence', False), + channel_selector=getattr(data_cfg, 'channel_selector', None), + max_seq_length=data_cfg.max_seq_length, + min_seq_length=data_cfg.min_seq_length, + add_bos=data_cfg.get('add_bos', False), + add_eos=data_cfg.get('add_eos', True), + add_sep=data_cfg.get('add_sep', False), + sep_id=self.sep_id, + max_num_samples=num_samples[0], + seed=data_cfg.get('seed', 1234), + separate_prompt_and_response_with_newline=data_cfg.get( + 'separate_prompt_and_response_with_newline', True + ), + answer_only_loss=self.cfg.get('answer_only_loss', True), + truncation_field=data_cfg.get('truncation_field', 'context'), + pad_to_max_length=False, + index_mapping_dir=data_cfg.get('index_mapping_dir', None), + prompt_template=data_cfg.get('prompt_template', None), + virtual_tokens=self.virtual_tokens, + tokens_to_generate=data_cfg.get( + 'tokens_to_generate', 0 + ), # used at inference time to allocate tensor positions for tokens that will be generated by inf procedure. + ) + datasets.append(dataset) - self._validation_dl = self._setup_dataloader_from_config(config=val_data_config, for_train=True) + if is_train and not data_cfg.get('is_tarred', False): + dataset = BlendableDataset( + datasets=datasets, weights=concat_sampling_probabilities, size=num_train_samples_after_blend + ) + return dataset + else: + return datasets - def setup_test_data(self, test_data_config: Optional[Union[DictConfig, Dict]]): + @classmethod + def _modify_config(cls, gpt_cfg, cfg, audio_cfg, add_cfg_to_tree=False): """ - Sets up the test data loader via a Dict-like object. - - Args: - test_data_config: A config that contains the information regarding construction - of an ASR Training dataset. - - Supported Datasets: - - :class:`~nemo.collections.asr.data.audio_to_text.AudioToCharDataset` - - :class:`~nemo.collections.asr.data.audio_to_text.AudioToBPEDataset` - - :class:`~nemo.collections.asr.data.audio_to_text.TarredAudioToCharDataset` - - :class:`~nemo.collections.asr.data.audio_to_text.TarredAudioToBPEDataset` - - :class:`~nemo.collections.asr.data.audio_to_text_dali.AudioToCharDALIDataset` + This function modifies the original gpt pre-training config (gpt_cfg) with attributes from the finetuning config (cfg). + The `add_cfg_to_tree` arg adds `cfg` to the top of the yaml tree which is needed for all `hparams.yaml` files when passed as an arg to `load_from_checkpoint()`. """ - if "shuffle" not in test_data_config: - test_data_config["shuffle"] = False - - # preserve config - self._update_dataset_config(dataset_name="test", config=test_data_config) - - self._test_dl = self._setup_dataloader_from_config(config=test_data_config, for_train=False) - - @property - def input_types(self) -> Optional[Dict[str, NeuralType]]: - if hasattr(self.preprocessor, "_sample_rate"): - input_signal_eltype = AudioSignal(freq=self.preprocessor._sample_rate) + OmegaConf.set_struct(gpt_cfg, True) + OmegaConf.resolve(cfg) + with open_dict(gpt_cfg): + gpt_cfg.megatron_amp_O2 = cfg.model.get('megatron_amp_O2', False) + gpt_cfg.micro_batch_size = cfg.model.data.train_ds.micro_batch_size + gpt_cfg.global_batch_size = cfg.model.data.train_ds.global_batch_size + gpt_cfg.sequence_parallel = cfg.model.get("sequence_parallel", False) + gpt_cfg.activations_checkpoint_granularity = cfg.model.get("activations_checkpoint_granularity", None) + gpt_cfg.activations_checkpoint_num_layers = cfg.model.get("activations_checkpoint_num_layers", None) + gpt_cfg.activations_checkpoint_method = cfg.model.get("activations_checkpoint_method", None) + gpt_cfg.data = cfg.model.data + gpt_cfg.optim = cfg.model.optim + gpt_cfg.precision = cfg.trainer.precision + gpt_cfg.answer_only_loss = cfg.model.answer_only_loss + gpt_cfg.restore_from_path = cfg.model.restore_from_path + gpt_cfg.resume_from_checkpoint = cfg.model.resume_from_checkpoint + gpt_cfg.save_nemo_on_validation_end = cfg.model.save_nemo_on_validation_end + gpt_cfg.gradient_as_bucket_view = cfg.model.gradient_as_bucket_view + gpt_cfg.hidden_dropout = cfg.model.get('hidden_dropout', 0.0) + gpt_cfg.attention_dropout = cfg.model.get('attention_dropout', 0.0) + gpt_cfg.ffn_dropout = cfg.model.ffn_dropout + gpt_cfg.peft = cfg.model.peft + # for AudioGPTLoRAModel + gpt_cfg.target = f"{cls.__module__}.{cls.__name__}" + gpt_cfg.perception = cfg.model.perception + gpt_cfg.perception.preprocessor = audio_cfg.preprocessor + gpt_cfg.perception.encoder = audio_cfg.encoder + matcher_cfg = gpt_cfg.perception.matcher + matcher_cfg.feat_in = audio_cfg.encoder.d_model + gpt_cfg.perception.output_dim = gpt_cfg.hidden_size + # This is needed when modifying a hparam file directly to load `.ckpt` files. + # This is not needed to modify the cfg in `.nemo` files. + if add_cfg_to_tree: + OmegaConf.resolve(gpt_cfg) + gpt_cfg.cfg = gpt_cfg + + return gpt_cfg + + @classmethod + def restore_from_pretrained_models( + cls, cfg: Optional[Union[OmegaConf, str]] = None, trainer: Optional[Trainer] = None, + ): + if not cfg.model.pretrained_audio_model: + raise RuntimeError("PEFT training needs a pretrained audio model present.") + + if not cfg.model.restore_from_path: + raise RuntimeError("PEFT training needs a trained base model present.") + + base_model_save_restore_connector = NLPSaveRestoreConnector() + if os.path.isdir(cfg.model.restore_from_path): + base_model_save_restore_connector.model_extracted_dir = cfg.model.restore_from_path + base_model_cfg = cls.restore_from( + restore_path=cfg.model.restore_from_path, + trainer=trainer, + return_config=True, + save_restore_connector=base_model_save_restore_connector, + ) + pretrained_audio_model = cfg.model.pretrained_audio_model + if pretrained_audio_model.endswith('.nemo'): + logging.info(f'Loading pretrained audio model from local file: {pretrained_audio_model}') + audio_model = ASRModel.restore_from(pretrained_audio_model, map_location='cpu') else: - input_signal_eltype = AudioSignal() - - return { - "input_signal": NeuralType(("B", "T"), input_signal_eltype, optional=True), - "input_signal_length": NeuralType(tuple("B"), LengthsType(), optional=True), - "processed_signal": NeuralType( - ("B", "D", "T"), SpectrogramType(), optional=True - ), - "processed_signal_length": NeuralType( - tuple("B"), LengthsType(), optional=True - ), - } - - @property - def output_types(self) -> Optional[Dict[str, NeuralType]]: - return { - "outputs": NeuralType(("B", "D", "T"), AcousticEncodedRepresentation()), - "encoded_lengths": NeuralType(tuple("B"), LengthsType()), - } + logging.info(f'Loading pretrained audio model from NGC: {pretrained_audio_model}') + audio_model = ASRModel.from_pretrained(pretrained_audio_model, map_location='cpu') + + model_cfg = cls._modify_config(base_model_cfg, cfg, audio_model.cfg, add_cfg_to_tree=False) + resume_from_checkpoint = trainer._checkpoint_connector.resume_from_checkpoint_fit_path + save_restore_connector = PEFTSaveRestoreConnector( + peft_model_nemo_path=cfg.model.peft.restore_from_path, peft_model_ckpt_path=resume_from_checkpoint + ) + if os.path.isdir(cfg.model.restore_from_path): + save_restore_connector.model_extracted_dir = cfg.model.restore_from_path + + # load llm + model = cls.restore_from( + restore_path=cfg.model.restore_from_path, + trainer=trainer, + override_config_path=model_cfg, + save_restore_connector=save_restore_connector, + strict=False, + ) + # load am + model.perception.encoder.load_state_dict(audio_model.encoder.state_dict(), strict=True) + logging.info(f'Loaded pretrained audio model from {pretrained_audio_model}') + return model diff --git a/nemo/collections/multimodal/modules/__init__.py b/nemo/collections/multimodal/modules/__init__.py index 3bb6a3cbacfa..e23f18c12d23 100644 --- a/nemo/collections/multimodal/modules/__init__.py +++ b/nemo/collections/multimodal/modules/__init__.py @@ -12,4 +12,4 @@ # See the License for the specific language governing permissions and # limitations under the License. -from nemo.collections.multimodal.modules.speechllm_perception import * \ No newline at end of file +from nemo.collections.multimodal.modules.speechllm_perception import * diff --git a/nemo/collections/multimodal/modules/speechllm_perception.py b/nemo/collections/multimodal/modules/speechllm_perception.py index 4270e93316fb..3745c5971c73 100644 --- a/nemo/collections/multimodal/modules/speechllm_perception.py +++ b/nemo/collections/multimodal/modules/speechllm_perception.py @@ -13,22 +13,17 @@ # limitations under the License. from collections import OrderedDict +from typing import Any, Dict, Optional import torch import torch.distributed import torch.nn as nn +from omegaconf.dictconfig import DictConfig -from typing import Optional, Dict, Any from nemo.core.classes.common import typecheck from nemo.core.classes.exportable import Exportable from nemo.core.classes.module import NeuralModule -from nemo.core.neural_types import ( - AcousticEncodedRepresentation, - AudioSignal, - LengthsType, - NeuralType, - SpectrogramType, -) +from nemo.core.neural_types import AcousticEncodedRepresentation, AudioSignal, LengthsType, NeuralType, SpectrogramType __all__ = ["AudioPerceptionModel"] @@ -36,9 +31,7 @@ class AudioPerceptionModel(NeuralModule, Exportable): """Audio perception model with basic matcher (some fc layers).""" - def input_example( - self, max_batch: int = 8, max_dim: int = 32000, min_length: int = 200 - ): + def input_example(self, max_batch: int = 8, max_dim: int = 32000, min_length: int = 200): batch_size = torch.randint(low=1, high=max_batch, size=[1]).item() max_length = torch.randint(low=min_length, high=max_dim, size=[1]).item() signals = torch.rand(size=[batch_size, max_length]) * 2 - 1 @@ -51,9 +44,7 @@ def input_types(self): """Returns definitions of module input ports.""" return OrderedDict( { - "input_signal": NeuralType( - ("B", "T"), AudioSignal(freq=self.preprocessor._sample_rate) - ), + "input_signal": NeuralType(("B", "T"), AudioSignal(freq=self.preprocessor._sample_rate)), "input_signal_length": NeuralType( tuple("B"), LengthsType() ), # Please note that length should be in samples not seconds. @@ -72,37 +63,20 @@ def output_types(self): } ) - def __init__( - self, - preprocessor: Dict[str, Any], - encoder: Dict[str, Any], - matcher: Dict[str, Any], - d_model: int, - spec_augment: Optional[Dict[str, Any]] = None, - freeze_encoder: bool = False, - ): + def __init__(self, cfg: DictConfig): super().__init__() # Initialize components - self.preprocessor = preprocessor - self.encoder = encoder - self.spec_augmentation = spec_augment - self.matcher = matcher - self.proj = nn.Linear(matcher.d_model, d_model) - if freeze_encoder: - for params in self.encoder.parameters(): - params.requires_grad = False + self.preprocessor = self.from_config_dict(cfg.preprocessor) + self.encoder = self.from_config_dict(cfg.encoder) + self.spec_augmentation = self.from_config_dict(cfg.spec_augment) + self.matcher = self.from_config_dict(cfg.matcher) + self.proj = nn.Linear(cfg.matcher.d_model, cfg.output_dim) def maybe_preprocess_audio( - self, - input_signal=None, - input_signal_length=None, - processed_signal=None, - processed_signal_length=None, + self, input_signal=None, input_signal_length=None, processed_signal=None, processed_signal_length=None, ): has_input_signal = input_signal is not None and input_signal_length is not None - has_processed_signal = ( - processed_signal is not None and processed_signal_length is not None - ) + has_processed_signal = processed_signal is not None and processed_signal_length is not None if (has_input_signal ^ has_processed_signal) is False: raise ValueError( f"{self} Arguments ``input_signal`` and ``input_signal_length`` are mutually exclusive " @@ -111,18 +85,13 @@ def maybe_preprocess_audio( if not has_processed_signal: processed_signal, processed_signal_length = self.preprocessor( - input_signal=input_signal, - length=input_signal_length, + input_signal=input_signal, length=input_signal_length, ) return processed_signal, processed_signal_length @typecheck() def forward( - self, - input_signal=None, - input_signal_length=None, - processed_signal=None, - processed_signal_length=None, + self, input_signal=None, input_signal_length=None, processed_signal=None, processed_signal_length=None, ): processed_signal, processed_signal_length = self.maybe_preprocess_audio( input_signal, input_signal_length, processed_signal, processed_signal_length @@ -130,15 +99,11 @@ def forward( # Spec augment is not applied during evaluation/testing if self.spec_augmentation is not None and self.training: - processed_signal = self.spec_augmentation( - input_spec=processed_signal, length=processed_signal_length - ) + processed_signal = self.spec_augmentation(input_spec=processed_signal, length=processed_signal_length) - encoded, encoded_len = self.encoder( - audio_signal=processed_signal, length=processed_signal_length - ) + encoded, encoded_len = self.encoder(audio_signal=processed_signal, length=processed_signal_length) encoded, encoded_len = self.matcher(audio_signal=encoded, length=encoded_len) # b, t, c - encoded = self.proj(encoded.transpose(1,2)) + encoded = self.proj(encoded.transpose(1, 2)) return encoded, encoded_len diff --git a/nemo/collections/multimodal/parts/__init__.py b/nemo/collections/multimodal/parts/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/nemo/collections/multimodal/parts/utils/data_utils.py b/nemo/collections/multimodal/parts/utils/data_utils.py new file mode 100644 index 000000000000..306019de4384 --- /dev/null +++ b/nemo/collections/multimodal/parts/utils/data_utils.py @@ -0,0 +1,26 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import torch + + +def maybe_cast_to_list(x): + if isinstance(x, np.ndarray): + return [item.tolist() for item in x] + return x + + +def ceil_to_nearest(n, m): + return (n + m - 1) // m * m diff --git a/nemo/collections/nlp/models/language_modeling/megatron_gpt_peft_models.py b/nemo/collections/nlp/models/language_modeling/megatron_gpt_peft_models.py index 73579114234d..fa6f3553a757 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_gpt_peft_models.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_gpt_peft_models.py @@ -106,7 +106,7 @@ def load_state_dict(self, state_dict, strict: bool = True): assert set(state_dict.keys()) == self.adapter_keys super().load_state_dict(state_dict, strict=False) else: - super().load_state_dict(state_dict, strict=True) + super().load_state_dict(state_dict, strict=strict) def setup_optimizer_param_groups(self): """ diff --git a/tests/collections/multimodal/test_speechllm_models.py b/tests/collections/multimodal/test_speechllm_models.py index f19bbbc89199..b9465de85e22 100644 --- a/tests/collections/multimodal/test_speechllm_models.py +++ b/tests/collections/multimodal/test_speechllm_models.py @@ -27,16 +27,20 @@ from nemo.collections.nlp.parts.nlp_overrides import NLPDDPStrategy from pytorch_lightning.plugins.environments import TorchElasticEnvironment -from nemo.collections.multimodal.models.speechllm_models import ( - ModularizedSpeechGPTModel, -) +from nemo.collections.multimodal.models import speechllm_models from nemo.collections.asr.models.hybrid_asr_tts_models import ASRWithTTSModel -from nemo.collections.nlp.models.language_modeling.megatron_gpt_model import ( - MegatronGPTModel, +from nemo.collections.nlp.models.language_modeling.megatron.gpt_model import ( + GPTModel, ) +class ModularizedAudioGPTModel(speechllm_models.ModularizedAudioGPTModel): + # disable logging to avoid MisconfigurationException + def log(self, *args, **kwargs): + pass + def setup_module(): + pl.seed_everything(1) # init model parallel needed for LLM loss init_method = 'tcp://' master_ip = 'localhost' @@ -56,11 +60,13 @@ def llm_model_config(): "../../../examples/multimodel/conf/speechllm/modularized_speech_gpt_config.yaml", ) ) - # TODO(zhehuai): update train_ds and validation_ds - # wget -nc --content-disposition https://api.ngc.nvidia.com/v2/models/nvidia/nemo/megatron_gpt_345m/versions/1/files/megatron_gpt_345m.nemo -O /home/TestData/nlp/megatron_gpt/megatron_gpt_345m.nemo - config.model.language_model_path = ( - "/home/TestData/nlp/megatron_gpt/megatron_gpt_345m.nemo" + # TODO(zhehuai): move the following to Test /home/TestData + config.model.restore_from_path = ( + "/root/home/works/TestData/pretrained_models/megatron_gpt/gpt_pretrain_220m_len_4096_pos_alibi_step_595508_gbs256.nemo" ) + config.model.micro_batch_size = 64 + config.model.data.validation_ds.file_names = '/root/home/works/TestData/datasets/LibriSpeech/dev_clean_cleaned.json' + config.model.data.train_ds.file_names = '/root/home/works/TestData/datasets/LibriSpeech/dev_clean_cleaned.json' return config @@ -77,10 +83,11 @@ def trainer_config(): config_trainer.devices = 1 config_trainer.num_nodes = 1 config_trainer.max_epochs = 4 + config_trainer.max_steps = 1 config_trainer.val_check_interval = 1.0 # for PyTorch Native AMP set precision=16 - config_trainer.precision = 16 if torch.cuda.is_available() else 32 + config_trainer.precision = 32 # setup cluster environment parameters" # use torch elastic cluster environment so `create_process_externally` is True @@ -93,7 +100,7 @@ def trainer_config(): strategy = NLPDDPStrategy( ) plugins = [TorchElasticEnvironment()] - trainer = pl.Trainer(logger=None, plugins=plugins, strategy=strategy, **config_trainer) + trainer = pl.Trainer(logger=False, plugins=plugins, strategy=strategy, **config_trainer) return trainer, config_trainer @@ -117,86 +124,93 @@ def perception_model_config(): "preprocessor": DictConfig(preprocessor), "encoder": DictConfig(encoder), "matcher": DictConfig(encoder), - "d_model": 1024, + "output_dim": 1024, } ) return model_config -class TestModularizedSpeechGPTModel: +@pytest.fixture +def test_batch(): + signal_len = torch.from_numpy(np.array([64000, 64000])) + transcript = torch.arange(10).reshape(2, 5).int() + tokens = transcript[:,:-1] + labels = transcript[:,1:] + transcript_length = torch.Tensor([3, 2]).int() + # assuming context_lengths = [1, 1] + loss_mask = torch.Tensor([[0, 1, 1, 0], [0, 1, 0, 0]]) + batch = { + 'audio_signal_length':signal_len, + 'tokens':tokens, + 'tokens_length':transcript_length, + 'labels':labels, + 'loss_mask': loss_mask + } + batch['audio_signal'] = torch.randn([2, 64000]) + return batch + +class TestModularizedAudioGPTModel: @pytest.mark.unit def test_init_and_train( self, llm_model_config, perception_model_config, trainer_config ): + llm_model_config.model.pretrained_audio_model = "stt_en_fastconformer_transducer_large" llm_model_config.model.perception = perception_model_config trainer, llm_model_config.trainer = trainer_config - model = ModularizedSpeechGPTModel(cfg=llm_model_config.model, trainer=trainer) - - assert isinstance(model.frozen_model, MegatronGPTModel) + model = ModularizedAudioGPTModel.restore_from_pretrained_models(llm_model_config, trainer=trainer) + + assert isinstance(model.model, GPTModel) with tempfile.TemporaryDirectory() as tmpdir: save_path = str(Path(tmpdir) / "model.nemo") model.train() model.save_to(save_path) + @pytest.mark.unit def test_prepare_llm_input( - self, llm_model_config, perception_model_config, trainer_config + self, llm_model_config, perception_model_config, trainer_config, test_batch ): + llm_model_config.model.pretrained_audio_model = "stt_en_fastconformer_transducer_large" llm_model_config.model.perception = perception_model_config trainer, llm_model_config.trainer = trainer_config - model = ModularizedSpeechGPTModel(cfg=llm_model_config.model, trainer=trainer) + model = ModularizedAudioGPTModel.restore_from_pretrained_models(llm_model_config, trainer=trainer) model.cuda() model.train() - pl.seed_everything(1) - signal = torch.randn(2, 64000).cuda() - signal_len = torch.from_numpy(np.array([64000, 64000])).cuda() - transcript = torch.arange(8).reshape(2, 4).int().cuda() - transcript_length = torch.from_numpy(np.array([3, 2])).cuda() - batch = signal, signal_len, transcript, transcript_length + batch = {key: val.cuda(non_blocking=True) for key, val in test_batch.items()} encoder_input, attention_mask, labels, loss_mask, encoder_length = model.prepare_llm_input(batch) - assert encoder_input.shape == (40, 2, 1024) - assert np.allclose(encoder_input.sum().cpu().detach().numpy(), -788.436) - assert attention_mask.shape == (2, 1, 40, 40) - assert labels.shape == (2, 40) + assert encoder_input.shape == (17, 2, 768) + assert np.allclose(encoder_input.sum().cpu().detach().numpy(), 15.783691) + assert attention_mask.shape == (2, 1, 17, 17) + assert labels.shape == (2, 17) assert np.allclose(loss_mask.sum(axis=1).cpu().numpy(), [2, 1]) - assert np.allclose(encoder_length.cpu().numpy(), (39, 38)) + assert np.allclose(encoder_length.cpu().numpy(), (16, 15)) @pytest.mark.unit def test_training_step( - self, llm_model_config, perception_model_config, trainer_config + self, llm_model_config, perception_model_config, trainer_config, test_batch ): + llm_model_config.model.pretrained_audio_model = "stt_en_fastconformer_transducer_large" llm_model_config.model.perception = perception_model_config trainer, llm_model_config.trainer = trainer_config - model = ModularizedSpeechGPTModel(cfg=llm_model_config.model, trainer=trainer) + model = ModularizedAudioGPTModel.restore_from_pretrained_models(llm_model_config, trainer=trainer) model.cuda() + model.on_train_start() + model.setup() model.train() - pl.seed_everything(1) - signal = torch.randn(2, 64000) - signal_len = torch.from_numpy(np.array([64000, 64000])) - transcript = torch.arange(8).reshape(2, 4).int() - transcript_length = torch.from_numpy(np.array([3, 2])) - batch = signal, signal_len, transcript, transcript_length - loss_mean = model.training_step(batch, None) - assert np.allclose(loss_mean.cpu().detach().numpy(), 7.757044) + loss_mean = model.training_step(iter([test_batch]), None) + assert np.allclose(loss_mean.cpu().detach().numpy(), 6.014655) @pytest.mark.unit def test_validation_step( - self, llm_model_config, perception_model_config, trainer_config + self, llm_model_config, perception_model_config, trainer_config, test_batch ): + llm_model_config.model.pretrained_audio_model = "stt_en_fastconformer_transducer_large" llm_model_config.model.perception = perception_model_config trainer, llm_model_config.trainer = trainer_config - model = ModularizedSpeechGPTModel(cfg=llm_model_config.model, trainer=trainer) + model = ModularizedAudioGPTModel.restore_from_pretrained_models(llm_model_config, trainer=trainer) model.cuda() model.train() - pl.seed_everything(1) - signal = torch.randn(2, 64000) - signal_len = torch.from_numpy(np.array([64000, 64000])) - transcript = torch.arange(8).reshape(2, 4).int() - transcript_length = torch.from_numpy(np.array([3, 2])) - batch = signal, signal_len, transcript, transcript_length - loss_mean = model.validation_step(batch, None) - assert np.allclose(loss_mean['loss'].cpu().detach().numpy(), 7.7556906) - - # TODO(zhehuai): test ckpt restore + loss_mean = model.validation_step(iter([test_batch]), None) + assert np.allclose(loss_mean['loss'].cpu().detach().numpy(), 5.9237595) diff --git a/workspace/run_sft.sh b/workspace/run_sft.sh new file mode 100755 index 000000000000..f5179d30cb81 --- /dev/null +++ b/workspace/run_sft.sh @@ -0,0 +1,15 @@ +NEMO_DIR=/workspace/nemo/works/mod_speech_llm/NeMo +export PYTHONPATH=$NEMO_DIR:$PYTHONPATH + +MEGATRON_CKPT=/media/data3/pretrained_models/megatron_gpt/gpt_pretrain_220m_len_4096_pos_alibi_step_595508_gbs256.nemo +ASR_MODEL="stt_en_fastconformer_transducer_large" + +TRAIN_MANIFESTS=/media/data/datasets/LibriSpeech/train_clean_100_cleaned.json +VAL_MANIFESTS=/media/data/datasets/LibriSpeech/dev_clean.json + +python -m pdb -c continue run_sft_audio_lm.py --config-path="../examples/multimodel/conf/speechllm/" --config-name "modularized_speech_gpt_config" \ + model.pretrained_audio_model=$ASR_MODEL \ + model.restore_from_path=$MEGATRON_CKPT \ + model.data.train_ds.file_names=$TRAIN_MANIFESTS \ + model.data.validation_ds.file_names=$VAL_MANIFESTS + diff --git a/workspace/run_sft_audio_lm.py b/workspace/run_sft_audio_lm.py new file mode 100644 index 000000000000..3e84416666c4 --- /dev/null +++ b/workspace/run_sft_audio_lm.py @@ -0,0 +1,129 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import os +import tempfile + +import torch.multiprocessing as mp +from omegaconf.omegaconf import OmegaConf, open_dict +from pytorch_lightning import Trainer +from pytorch_lightning.plugins.environments import TorchElasticEnvironment +from pytorch_lightning.trainer.connectors.checkpoint_connector import CheckpointConnector +from torch.utils.data import DataLoader, Dataset + +from nemo.collections.asr.models import ASRModel +from nemo.collections.multimodal.models.speechllm_models import ModularizedAudioGPTModel +from nemo.collections.nlp.models.language_modeling.megatron_gpt_peft_models import ( + MegatronGPTAdapterModel, + MegatronGPTAdapterPTuningModel, + MegatronGPTIA3Model, + MegatronGPTLoRAModel, + MegatronGPTPTuningModel, +) +from nemo.collections.nlp.models.language_modeling.megatron_gpt_sft_model import MegatronGPTModel +from nemo.collections.nlp.modules.common.megatron.megatron_init import fake_initialize_model_parallel +from nemo.collections.nlp.parts.nlp_overrides import ( + GradScaler, + MegatronHalfPrecisionPlugin, + NLPDDPStrategy, + NLPSaveRestoreConnector, + PEFTSaveRestoreConnector, + PipelineMixedPrecisionPlugin, +) +from nemo.core.config import hydra_runner +from nemo.utils import AppState, logging +from nemo.utils.exp_manager import exp_manager +from nemo.utils.model_utils import inject_model_parallel_rank + +mp.set_start_method("spawn", force=True) + +""" +This is the script to train an Adapter infused GPT Model for audio question answering. +A base GPT Model is required as a starting point. This script will then insert +Adapters into each Transformer layer and will train/update only these adapters +during training. The base GPT Model weights will remain frozen. + +During training this script will only save the newly trained Adapter weights +in checkpoints. At the end of training a .nemo file of Adapter weights will +be saved. + +Usage: + Assuming the base model is a 125m GPT Model, with TP=1, PP=1: + a. run a training run for a base gpt nemo file: + python megatron_gpt_adapter_tuning.py \ + model.data.train_ds=[PATH TO TRAINING JSONL FILE], \ + model.data.validation_ds=[PATH TO VALIDATION JSONL FILE]",\ + model.pretrained_audio_model="PATH TO ASR MODEL (.nemo FILE or NGC MODEL NAME)" \ + model.restore_from_path="PATH TO BASE GPT MODEL .nemo FILE" \ + name="NAME OF TRAINING RUN" \ + exp_manager.exp_dir="DIR TO SAVE CHECKPOINTS and .nemo FILE" \ + trainer.max_epochs=2 +""" + + +@hydra_runner(config_path="conf", config_name="megatron_gpt_peft_tuning_config") +def main(cfg) -> None: + logging.info("\n\n************** Experiment configuration ***********") + logging.info(f'\n{OmegaConf.to_yaml(cfg)}') + + megatron_amp_o2 = cfg.model.get('megatron_amp_O2', False) + with_distributed_adam = cfg.model.optim.get('name') == 'distributed_fused_adam' + + plugins = [] + strategy = NLPDDPStrategy( + no_ddp_communication_hook=True, # we don't use DDP for async grad allreduce + gradient_as_bucket_view=cfg.model.gradient_as_bucket_view, + find_unused_parameters=False, + ) + if cfg.trainer.precision in [16, 'bf16']: + scaler = None + if cfg.trainer.precision == 16: + scaler = GradScaler( + init_scale=cfg.model.get('native_amp_init_scale', 2 ** 32), + growth_interval=cfg.model.get('native_amp_growth_interval', 1000), + hysteresis=cfg.model.get('hysteresis', 2), + enabled=False + if cfg.model.pipeline_model_parallel_size > 1 + else True, # turn off the grad scale for pipeline parallel LM model + ) + if megatron_amp_o2 and not with_distributed_adam: + plugins.append(MegatronHalfPrecisionPlugin(precision=cfg.trainer.precision, device='cuda', scaler=scaler)) + else: + plugins.append(PipelineMixedPrecisionPlugin(precision=cfg.trainer.precision, device='cuda', scaler=scaler)) + + if cfg.get('cluster_type', None) == 'BCP': + plugins.append(TorchElasticEnvironment()) + + trainer = Trainer(plugins=plugins, strategy=strategy, **cfg.trainer) + exp_manager(trainer, cfg.exp_manager) + # update resume from checkpoint found by exp_manager + if cfg.model.resume_from_checkpoint is not None: + resume_from_checkpoint = cfg.model.resume_from_checkpoint + else: + resume_from_checkpoint = trainer._checkpoint_connector.resume_from_checkpoint_fit_path + logging.info(f'Resuming training from checkpoint: {resume_from_checkpoint}') + + trainer._checkpoint_connector = CheckpointConnector(trainer, resume_from_checkpoint=resume_from_checkpoint) + + # hydra interpolation does not work here as the interpolation key is lost when PTL saves hparams + with open_dict(cfg): + cfg.model.precision = cfg.trainer.precision + model = ModularizedAudioGPTModel.restore_from_pretrained_models(cfg, trainer=trainer) + + trainer.fit(model) + + +if __name__ == '__main__': + main() diff --git a/workspace/scripts/convert_hf_llama_to_nemo.sh b/workspace/scripts/convert_hf_llama_to_nemo.sh new file mode 100755 index 000000000000..5b8629529804 --- /dev/null +++ b/workspace/scripts/convert_hf_llama_to_nemo.sh @@ -0,0 +1,16 @@ + + +NEMO_ROOT=/home/heh/codes/nemo-slm +export PYTHONPATH=${NEMO_ROOT}:${PYTHONPATH} + +TP=1 +PP=1 + +# INPUT_DIR=/media/data3/speech_llm/llama_hf/llama_hf/llama-13b-hf +# OUTPUT_DIR=/media/data3/speech_lm/llama/llama-13b-nemo + +MODEL_SIZE="llama-2-7b-chat" +INPUT_DIR=/media/data3/pretrained_models/llama2_hf/$MODEL_SIZE +OUTPUT_DIR=/media/data3/pretrained_models/llama2_nemo/$MODEL_SIZE/tp${TP}_pp${PP}/$MODEL_SIZE.nemo + +WORLD_SIZE=$TP python ../tools/convert_hf_llama_to_nemo.py --input_dir $INPUT_DIR --output_file $OUTPUT_DIR --local_rank 0 --tensor_model_parallel_size $TP --pipeline_model_parallel_size $PP diff --git a/workspace/scripts/convert_llama_to_hf.sh b/workspace/scripts/convert_llama_to_hf.sh new file mode 100755 index 000000000000..90e91b1d5b95 --- /dev/null +++ b/workspace/scripts/convert_llama_to_hf.sh @@ -0,0 +1,11 @@ + +HF_ROOT=/home/heh/github/transformers +SCRIPT=../tools/convert_llama2_to_hf.py + +MODEL_SIZE="llama-2-7b-chat" +LLAMA_DIR=/media/data3/pretrained_models/llama2_raw +OUTPUT_DIR=/media/data3/pretrained_models/llama2_hf/$MODEL_SIZE + + +python $SCRIPT --input_dir $LLAMA_DIR --model_size $MODEL_SIZE --output_dir $OUTPUT_DIR + diff --git a/workspace/tools/convert_hf_llama_to_nemo.py b/workspace/tools/convert_hf_llama_to_nemo.py new file mode 100644 index 000000000000..301fcc6e243d --- /dev/null +++ b/workspace/tools/convert_hf_llama_to_nemo.py @@ -0,0 +1,263 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +https://gitlab-master.nvidia.com/hongbinl/NeMo/-/blob/support_llama/examples/nlp/language_modeling/convert_hf_llama_to_nemo.py +""" + +import os +from argparse import ArgumentParser +from collections import OrderedDict + +import torch +from omegaconf import OmegaConf +from pytorch_lightning.core.saving import _load_state as ptl_load_state +from pytorch_lightning.trainer.trainer import Trainer +from transformers import LlamaForCausalLM + +from nemo.collections.nlp.models.language_modeling.megatron_llama_model import MegatronLLAMAModel +from nemo.collections.nlp.modules.common.megatron.megatron_init import initialize_model_parallel_for_nemo +from nemo.collections.nlp.parts.nlp_overrides import NLPSaveRestoreConnector +from nemo.utils import AppState, logging +from nemo.utils.distributed import initialize_distributed + + +def get_args(): + parser = ArgumentParser() + parser.add_argument( + "--input_dir", + type=str, + default=None, + required=True, + help="Path to Megatron-LM checkpoints saved during training. Ex: /raid/Megatron_LM/checkpoints", + ) + parser.add_argument("--output_file", type=str, default=None, required=False, help="Path to output .nemo file.") + + parser.add_argument("--gpus_per_node", type=int, required=False, default=1) + + parser.add_argument("--tensor_model_parallel_size", type=int, required=False, default=1) + parser.add_argument("--pipeline_model_parallel_size", type=int, required=False, default=1) + + parser.add_argument("--local_rank", type=int, required=False, default=os.getenv('LOCAL_RANK', 0)) + + parser.add_argument("--model_type", type=str, required=False, default="gpt", choices=["gpt", "t5", "bert"]) + + args = parser.parse_args() + return args + + +def load_model(cls, checkpoint, strict, **kwargs): + print(checkpoint[cls.CHECKPOINT_HYPER_PARAMS_KEY]) + try: + if 'cfg' in kwargs: + model = ptl_load_state(cls, checkpoint, strict=strict, **kwargs) + else: + model = ptl_load_state( + cls, checkpoint, strict=strict, cfg=checkpoint[cls.CHECKPOINT_HYPER_PARAMS_KEY].cfg, **kwargs + ) + # register the artifacts + cfg = checkpoint[cls.CHECKPOINT_HYPER_PARAMS_KEY].cfg + if cfg.tokenizer.model is not None: + model.register_artifact("tokenizer.tokenizer_model", cfg.tokenizer.model) + if cfg.tokenizer.vocab_file is not None: + model.register_artifact("tokenizer.vocab_file", cfg.tokenizer.vocab_file) + if cfg.tokenizer.merge_file is not None: + model.register_artifact("tokenizer.merge_file", cfg.tokenizer.merge_file) + finally: + cls._set_model_restore_state(is_being_restored=False) + return model + + +def load_config(llama_config, args): + nemo_config = {} + nemo_config['cfg'] = {} + nemo_config['cfg']['encoder_seq_length'] = llama_config.get( + 'max_sequence_length', llama_config['max_position_embeddings'] + ) + nemo_config['cfg']['num_layers'] = int(llama_config['num_hidden_layers']) + nemo_config['cfg']['hidden_size'] = llama_config['hidden_size'] + nemo_config['cfg']['ffn_hidden_size'] = llama_config['intermediate_size'] + nemo_config['cfg']['num_attention_heads'] = llama_config['num_attention_heads'] + nemo_config['cfg']['max_position_embeddings'] = llama_config['max_position_embeddings'] + nemo_config['cfg']['init_method_std'] = llama_config['initializer_range'] + nemo_config['cfg']['normalization'] = 'rmsnorm' + nemo_config['cfg']['layernorm_epsilon'] = llama_config['rms_norm_eps'] + nemo_config['cfg']['pre_process'] = True + nemo_config['cfg']['post_process'] = True + nemo_config['cfg']['bias'] = False + nemo_config['cfg']['hidden_dropout'] = 0.0 + nemo_config['cfg']['attention_dropout'] = 0.0 + nemo_config['cfg']['ffn_dropout'] = 0.0 + nemo_config['cfg']['bias_dropout_add_fusion'] = False + nemo_config['cfg']['bias_activation_fusion'] = False + nemo_config['cfg']['use_cpu_initialization'] = True + nemo_config['cfg']['share_embeddings_and_output_weights'] = False + nemo_config['cfg']['make_vocab_size_divisible_by'] = 128 + nemo_config['cfg']['activation'] = 'swiglu' + nemo_config['cfg']['transformer_block_type'] = 'pre_ln' + nemo_config['cfg']['position_embedding_type'] = 'rope' + nemo_config['cfg']['precision'] = 32 + nemo_config['cfg']['optim'] = {'name': 'fused_adam'} + nemo_config['cfg']['tokenizer'] = {} + nemo_config['cfg']['tokenizer']['library'] = 'sentencepiece' + nemo_config['cfg']['tokenizer']['type'] = 'null' + nemo_config['cfg']['tokenizer']['model'] = f'{args.input_dir}/tokenizer.model' + nemo_config['cfg']['tokenizer']['vocab_file'] = 'null' + nemo_config['cfg']['tokenizer']['merge_file'] = 'null' + nemo_config['cfg']['tokenizer']['tokenizer_model'] = 'null' + nemo_config['cfg']['tokenizer']['sentencepiece_legacy'] = False + nemo_config['cfg']['micro_batch_size'] = 1 + nemo_config['cfg']['global_batch_size'] = 1 + + nemo_config['cfg']['use_scaled_init_method'] = True + nemo_config['cfg']['normalize_attention_scores'] = True + nemo_config['cfg']['grad_allreduce_chunk_size_mb'] = 125 + nemo_config['cfg']['persist_layer_norm'] = True + nemo_config['cfg']['masked_softmax_fusion'] = True + print(nemo_config) + return nemo_config + + +def convert(local_rank, rank, world_size, args): + + app_state = AppState() + initialize_model_parallel_for_nemo( + world_size=world_size, + global_rank=rank, + local_rank=local_rank, + tensor_model_parallel_size=args.tensor_model_parallel_size, + pipeline_model_parallel_size=args.pipeline_model_parallel_size, + virtual_pipeline_model_parallel_size=None, + pipeline_model_parallel_split_rank=0, + micro_batch_size=None, + global_batch_size=None, + seed=1234, + apex_transformer_log_level=30, + ) + # hard set the data parallel rank to 0, otherwiaze it is default to None + app_state.data_parallel_rank = 0 + + # tensor_model_parallel_size = args.tensor_model_parallel_size + num_nodes = world_size // args.gpus_per_node + assert world_size % args.gpus_per_node == 0, "world_size must be divisible by gpus_per_node" + + trainer = Trainer(devices=args.gpus_per_node, accelerator='cpu', num_nodes=num_nodes) + + logging.info(f"loading checkpoint {args.input_dir}") + model = LlamaForCausalLM.from_pretrained(args.input_dir) + hf_config = vars(model.config) + nemo_config = load_config(hf_config, args) + print(f"hf_config: {hf_config}") + print(f"nemo_config: {nemo_config}") + + hidden_size = hf_config["hidden_size"] + head_num = hf_config["num_attention_heads"] + head_size = hidden_size // head_num + num_layers = hf_config["num_hidden_layers"] + + checkpoint = None + checkpoint = OrderedDict() + checkpoint['state_dict'] = OrderedDict() + + embed_weight = model.state_dict()[f'model.embed_tokens.weight'] + embed_weights_base_name = f'model.language_model.embedding.word_embeddings.weight' + checkpoint['state_dict'][embed_weights_base_name] = embed_weight + + rotary_embed_weight = model.state_dict()[f'model.layers.0.self_attn.rotary_emb.inv_freq'] + rotary_embed_weight_base_name = f'model.language_model.rotary_pos_emb.inv_freq' + checkpoint['state_dict'][rotary_embed_weight_base_name] = rotary_embed_weight + + for l in range(int(num_layers)): + print(f"converting layer {l}") + # first merge QKV into a single weight + # concat direct to FT shape: [hidden_size, 3, head_num, head_size] + # copied from huggingface_gptj_ckpt_convert.py + new_tensor_shape = (head_num, head_size) + model.state_dict()[ + f'model.layers.{l}.self_attn.q_proj.weight' + ].size()[1:] + q = model.state_dict()[f'model.layers.{l}.self_attn.q_proj.weight'].view(*new_tensor_shape) + k = model.state_dict()[f'model.layers.{l}.self_attn.k_proj.weight'].view(*new_tensor_shape) + v = model.state_dict()[f'model.layers.{l}.self_attn.v_proj.weight'].view(*new_tensor_shape) + qkv_weights = torch.cat((q, k, v), axis=1) + qkv_weights = qkv_weights.reshape([3 * hidden_size, hidden_size]) + qkv_weights_base_name = f'model.language_model.encoder.layers.{l}.self_attention.query_key_value.weight' + checkpoint['state_dict'][qkv_weights_base_name] = qkv_weights + + # attention dense + o_weight = model.state_dict()[f'model.layers.{l}.self_attn.o_proj.weight'] + o_weight_base_name = f'model.language_model.encoder.layers.{l}.self_attention.dense.weight' + checkpoint['state_dict'][o_weight_base_name] = o_weight + + # MLP + mlp_down_weight = model.state_dict()[f'model.layers.{l}.mlp.gate_proj.weight'] + mlp_down_base_name = f'model.language_model.encoder.layers.{l}.mlp.dense_h_to_4h.weight' + checkpoint['state_dict'][mlp_down_base_name] = mlp_down_weight + + mlp_gate_weight = model.state_dict()[f'model.layers.{l}.mlp.up_proj.weight'] + mlp_gate_base_name = f'model.language_model.encoder.layers.{l}.mlp.dense_h_to_4h_2.weight' + checkpoint['state_dict'][mlp_gate_base_name] = mlp_gate_weight + + mlp_up_weight = model.state_dict()[f'model.layers.{l}.mlp.down_proj.weight'] + mlp_up_base_name = f'model.language_model.encoder.layers.{l}.mlp.dense_4h_to_h.weight' + checkpoint['state_dict'][mlp_up_base_name] = mlp_up_weight + + # LayerNorm + input_ln_weight = model.state_dict()[f'model.layers.{l}.input_layernorm.weight'] + input_ln_base_name = f'model.language_model.encoder.layers.{l}.input_layernorm.weight' + checkpoint['state_dict'][input_ln_base_name] = input_ln_weight + + post_attn_ln_weight = model.state_dict()[f'model.layers.{l}.post_attention_layernorm.weight'] + post_attn_ln_base_name = f'model.language_model.encoder.layers.{l}.post_attention_layernorm.weight' + checkpoint['state_dict'][post_attn_ln_base_name] = post_attn_ln_weight + + print(f"done layer {l}") + + final_ln_weight = model.state_dict()[f'model.norm.weight'] + final_ln_base_name = f'model.language_model.encoder.final_layernorm.weight' + checkpoint['state_dict'][final_ln_base_name] = final_ln_weight + + output_layer_weight = model.state_dict()[f'lm_head.weight'] + output_layer_base_name = f'model.language_model.output_layer.weight' + checkpoint['state_dict'][output_layer_base_name] = output_layer_weight + + checkpoint[MegatronLLAMAModel.CHECKPOINT_HYPER_PARAMS_KEY] = OmegaConf.create(nemo_config) + if torch.distributed.is_initialized(): + torch.distributed.barrier() + + model = load_model(MegatronLLAMAModel, checkpoint, strict=False, trainer=trainer) + + # verify tensor parallel rank id and pipeline parallel rank id matches + assert app_state.data_parallel_size == 1 + model._save_restore_connector = NLPSaveRestoreConnector() + model.save_to(args.output_file) + logging.info(f'NeMo model saved to: {args.output_file}') + + +if __name__ == '__main__': + args = get_args() + if args.local_rank == -1: + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + rank = args.local_rank + local_rank = rank + world_size = 1 + else: + local_rank, rank, world_size = initialize_distributed(args) + + # make sure the world size is divisible by tensor model parallel_size + assert world_size % args.tensor_model_parallel_size == 0 + + if torch.distributed.is_initialized(): + torch.distributed.barrier() + convert(local_rank, rank, world_size, args) + if torch.distributed.is_initialized(): + torch.distributed.barrier() diff --git a/workspace/tools/convert_llama2_to_hf.py b/workspace/tools/convert_llama2_to_hf.py new file mode 100644 index 000000000000..4cab37a66628 --- /dev/null +++ b/workspace/tools/convert_llama2_to_hf.py @@ -0,0 +1,310 @@ +# Copyright 2022 EleutherAI and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Adapted from `/src/transformers/models/llama/convert_llama_weights_to_hf.py` + +import argparse +import gc +import json +import os +import shutil +import warnings + +import torch + +from transformers import LlamaConfig, LlamaForCausalLM, LlamaTokenizer + + +try: + from transformers import LlamaTokenizerFast +except ImportError as e: + warnings.warn(e) + warnings.warn( + "The converted tokenizer will be the `slow` tokenizer. To use the fast, update your `tokenizers` library and re-run the tokenizer conversion" + ) + LlamaTokenizerFast = None + +""" +Sample usage: + +``` +python src/transformers/models/llama/convert_llama2_to_hf.py \ + --input_dir /path/to/downloaded/llama/weights --model_size llama-2-7b --output_dir /output/path +``` + +Thereafter, models can be loaded via: + +```py +from transformers import LlamaForCausalLM, LlamaTokenizer + +model = LlamaForCausalLM.from_pretrained("/output/path") +tokenizer = LlamaTokenizer.from_pretrained("/output/path") +``` + +Important note: you need to be able to host the whole model in RAM to execute this script (even if the biggest versions +come in several checkpoints they each contain a part of each weight of the model, so we need to load them all in RAM). +""" + +INTERMEDIATE_SIZE_MAP = { + "llama-2-7b": 11008, + "llama-2-13b": 13824, + "llama-2-70b": 28672, +} +NUM_SHARDS = { + "llama-2-7b": 1, + "llama-2-7b-chat": 1, + "llama-2-13b": 2, + "llama-2-13b-chat": 2, + "llama-2-70b": 8, + "llama-2-70b-chat": 8, +} + + +def compute_intermediate_size(n, ffn_dim_multiplier=1, multiple_of=256): + return multiple_of * ((int(ffn_dim_multiplier * int(8 * n / 3)) + multiple_of - 1) // multiple_of) + + +def read_json(path): + with open(path, "r") as f: + return json.load(f) + + +def write_json(text, path): + with open(path, "w") as f: + json.dump(text, f) + + +def write_model(model_path, input_base_path, model_size, safe_serialization=True): + os.makedirs(model_path, exist_ok=True) + tmp_model_path = os.path.join(model_path, "tmp") + os.makedirs(tmp_model_path, exist_ok=True) + + params = read_json(os.path.join(input_base_path, "params.json")) + num_shards = NUM_SHARDS[model_size] + n_layers = params["n_layers"] + n_heads = params["n_heads"] + n_heads_per_shard = n_heads // num_shards + dim = params["dim"] + dims_per_head = dim // n_heads + base = 10000.0 + inv_freq = 1.0 / (base ** (torch.arange(0, dims_per_head, 2).float() / dims_per_head)) + + if "n_kv_heads" in params: + num_key_value_heads = params["n_kv_heads"] # for GQA / MQA + num_local_key_value_heads = n_heads_per_shard // num_key_value_heads + key_value_dim = dim // num_key_value_heads + else: # compatibility with other checkpoints + num_key_value_heads = n_heads + num_local_key_value_heads = n_heads_per_shard + key_value_dim = dim + + # permute for sliced rotary + def permute(w, n_heads=n_heads, dim1=dim, dim2=dim): + return w.view(n_heads, dim1 // n_heads // 2, 2, dim2).transpose(1, 2).reshape(dim1, dim2) + + print(f"Fetching all parameters from the checkpoint at {input_base_path}.") + # Load weights + if model_size == "llama-2-7b": + # Not sharded + # (The sharded implementation would also work, but this is simpler.) + loaded = torch.load(os.path.join(input_base_path, "consolidated.00.pth"), map_location="cpu") + else: + # Sharded + loaded = [ + torch.load(os.path.join(input_base_path, f"consolidated.{i:02d}.pth"), map_location="cpu") + for i in range(num_shards) + ] + param_count = 0 + index_dict = {"weight_map": {}} + for layer_i in range(n_layers): + filename = f"pytorch_model-{layer_i + 1}-of-{n_layers + 1}.bin" + if model_size == "llama-2-7b": + # Unsharded + state_dict = { + f"model.layers.{layer_i}.self_attn.q_proj.weight": permute( + loaded[f"layers.{layer_i}.attention.wq.weight"] + ), + f"model.layers.{layer_i}.self_attn.k_proj.weight": permute( + loaded[f"layers.{layer_i}.attention.wk.weight"] + ), + f"model.layers.{layer_i}.self_attn.v_proj.weight": loaded[f"layers.{layer_i}.attention.wv.weight"], + f"model.layers.{layer_i}.self_attn.o_proj.weight": loaded[f"layers.{layer_i}.attention.wo.weight"], + f"model.layers.{layer_i}.mlp.gate_proj.weight": loaded[f"layers.{layer_i}.feed_forward.w1.weight"], + f"model.layers.{layer_i}.mlp.down_proj.weight": loaded[f"layers.{layer_i}.feed_forward.w2.weight"], + f"model.layers.{layer_i}.mlp.up_proj.weight": loaded[f"layers.{layer_i}.feed_forward.w3.weight"], + f"model.layers.{layer_i}.input_layernorm.weight": loaded[f"layers.{layer_i}.attention_norm.weight"], + f"model.layers.{layer_i}.post_attention_layernorm.weight": loaded[f"layers.{layer_i}.ffn_norm.weight"], + } + else: + # Sharded + # Note that attention.w{q,k,v,o}, feed_fordward.w[1,2,3], attention_norm.weight and ffn_norm.weight share + # the same storage object, saving attention_norm and ffn_norm will save other weights too, which is + # redundant as other weights will be stitched from multiple shards. To avoid that, they are cloned. + + state_dict = { + f"model.layers.{layer_i}.input_layernorm.weight": loaded[0][ + f"layers.{layer_i}.attention_norm.weight" + ].clone(), + f"model.layers.{layer_i}.post_attention_layernorm.weight": loaded[0][ + f"layers.{layer_i}.ffn_norm.weight" + ].clone(), + } + state_dict[f"model.layers.{layer_i}.self_attn.q_proj.weight"] = permute( + torch.cat( + [ + loaded[i][f"layers.{layer_i}.attention.wq.weight"].view(n_heads_per_shard, dims_per_head, dim) + for i in range(num_shards) + ], + dim=0, + ).reshape(dim, dim) + ) + state_dict[f"model.layers.{layer_i}.self_attn.k_proj.weight"] = permute( + torch.cat( + [ + loaded[i][f"layers.{layer_i}.attention.wk.weight"].view( + num_local_key_value_heads, dims_per_head, dim + ) + for i in range(num_shards) + ], + dim=0, + ).reshape(key_value_dim, dim), + num_key_value_heads, + key_value_dim, + dim, + ) + state_dict[f"model.layers.{layer_i}.self_attn.v_proj.weight"] = torch.cat( + [ + loaded[i][f"layers.{layer_i}.attention.wv.weight"].view( + num_local_key_value_heads, dims_per_head, dim + ) + for i in range(num_shards) + ], + dim=0, + ).reshape(key_value_dim, dim) + + state_dict[f"model.layers.{layer_i}.self_attn.o_proj.weight"] = torch.cat( + [loaded[i][f"layers.{layer_i}.attention.wo.weight"] for i in range(num_shards)], dim=1 + ) + state_dict[f"model.layers.{layer_i}.mlp.gate_proj.weight"] = torch.cat( + [loaded[i][f"layers.{layer_i}.feed_forward.w1.weight"] for i in range(num_shards)], dim=0 + ) + state_dict[f"model.layers.{layer_i}.mlp.down_proj.weight"] = torch.cat( + [loaded[i][f"layers.{layer_i}.feed_forward.w2.weight"] for i in range(num_shards)], dim=1 + ) + state_dict[f"model.layers.{layer_i}.mlp.up_proj.weight"] = torch.cat( + [loaded[i][f"layers.{layer_i}.feed_forward.w3.weight"] for i in range(num_shards)], dim=0 + ) + + state_dict[f"model.layers.{layer_i}.self_attn.rotary_emb.inv_freq"] = inv_freq + for k, v in state_dict.items(): + index_dict["weight_map"][k] = filename + param_count += v.numel() + torch.save(state_dict, os.path.join(tmp_model_path, filename)) + + filename = f"pytorch_model-{n_layers + 1}-of-{n_layers + 1}.bin" + if model_size == "llama-2-7b": + # Unsharded + state_dict = { + "model.embed_tokens.weight": loaded["tok_embeddings.weight"], + "model.norm.weight": loaded["norm.weight"], + "lm_head.weight": loaded["output.weight"], + } + else: + state_dict = { + "model.norm.weight": loaded[0]["norm.weight"], + "model.embed_tokens.weight": torch.cat( + [loaded[i]["tok_embeddings.weight"] for i in range(num_shards)], dim=1 + ), + "lm_head.weight": torch.cat([loaded[i]["output.weight"] for i in range(num_shards)], dim=0), + } + + for k, v in state_dict.items(): + index_dict["weight_map"][k] = filename + param_count += v.numel() + torch.save(state_dict, os.path.join(tmp_model_path, filename)) + + # Write configs + index_dict["metadata"] = {"total_size": param_count * 2} + write_json(index_dict, os.path.join(tmp_model_path, "pytorch_model.bin.index.json")) + ffn_dim_multiplier = params["ffn_dim_multiplier"] if "ffn_dim_multiplier" in params else 1 + multiple_of = params["multiple_of"] if "multiple_of" in params else 256 + config = LlamaConfig( + hidden_size=dim, + intermediate_size=compute_intermediate_size(dim, ffn_dim_multiplier, multiple_of), + num_attention_heads=params["n_heads"], + num_hidden_layers=params["n_layers"], + rms_norm_eps=params["norm_eps"], + num_key_value_heads=num_key_value_heads, + ) + config.save_pretrained(tmp_model_path) + + # Make space so we can load the model properly now. + del state_dict + del loaded + gc.collect() + + print("Loading the checkpoint in a Llama model.") + model = LlamaForCausalLM.from_pretrained(tmp_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True) + # Avoid saving this as part of the config. + del model.config._name_or_path + + print("Saving in the Transformers format.") + model.save_pretrained(model_path, safe_serialization=safe_serialization) + shutil.rmtree(tmp_model_path) + + +def write_tokenizer(tokenizer_path, input_tokenizer_path): + # Initialize the tokenizer based on the `spm` model + tokenizer_class = LlamaTokenizer if LlamaTokenizerFast is None else LlamaTokenizerFast + print(f"Saving a {tokenizer_class.__name__} to {tokenizer_path}.") + tokenizer = tokenizer_class(input_tokenizer_path, legacy=False) + tokenizer.save_pretrained(tokenizer_path) + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--input_dir", help="Location of LLaMA weights, which contains tokenizer.model and model folders", + ) + parser.add_argument( + "--model_size", + choices=[ + "llama-2-7b", + "llama-2-7b-chat", + "llama-2-13b", + "llama-2-13b-chat", + "30B", + "65B", + "llama-2-70b", + "llama-2-70b-chat", + "tokenizer_only", + ], + ) + parser.add_argument( + "--output_dir", help="Location to write HF model and tokenizer", + ) + parser.add_argument("--safe_serialization", type=bool, help="Whether or not to save using `safetensors`.") + args = parser.parse_args() + if args.model_size != "tokenizer_only": + write_model( + model_path=args.output_dir, + input_base_path=os.path.join(args.input_dir, args.model_size), + model_size=args.model_size, + safe_serialization=args.safe_serialization, + ) + spm_path = os.path.join(args.input_dir, "tokenizer.model") + write_tokenizer(args.output_dir, spm_path) + + +if __name__ == "__main__": + main()