Skip to content
This repository has been archived by the owner on Jan 15, 2024. It is now read-only.

Commit

Permalink
[FEATURE] DatasetLoader for BERT Pretraining (#1099)
Browse files Browse the repository at this point in the history
  • Loading branch information
szhengac authored Feb 3, 2020
1 parent 8f475df commit 8d43da9
Show file tree
Hide file tree
Showing 14 changed files with 792 additions and 455 deletions.
2 changes: 2 additions & 0 deletions docs/api/modules/data.rst
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,7 @@ iterate through sequence data.
SortedSampler
FixedBucketSampler
SortedBucketSampler
SplitSampler

The `FixedBucketSampler` uses following bucket scheme classes to generate bucket keys.

Expand All @@ -248,6 +249,7 @@ DataLoaders loads data from a dataset and returns mini-batches of data
:nosignatures:

ShardedDataLoader
DatasetLoader

Utilities
---------
Expand Down
3 changes: 2 additions & 1 deletion scripts/bert/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,5 @@

# pylint: disable=wildcard-import
"""BERT data."""
from . import classification, embedding, transform, dataloader, pretrain

from . import classification, embedding, transform
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,7 @@ def create_training_instances(x):
all_documents = [[]]

for input_file in input_files:
logging.debug('*** Tokenizing file %s***', input_file)
with io.open(input_file, 'r', encoding='utf-8') as reader:
lines = reader.readlines()
num_lines = len(lines)
Expand All @@ -248,9 +249,9 @@ def create_training_instances(x):
else:
all_documents[-1].append(line)

# remove the last empty document if any
if not all_documents[-1]:
all_documents = all_documents[:-1]
# remove the empty document if any
all_documents = [x for x in all_documents if x]
random.shuffle(all_documents)

# generate training instances
instances = []
Expand All @@ -264,6 +265,7 @@ def create_training_instances(x):
instances_results = worker_pool.map(create_instances_from_document, process_args)
for instances_result in instances_results:
instances.extend(instances_result)
random.shuffle(instances)
npz_instances = worker_pool.apply(convert_to_npz, (instances, max_seq_length))
else:
for _ in range(dupe_factor):
Expand All @@ -273,6 +275,7 @@ def create_training_instances(x):
(all_documents, document_index, max_seq_length, short_seq_prob,
masked_lm_prob, max_predictions_per_seq, whole_word_mask,
vocab, tokenizer)))
random.shuffle(instances)
npz_instances = convert_to_npz(instances, max_seq_length)

(input_ids, masked_lm_ids, masked_lm_positions, masked_lm_weights,
Expand Down Expand Up @@ -351,8 +354,8 @@ def create_instances_from_document(x):

# randomly choose a document other than itself
random_document_index = random.randint(0, len(all_documents) - 2)
if random_document_index >= document_index:
random_document_index += 1
if random_document_index == document_index:
random_document_index = len(all_documents) - 1

random_document = all_documents[random_document_index]
random_start = random.randint(0, len(random_document) - 1)
Expand Down Expand Up @@ -646,7 +649,7 @@ def main():
parser.add_argument(
'--dupe_factor',
type=int,
default=1,
default=5,
help='Number of times to duplicate the input data (with different masks).')

parser.add_argument(
Expand Down
242 changes: 0 additions & 242 deletions scripts/bert/data/dataloader.py

This file was deleted.

Loading

0 comments on commit 8d43da9

Please sign in to comment.