diff --git a/Jenkinsfile b/Jenkinsfile index 5dd12b0f4746..0dbad637941e 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -3604,10 +3604,10 @@ assert_frame_equal(training_curve, gt_curve, rtol=1e-3, atol=1e-3)"''' trainer.devices=2 \ trainer.accelerator=gpu \ trainer.log_every_n_steps=1 \ - trainer.val_check_interval=10 \ + trainer.val_check_interval=2 \ trainer.limit_val_batches=2 \ trainer.accumulate_grad_batches=1 \ - trainer.max_steps=10 \ + trainer.max_steps=3 \ trainer.precision=16 \ trainer.gradient_clip_val=1.0 \ exp_manager.exp_dir=examples/nlp/language_modeling/bart_pretrain_results \ @@ -3627,15 +3627,15 @@ assert_frame_equal(training_curve, gt_curve, rtol=1e-3, atol=1e-3)"''' model.decoder.bias_activation_fusion=False \ model.decoder.activations_checkpoint_method='block' \ model.decoder.activations_checkpoint_num_layers=1 \ - model.data.data_prefix=[.5,/home/TestData/nlp/megatron_t5/data/pile_val_small_bert_tokenizer_text_document,.5,/home/TestData/nlp/megatron_t5/data/pile_val_small_bert_tokenizer_text_document]" + model.data.data_prefix='{train:[1.0,/home/TestData/nlp/megatron_t5/data/pile_val_small_bert_tokenizer_text_document],test:[/home/TestData/nlp/megatron_t5/data/pile_val_small_bert_tokenizer_text_document], validation:[/home/TestData/nlp/megatron_t5/data/pile_val_small_bert_tokenizer_text_document]}'" sh "python examples/nlp/language_modeling/megatron_bart_pretraining.py \ trainer.devices=2 \ trainer.accelerator=gpu \ trainer.log_every_n_steps=1 \ - trainer.val_check_interval=10 \ - trainer.limit_val_batches=2 \ + trainer.val_check_interval=2 \ + trainer.limit_val_batches=1 \ trainer.accumulate_grad_batches=1 \ - trainer.max_steps=10 \ + trainer.max_steps=6 \ trainer.precision=16 \ trainer.gradient_clip_val=1.0 \ exp_manager.exp_dir=examples/nlp/language_modeling/bart_pretrain_results \ @@ -3656,7 +3656,7 @@ assert_frame_equal(training_curve, gt_curve, rtol=1e-3, atol=1e-3)"''' model.decoder.bias_activation_fusion=False \ model.decoder.activations_checkpoint_method='block' \ model.decoder.activations_checkpoint_num_layers=1 \ - model.data.data_prefix=[.5,/home/TestData/nlp/megatron_t5/data/pile_val_small_bert_tokenizer_text_document,.5,/home/TestData/nlp/megatron_t5/data/pile_val_small_bert_tokenizer_text_document]" + model.data.data_prefix='{train:[1.0,/home/TestData/nlp/megatron_t5/data/pile_val_small_bert_tokenizer_text_document],test:[/home/TestData/nlp/megatron_t5/data/pile_val_small_bert_tokenizer_text_document], validation:[/home/TestData/nlp/megatron_t5/data/pile_val_small_bert_tokenizer_text_document]}'" sh "rm -rf examples/nlp/language_modeling/bart_pretrain_results" } } diff --git a/examples/nlp/language_modeling/conf/megatron_gpt_config.yaml b/examples/nlp/language_modeling/conf/megatron_gpt_config.yaml index 18c1258fe93c..b7de4854784b 100755 --- a/examples/nlp/language_modeling/conf/megatron_gpt_config.yaml +++ b/examples/nlp/language_modeling/conf/megatron_gpt_config.yaml @@ -118,14 +118,18 @@ model: sequence_parallel: False data: - # Path to data must be specified by the user. - # can override from the CLI: "model.data.data_prefix=[.5,/raid/data/pile/my-gpt3_00_text_document,.5,/raid/data/pile/my-gpt3_01_text_document]", + # Path to data must be specified by the user. + # Supports List, String and Dictionary + # List : can override from the CLI: "model.data.data_prefix=[.5,/raid/data/pile/my-gpt3_00_text_document,.5,/raid/data/pile/my-gpt3_01_text_document]", # Or see example below: # data_prefix: # - .5 # - /raid/data/pile/my-gpt3_00_text_document # - .5 # - /raid/data/pile/my-gpt3_01_text_document + # Dictionary: can override from CLI "model.data.data_prefix"={"train":[1.0, /path/to/data], "validation":/path/to/data, "test":/path/to/test} + # Or see example below: + # "model.data.data_prefix: {train:[1.0,/path/to/data], validation:[/path/to/data], test:[/path/to/test]}" data_prefix: ??? index_mapping_dir: null # path to save index mapping .npy files, by default will save in the same location as data_prefix data_impl: mmap diff --git a/examples/nlp/language_modeling/conf/megatron_t5_config.yaml b/examples/nlp/language_modeling/conf/megatron_t5_config.yaml index 7a93c604366b..25e58f4a7e2d 100644 --- a/examples/nlp/language_modeling/conf/megatron_t5_config.yaml +++ b/examples/nlp/language_modeling/conf/megatron_t5_config.yaml @@ -92,14 +92,18 @@ model: apex_transformer_log_level: 30 # Python logging level displays logs with severity greater than or equal to this data: - # Path to data must be specified by the user. - # can override from the CLI: "model.data.data_prefix=[.5,/raid/data/pile/my-t5_00_text_document,.5,/raid/data/pile/my-t5_01_text_document]", - # Or see example below: + # Path to data must be specified by the user. + # Supports List, String and Dictionary + # List : can override from the CLI: "model.data.data_prefix=[.5,/raid/data/pile/my-t5_00_text_document,.5,/raid/data/pile/my-t5_01_text_document]", + # Or see example below: # data_prefix: # - .5 # - /raid/data/pile/my-t5_00_text_document # - .5 # - /raid/data/pile/my-t5_01_text_document + # Dictionary: can override from CLI "model.data.data_prefix"={"train":[1.0, /path/to/data], "validation":/path/to/data, "test":/path/to/test} + # Or see example below: + # "model.data.data_prefix: {train:[1.0,/path/to/data], validation:[/path/to/data], test:[/path/to/test]}" data_prefix: ??? index_mapping_dir: null # path to save index mapping .npy files, by default will save in the same location as data_prefix data_impl: mmap # mmap, retmmap, text_mmap, csv_mmap diff --git a/nemo/collections/nlp/data/language_modeling/megatron/base_dataset_utils.py b/nemo/collections/nlp/data/language_modeling/megatron/base_dataset_utils.py index ef7cd8ae5660..96b7f57f7dc7 100644 --- a/nemo/collections/nlp/data/language_modeling/megatron/base_dataset_utils.py +++ b/nemo/collections/nlp/data/language_modeling/megatron/base_dataset_utils.py @@ -15,7 +15,7 @@ import math -def get_datasets_weights_and_num_samples(data_prefix, train_valid_test_num_samples): +def get_datasets_weights_and_num_samples(data_prefix, num_samples): # The data prefix should be in the format of: # weight-1, data-prefix-1, weight-2, data-prefix-2, .. @@ -39,9 +39,11 @@ def get_datasets_weights_and_num_samples(data_prefix, train_valid_test_num_sampl # TODO: check data leakage between train/val/test? datasets_train_valid_test_num_samples = [] for weight in weights: - datasets_train_valid_test_num_samples.append( - [int(math.ceil(val * weight * 1.005)) for val in train_valid_test_num_samples] - ) + # Comes here when we have seperate train,test and validation datasets. + if isinstance(num_samples, int): + datasets_train_valid_test_num_samples.append(int(math.ceil(num_samples * weight * 1.005))) + else: + datasets_train_valid_test_num_samples.append([int(math.ceil(val * weight * 1.005)) for val in num_samples]) return prefixes, weights, datasets_train_valid_test_num_samples diff --git a/nemo/collections/nlp/data/language_modeling/megatron/dataset_utils.py b/nemo/collections/nlp/data/language_modeling/megatron/dataset_utils.py index c9af9982524e..e6caeb02967c 100644 --- a/nemo/collections/nlp/data/language_modeling/megatron/dataset_utils.py +++ b/nemo/collections/nlp/data/language_modeling/megatron/dataset_utils.py @@ -539,6 +539,264 @@ def make_text_memmap_bin_compatibility(text_memmap_ds): return text_memmap_ds +def get_dataset( + indexed_dataset, + start_index, + end_index, + cfg, + trainer, + num_samples, + masked_lm_prob, + short_seq_prob, + binary_head, + max_seq_length_dec, + dataset_type='standard_bert', + tokenizer=None, + max_ngram_size=3, + mean_ngram_size=None, + geometric_dist=True, + permutation=False, + whole_word_masking=True, + favor_long_ngrams=False, + delete_mask_prob=0, # This flag is used in BART only, and will not have effect on T5/BERT + respect_document_boundaries=True, + **kwargs, +): + + if dataset_type not in DSET_TYPES: + raise ValueError("Invalid dataset_type: ", dataset_type) + + # from nemo.collections.nlp.data.language_modeling.megatron.ict_dataset import ICTDataset + from nemo.collections.nlp.data.language_modeling.megatron.bert_dataset import BertDataset + from nemo.collections.nlp.data.language_modeling.megatron.t5_dataset import T5Dataset + from nemo.collections.nlp.data.language_modeling.megatron.ul2_dataset import UL2Dataset + from nemo.collections.nlp.data.language_modeling.megatron.bart_dataset import BARTDataset + from nemo.collections.nlp.data.language_modeling.megatron.length_distribution_type import LengthDistribution + + if dataset_type == DSET_TYPE_ICT: + raise NotImplementedError("ICT dataset is not implemented yet.") + ''' + dataset = ICTDataset( + block_dataset=indexed_dataset, + title_dataset=title_dataset, + query_in_block_prob=args.query_in_block_prob, + use_one_sent_docs=args.use_one_sent_docs, + binary_head=binary_head, + **kwargs, + ) + ''' + elif dataset_type == DSET_TYPE_T5: + assert tokenizer is not None, "Tokenizer is required for T5 dataset" + logging.info("Instatiating T5 Dataset ...") + documents = np.arange(start=start_index, stop=end_index, step=1, dtype=np.int32) + dataset = T5Dataset( + cfg=cfg, + trainer=trainer, + tokenizer=tokenizer, + indexed_dataset=indexed_dataset, + masked_lm_prob=masked_lm_prob, + max_seq_length_dec=max_seq_length_dec, + short_seq_prob=short_seq_prob, + max_ngram_size=max_ngram_size, + mean_ngram_size=mean_ngram_size, + geometric_dist=geometric_dist, + permutation=permutation, + whole_word_masking=whole_word_masking, + favor_long_ngrams=favor_long_ngrams, + documents=documents, + respect_document_boundaries=respect_document_boundaries, + **kwargs, + ) + elif dataset_type == DSET_TYPE_BERT: + logging.info("Instatiating BERT Dataset ...") + dataset = BertDataset( + cfg=cfg, + indexed_dataset=indexed_dataset, + masked_lm_prob=masked_lm_prob, + short_seq_prob=short_seq_prob, + binary_head=binary_head, + tokenizer=tokenizer, + **kwargs, + ) + elif dataset_type == DSET_TYPE_T5_LM: + documents = np.arange(start=start_index, stop=end_index, step=1, dtype=np.int32) + logging.info("Instatiating T5 Prefix-LM Dataset ...") + dataset = T5LMAdaptedDataset( + cfg=cfg, + trainer=trainer, + tokenizer=tokenizer, + documents=documents, + indexed_dataset=indexed_dataset, + num_samples=num_samples, + max_seq_length_encoder=kwargs["max_seq_length"], + max_seq_length_decoder=max_seq_length_dec, + **kwargs, + ) + elif dataset_type == DSET_TYPE_BART: + assert tokenizer is not None, "Tokenizer is required for BART dataset" + documents = np.arange(start=start_index, stop=end_index, step=1, dtype=np.int32) + logging.info("Instatiating BART Dataset ...") + dataset = BARTDataset( + cfg=cfg, + trainer=trainer, + tokenizer=tokenizer, + indexed_dataset=indexed_dataset, + masked_lm_prob=masked_lm_prob, + short_seq_prob=short_seq_prob, + max_ngram_size=max_ngram_size, + mean_ngram_size=mean_ngram_size, + geometric_dist=geometric_dist, + permutation=permutation, + whole_word_masking=whole_word_masking, + favor_long_ngrams=favor_long_ngrams, + delete_mask_prob=delete_mask_prob, + documents=documents, + respect_document_boundaries=respect_document_boundaries, + **kwargs, + ) + elif dataset_type == DSET_TYPE_UL2: + assert tokenizer is not None, "Tokenizer is required for UL2 dataset" + documents = np.arange(start=start_index, stop=end_index, step=1, dtype=np.int32) + logging.info("Instatiating UL2 Dataset ...") + extreme_ngram_span_length_distribution = cfg.data.get( + "extreme_ngram_span_length_distribution", "truncated_normal" + ) + ngram_span_length_distribution = cfg.data.get("ngram_span_length_distribution", "geometric") + if extreme_ngram_span_length_distribution == "truncated_normal": + extreme_ngram_span_length_distribution = LengthDistribution.truncated_normal + elif extreme_ngram_span_length_distribution == "uniform": + extreme_ngram_span_length_distribution = LengthDistribution.uniform + elif extreme_ngram_span_length_distribution == "geometric": + extreme_ngram_span_length_distribution = LengthDistribution.geometric + + if ngram_span_length_distribution == "truncated_normal": + ngram_span_length_distribution = LengthDistribution.truncated_normal + elif ngram_span_length_distribution == "uniform": + ngram_span_length_distribution = LengthDistribution.uniform + elif ngram_span_length_distribution == "geometric": + ngram_span_length_distribution = LengthDistribution.geometric + + dataset = UL2Dataset( + cfg=cfg, + trainer=trainer, + tokenizer=tokenizer, + indexed_dataset=indexed_dataset, + masked_lm_prob=masked_lm_prob, + max_seq_length_dec=max_seq_length_dec, + short_seq_prob=short_seq_prob, + max_ngram_size=max_ngram_size, + mean_ngram_size=mean_ngram_size, + ngram_span_length_distribution=ngram_span_length_distribution, + extreme_ngram_span_length_distribution=extreme_ngram_span_length_distribution, + permutation=permutation, + whole_word_masking=whole_word_masking, + favor_long_ngrams=favor_long_ngrams, + extreme_masked_lm_prob=cfg.data.get("extreme_masked_lm_prob", 0.5), + extreme_max_ngram_size=cfg.data.get("extreme_max_ngram_size", 128), + extreme_mean_ngram_size=cfg.data.get("extreme_mean_ngram_size", 64), + extreme_min_ngram_size=cfg.data.get("extreme_min_ngram_size", 32), + prefix_lm_pivot_mean=cfg.data.get("prefix_lm_pivot_mean", 0.25), + respect_document_boundaries=respect_document_boundaries, + documents=documents, + **kwargs, + ) + else: + raise NotImplementedError(f"Dataset type {dataset_type} not fully implemented.") + return dataset + + +def build_dataset( + cfg, + trainer, + data_prefix, + data_impl, + num_samples, + max_seq_length, + masked_lm_prob, + short_seq_prob, + seed, + skip_warmup, + binary_head, + max_seq_length_dec, + name, + dataset_type, + tokenizer, + max_ngram_size, + mean_ngram_size, + geometric_dist, + permutation, + whole_word_masking, + favor_long_ngrams, + delete_mask_prob, + respect_document_boundaries, + data_impl_kwargs, +): + def _build_dataset(current_data_prefix, current_num_samples): + indexed_dataset = get_indexed_dataset_( + current_data_prefix, data_impl, skip_warmup, data_impl_kwargs=data_impl_kwargs + ) + total_num_of_documents = indexed_dataset.sizes.shape[0] + # Print stats about the splits. + logging.info(' > dataset split:') + logging.info(' Total {} documents is : {} '.format(name, total_num_of_documents)) + if hasattr(indexed_dataset, 'get_doc_idx'): + doc_idx_ptr = indexed_dataset.get_doc_idx() + indexed_dataset.set_doc_idx(doc_idx_ptr[0:total_num_of_documents]) + + kwargs = dict( + name=name, + data_prefix=current_data_prefix, + num_epochs=None, + max_num_samples=int(current_num_samples), + max_seq_length=max_seq_length, + seed=seed, + ) + + dataset = get_dataset( + indexed_dataset, + 0, + total_num_of_documents, + cfg, + trainer, + current_num_samples, + masked_lm_prob, + short_seq_prob, + binary_head, + max_seq_length_dec, + dataset_type, + tokenizer, + max_ngram_size, + mean_ngram_size, + geometric_dist, + permutation, + whole_word_masking, + favor_long_ngrams, + delete_mask_prob, + respect_document_boundaries, + **kwargs, + ) + + # Set the original pointer so dataset remains the main dataset. + if hasattr(indexed_dataset, 'set_doc_idx'): + indexed_dataset.set_doc_idx(doc_idx_ptr) + # Checks. + assert indexed_dataset.doc_idx[0] == 0 + assert indexed_dataset.doc_idx.shape[0] == (total_num_of_documents + 1) + return dataset + + if len(data_prefix) == 1: + return _build_dataset(data_prefix[0], num_samples) + + else: + output = get_datasets_weights_and_num_samples(data_prefix, num_samples) + prefixes, weights, datasets_num_samples = output + datasets = [] + for i in range(len(prefixes)): + dataset = _build_dataset(prefixes[i], datasets_num_samples[i]) + datasets.append(dataset) + return BlendableDataset(datasets, weights, num_samples) + + def build_train_valid_test_datasets( cfg, trainer, @@ -581,14 +839,20 @@ def build_train_valid_test_datasets( "respect_document_boundaries=False is not compatible with text_memmap and csv_memmap (data_impl_kwargs != {})" ) - if len(data_prefix) == 1: - return _build_train_valid_test_datasets( + if isinstance(data_prefix, DictConfig): + assert ( + data_prefix.get('train') is not None + and data_prefix.get('test') is not None + and data_prefix.get('validation') is not None + ), f"Data prefix dictionary should have train, test and validation keys. data_prefix currently has only {data_prefix.keys()}" + if cfg.data.splits_string is not None: + logging.warning(cfg.data.splits_string + " ignored since data prefix is of type dictionary.") + train_ds = build_dataset( cfg, trainer, - data_prefix[0], + data_prefix["train"], data_impl, - splits_string, - train_valid_test_num_samples, + int(train_valid_test_num_samples[0]), max_seq_length, masked_lm_prob, short_seq_prob, @@ -596,6 +860,7 @@ def build_train_valid_test_datasets( skip_warmup, binary_head, max_seq_length_dec, + "train", dataset_type=dataset_type, tokenizer=tokenizer, max_ngram_size=max_ngram_size, @@ -608,24 +873,12 @@ def build_train_valid_test_datasets( respect_document_boundaries=respect_document_boundaries, data_impl_kwargs=data_impl_kwargs, ) - # Blending dataset. - # Parse the values. - output = get_datasets_weights_and_num_samples(data_prefix, train_valid_test_num_samples) - prefixes, weights, datasets_train_valid_test_num_samples = output - train_n, valid_n, test_n = map(sum, zip(*datasets_train_valid_test_num_samples)) - - # Build individual datasets. - train_datasets = [] - valid_datasets = [] - test_datasets = [] - for i in range(len(prefixes)): - train_ds, valid_ds, test_ds = _build_train_valid_test_datasets( + validation_ds = build_dataset( cfg, trainer, - prefixes[i], + data_prefix["validation"], data_impl, - splits_string, - datasets_train_valid_test_num_samples[i], + int(train_valid_test_num_samples[1]), max_seq_length, masked_lm_prob, short_seq_prob, @@ -633,6 +886,7 @@ def build_train_valid_test_datasets( skip_warmup, binary_head, max_seq_length_dec, + "valid", dataset_type=dataset_type, tokenizer=tokenizer, max_ngram_size=max_ngram_size, @@ -645,25 +899,119 @@ def build_train_valid_test_datasets( respect_document_boundaries=respect_document_boundaries, data_impl_kwargs=data_impl_kwargs, ) - if train_ds: - train_datasets.append(train_ds) - if valid_ds: - valid_datasets.append(valid_ds) - if test_ds: - test_datasets.append(test_ds) - - # Blend. - blending_train_dataset = None - if train_datasets: - blending_train_dataset = BlendableDataset(train_datasets, weights, train_n) - blending_valid_dataset = None - if valid_datasets: - blending_valid_dataset = BlendableDataset(valid_datasets, weights, valid_n) - blending_test_dataset = None - if test_datasets: - blending_test_dataset = BlendableDataset(test_datasets, weights, test_n) - - return (blending_train_dataset, blending_valid_dataset, blending_test_dataset) + test_ds = build_dataset( + cfg, + trainer, + data_prefix["test"], + data_impl, + int(train_valid_test_num_samples[2]), + max_seq_length, + masked_lm_prob, + short_seq_prob, + seed, + skip_warmup, + binary_head, + max_seq_length_dec, + "test", + dataset_type=dataset_type, + tokenizer=tokenizer, + max_ngram_size=max_ngram_size, + mean_ngram_size=mean_ngram_size, + geometric_dist=geometric_dist, + permutation=permutation, + whole_word_masking=whole_word_masking, + favor_long_ngrams=favor_long_ngrams, + delete_mask_prob=delete_mask_prob, + respect_document_boundaries=respect_document_boundaries, + data_impl_kwargs=data_impl_kwargs, + ) + return train_ds, validation_ds, test_ds + + else: + + if len(data_prefix) == 1: + return _build_train_valid_test_datasets( + cfg, + trainer, + data_prefix[0], + data_impl, + splits_string, + train_valid_test_num_samples, + max_seq_length, + masked_lm_prob, + short_seq_prob, + seed, + skip_warmup, + binary_head, + max_seq_length_dec, + dataset_type=dataset_type, + tokenizer=tokenizer, + max_ngram_size=max_ngram_size, + mean_ngram_size=mean_ngram_size, + geometric_dist=geometric_dist, + permutation=permutation, + whole_word_masking=whole_word_masking, + favor_long_ngrams=favor_long_ngrams, + delete_mask_prob=delete_mask_prob, + respect_document_boundaries=respect_document_boundaries, + data_impl_kwargs=data_impl_kwargs, + ) + # Blending dataset. + # Parse the values. + output = get_datasets_weights_and_num_samples(data_prefix, train_valid_test_num_samples) + prefixes, weights, datasets_train_valid_test_num_samples = output + train_n, valid_n, test_n = map(sum, zip(*datasets_train_valid_test_num_samples)) + + # Build individual datasets. + train_datasets = [] + valid_datasets = [] + test_datasets = [] + for i in range(len(prefixes)): + train_ds, valid_ds, test_ds = _build_train_valid_test_datasets( + cfg, + trainer, + prefixes[i], + data_impl, + splits_string, + datasets_train_valid_test_num_samples[i], + max_seq_length, + masked_lm_prob, + short_seq_prob, + seed, + skip_warmup, + binary_head, + max_seq_length_dec, + dataset_type=dataset_type, + tokenizer=tokenizer, + max_ngram_size=max_ngram_size, + mean_ngram_size=mean_ngram_size, + geometric_dist=geometric_dist, + permutation=permutation, + whole_word_masking=whole_word_masking, + favor_long_ngrams=favor_long_ngrams, + delete_mask_prob=delete_mask_prob, + respect_document_boundaries=respect_document_boundaries, + data_impl_kwargs=data_impl_kwargs, + ) + if train_ds: + train_datasets.append(train_ds) + if valid_ds: + valid_datasets.append(valid_ds) + if test_ds: + test_datasets.append(test_ds) + + # Blend. + blending_train_dataset = None + if train_datasets: + blending_train_dataset = BlendableDataset(train_datasets, weights, train_n) + blending_valid_dataset = None + if valid_datasets: + blending_valid_dataset = BlendableDataset(valid_datasets, weights, valid_n) + blending_test_dataset = None + if test_datasets: + blending_test_dataset = BlendableDataset(test_datasets, weights, test_n) + + return (blending_train_dataset, blending_valid_dataset, blending_test_dataset) def _build_train_valid_test_datasets( @@ -693,9 +1041,6 @@ def _build_train_valid_test_datasets( data_impl_kwargs={}, ): - if dataset_type not in DSET_TYPES: - raise ValueError("Invalid dataset_type: ", dataset_type) - # Indexed dataset. indexed_dataset = get_indexed_dataset_(data_prefix, data_impl, skip_warmup, data_impl_kwargs=data_impl_kwargs) @@ -758,135 +1103,29 @@ def build_dataset(index, name): seed=seed, ) - if dataset_type == DSET_TYPE_ICT: - raise NotImplementedError("ICT dataset is not implemented yet.") - ''' - dataset = ICTDataset( - block_dataset=indexed_dataset, - title_dataset=title_dataset, - query_in_block_prob=args.query_in_block_prob, - use_one_sent_docs=args.use_one_sent_docs, - binary_head=binary_head, - **kwargs, - ) - ''' - elif dataset_type == DSET_TYPE_T5: - assert tokenizer is not None, "Tokenizer is required for T5 dataset" - logging.info("Instatiating T5 Dataset ...") - documents = np.arange(start=splits[index], stop=splits[index + 1], step=1, dtype=np.int32) - dataset = T5Dataset( - cfg=cfg, - trainer=trainer, - tokenizer=tokenizer, - indexed_dataset=indexed_dataset, - masked_lm_prob=masked_lm_prob, - max_seq_length_dec=max_seq_length_dec, - short_seq_prob=short_seq_prob, - max_ngram_size=max_ngram_size, - mean_ngram_size=mean_ngram_size, - geometric_dist=geometric_dist, - permutation=permutation, - whole_word_masking=whole_word_masking, - favor_long_ngrams=favor_long_ngrams, - documents=documents, - respect_document_boundaries=respect_document_boundaries, - **kwargs, - ) - elif dataset_type == DSET_TYPE_BERT: - logging.info("Instatiating BERT Dataset ...") - dataset = BertDataset( - cfg=cfg, - indexed_dataset=indexed_dataset, - masked_lm_prob=masked_lm_prob, - short_seq_prob=short_seq_prob, - binary_head=binary_head, - tokenizer=tokenizer, - **kwargs, - ) - elif dataset_type == DSET_TYPE_T5_LM: - documents = np.arange(start=splits[index], stop=splits[index + 1], step=1, dtype=np.int32) - logging.info("Instatiating T5 Prefix-LM Dataset ...") - dataset = T5LMAdaptedDataset( - cfg=cfg, - trainer=trainer, - tokenizer=tokenizer, - documents=documents, - indexed_dataset=indexed_dataset, - num_samples=int(train_valid_test_num_samples[index]), - max_seq_length_encoder=max_seq_length, - max_seq_length_decoder=max_seq_length_dec, - **kwargs, - ) - elif dataset_type == DSET_TYPE_BART: - assert tokenizer is not None, "Tokenizer is required for BART dataset" - documents = np.arange(start=splits[index], stop=splits[index + 1], step=1, dtype=np.int32) - logging.info("Instatiating BART Dataset ...") - dataset = BARTDataset( - cfg=cfg, - trainer=trainer, - tokenizer=tokenizer, - indexed_dataset=indexed_dataset, - masked_lm_prob=masked_lm_prob, - short_seq_prob=short_seq_prob, - max_ngram_size=max_ngram_size, - mean_ngram_size=mean_ngram_size, - geometric_dist=geometric_dist, - permutation=permutation, - whole_word_masking=whole_word_masking, - favor_long_ngrams=favor_long_ngrams, - delete_mask_prob=delete_mask_prob, - documents=documents, - respect_document_boundaries=respect_document_boundaries, - **kwargs, - ) - elif dataset_type == DSET_TYPE_UL2: - assert tokenizer is not None, "Tokenizer is required for UL2 dataset" - documents = np.arange(start=splits[index], stop=splits[index + 1], step=1, dtype=np.int32) - logging.info("Instatiating UL2 Dataset ...") - extreme_ngram_span_length_distribution = cfg.data.get( - "extreme_ngram_span_length_distribution", "truncated_normal" - ) - ngram_span_length_distribution = cfg.data.get("ngram_span_length_distribution", "geometric") - if extreme_ngram_span_length_distribution == "truncated_normal": - extreme_ngram_span_length_distribution = LengthDistribution.truncated_normal - elif extreme_ngram_span_length_distribution == "uniform": - extreme_ngram_span_length_distribution = LengthDistribution.uniform - elif extreme_ngram_span_length_distribution == "geometric": - extreme_ngram_span_length_distribution = LengthDistribution.geometric - - if ngram_span_length_distribution == "truncated_normal": - ngram_span_length_distribution = LengthDistribution.truncated_normal - elif ngram_span_length_distribution == "uniform": - ngram_span_length_distribution = LengthDistribution.uniform - elif ngram_span_length_distribution == "geometric": - ngram_span_length_distribution = LengthDistribution.geometric - - dataset = UL2Dataset( - cfg=cfg, - trainer=trainer, - tokenizer=tokenizer, - indexed_dataset=indexed_dataset, - masked_lm_prob=masked_lm_prob, - max_seq_length_dec=max_seq_length_dec, - short_seq_prob=short_seq_prob, - max_ngram_size=max_ngram_size, - mean_ngram_size=mean_ngram_size, - ngram_span_length_distribution=ngram_span_length_distribution, - extreme_ngram_span_length_distribution=extreme_ngram_span_length_distribution, - permutation=permutation, - whole_word_masking=whole_word_masking, - favor_long_ngrams=favor_long_ngrams, - extreme_masked_lm_prob=cfg.data.get("extreme_masked_lm_prob", 0.5), - extreme_max_ngram_size=cfg.data.get("extreme_max_ngram_size", 128), - extreme_mean_ngram_size=cfg.data.get("extreme_mean_ngram_size", 64), - extreme_min_ngram_size=cfg.data.get("extreme_min_ngram_size", 32), - prefix_lm_pivot_mean=cfg.data.get("prefix_lm_pivot_mean", 0.25), - respect_document_boundaries=respect_document_boundaries, - documents=documents, - **kwargs, - ) - else: - raise NotImplementedError("Dataset type not fully implemented.") + dataset = get_dataset( + indexed_dataset, + splits[index], + splits[index + 1], + cfg, + trainer, + int(train_valid_test_num_samples[index]), + masked_lm_prob, + short_seq_prob, + binary_head, + max_seq_length_dec, + dataset_type, + tokenizer, + max_ngram_size, + mean_ngram_size, + geometric_dist, + permutation, + whole_word_masking, + favor_long_ngrams, + delete_mask_prob, + respect_document_boundaries, + **kwargs, + ) # Set the original pointer so dataset remains the main dataset. if hasattr(indexed_dataset, 'set_doc_idx'): @@ -894,7 +1133,7 @@ def build_dataset(index, name): # Checks. assert indexed_dataset.doc_idx[0] == 0 assert indexed_dataset.doc_idx.shape[0] == (total_num_of_documents + 1) - return dataset + return dataset train_dataset = build_dataset(0, 'train') valid_dataset = build_dataset(1, 'valid') diff --git a/nemo/collections/nlp/data/language_modeling/megatron/gpt_dataset.py b/nemo/collections/nlp/data/language_modeling/megatron/gpt_dataset.py index 3648b2c57938..592e423a092b 100644 --- a/nemo/collections/nlp/data/language_modeling/megatron/gpt_dataset.py +++ b/nemo/collections/nlp/data/language_modeling/megatron/gpt_dataset.py @@ -19,6 +19,7 @@ import numpy as np import torch +from omegaconf.dictconfig import DictConfig from nemo.collections.nlp.data.language_modeling.megatron.base_dataset_utils import ( get_datasets_weights_and_num_samples, @@ -39,6 +40,40 @@ HAVE_APEX = False +def build_dataset(cfg, trainer, data_prefix, data_impl, num_samples, seq_length, seed, skip_warmup, tokenizer, name): + def _build_dataset(current_data_prefix, current_num_samples): + indexed_dataset = get_indexed_dataset_(current_data_prefix, data_impl, skip_warmup) + total_num_of_documents = indexed_dataset.sizes.shape[0] + # Print stats about the splits. + logging.info(' > dataset split:') + logging.info(' Total {} documents is : {} '.format(name, total_num_of_documents)) + dataset = GPTDataset( + cfg, + trainer, + tokenizer, + name, + current_data_prefix, + np.arange(start=0, stop=total_num_of_documents, step=1, dtype=np.int32), + indexed_dataset, + current_num_samples, + seq_length, + seed, + ) + return dataset + + if len(data_prefix) == 1: + return _build_dataset(data_prefix[0], num_samples) + + else: + output = get_datasets_weights_and_num_samples(data_prefix, num_samples) + prefixes, weights, datasets_num_samples = output + datasets = [] + for i in range(len(prefixes)): + dataset = _build_dataset(prefixes[i], datasets_num_samples[i]) + datasets.append(dataset) + return BlendableDataset(datasets, weights, num_samples) + + def build_train_valid_test_datasets( cfg, trainer, @@ -51,66 +86,111 @@ def build_train_valid_test_datasets( skip_warmup, tokenizer, ): - """Build train, valid, and test datasets.""" - - # Single dataset. - if len(data_prefix) == 1: - return _build_train_valid_test_datasets( + if isinstance(data_prefix, DictConfig): + assert ( + data_prefix.get('train') is not None + and data_prefix.get('test') is not None + and data_prefix.get('validation') is not None + ), f"Data prefix dictionary should have train, test and validation keys. data_prefix currently has only {data_prefix.keys()}" + if cfg.data.splits_string is not None: + logging.warning(cfg.data.splits_string + " ignored since data prefix is of type dictionary.") + train_ds = build_dataset( cfg, trainer, - data_prefix[0], + data_prefix["train"], data_impl, - splits_string, - train_valid_test_num_samples, + train_valid_test_num_samples[0], seq_length, seed, skip_warmup, tokenizer, + "train", ) - - # Blending dataset. - # Parse the values. - output = get_datasets_weights_and_num_samples(data_prefix, train_valid_test_num_samples) - prefixes, weights, datasets_train_valid_test_num_samples = output - - # Build individual datasets. - train_datasets = [] - valid_datasets = [] - test_datasets = [] - for i in range(len(prefixes)): - train_ds, valid_ds, test_ds = _build_train_valid_test_datasets( + validation_ds = build_dataset( + cfg, + trainer, + data_prefix["validation"], + data_impl, + train_valid_test_num_samples[1], + seq_length, + seed, + skip_warmup, + tokenizer, + "valid", + ) + test_ds = build_dataset( cfg, trainer, - prefixes[i], + data_prefix["test"], data_impl, - splits_string, - datasets_train_valid_test_num_samples[i], + train_valid_test_num_samples[2], seq_length, seed, skip_warmup, tokenizer, + "test", ) - if train_ds: - train_datasets.append(train_ds) - if valid_ds: - valid_datasets.append(valid_ds) - if test_ds: - test_datasets.append(test_ds) - - train_n, valid_n, test_n = map(sum, zip(*datasets_train_valid_test_num_samples)) - - # Blend. - blending_train_dataset = None - if train_datasets: - blending_train_dataset = BlendableDataset(train_datasets, weights, train_n) - blending_valid_dataset = None - if valid_datasets: - blending_valid_dataset = BlendableDataset(valid_datasets, weights, valid_n) - blending_test_dataset = None - if test_datasets: - blending_test_dataset = BlendableDataset(test_datasets, weights, test_n) - - return (blending_train_dataset, blending_valid_dataset, blending_test_dataset) + return train_ds, validation_ds, test_ds + + else: + # Single dataset. + if len(data_prefix) == 1: + return _build_train_valid_test_datasets( + cfg, + trainer, + data_prefix[0], + data_impl, + splits_string, + train_valid_test_num_samples, + seq_length, + seed, + skip_warmup, + tokenizer, + ) + + # Blending dataset. + # Parse the values. + output = get_datasets_weights_and_num_samples(data_prefix, train_valid_test_num_samples) + prefixes, weights, datasets_train_valid_test_num_samples = output + + # Build individual datasets. + train_datasets = [] + valid_datasets = [] + test_datasets = [] + for i in range(len(prefixes)): + train_ds, valid_ds, test_ds = _build_train_valid_test_datasets( + cfg, + trainer, + prefixes[i], + data_impl, + splits_string, + datasets_train_valid_test_num_samples[i], + seq_length, + seed, + skip_warmup, + tokenizer, + ) + if train_ds: + train_datasets.append(train_ds) + if valid_ds: + valid_datasets.append(valid_ds) + if test_ds: + test_datasets.append(test_ds) + + train_n, valid_n, test_n = map(sum, zip(*datasets_train_valid_test_num_samples)) + + # Blend. + blending_train_dataset = None + if train_datasets: + blending_train_dataset = BlendableDataset(train_datasets, weights, train_n) + blending_valid_dataset = None + if valid_datasets: + blending_valid_dataset = BlendableDataset(valid_datasets, weights, valid_n) + blending_test_dataset = None + if test_datasets: + blending_test_dataset = BlendableDataset(test_datasets, weights, test_n) + + return (blending_train_dataset, blending_valid_dataset, blending_test_dataset) def _build_train_valid_test_datasets(