Skip to content

Commit e8d609a

Browse files
Myle Ottfacebook-github-bot
Myle Ott
authored andcommitted
Add new Masked LM task + criterion
Summary: Pull Request resolved: fairinternal/fairseq-py#761 Differential Revision: D16421335 Pulled By: myleott fbshipit-source-id: 257d92c2b90361147642e2baa38486b4d18f6297
1 parent 654affc commit e8d609a

File tree

2 files changed

+299
-0
lines changed

2 files changed

+299
-0
lines changed

fairseq/criterions/masked_lm.py

+72
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
# Copyright (c) 2017-present, Facebook, Inc.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the license found in the LICENSE file in
5+
# the root directory of this source tree. An additional grant of patent rights
6+
# can be found in the PATENTS file in the same directory.
7+
8+
import math
9+
10+
import torch
11+
import torch.nn.functional as F
12+
13+
from fairseq import utils
14+
15+
from . import FairseqCriterion, register_criterion
16+
17+
18+
@register_criterion('masked_lm')
19+
class MaskedLmLoss(FairseqCriterion):
20+
"""
21+
Implementation for the loss used in masked language model (MLM) training.
22+
"""
23+
24+
def __init__(self, args, task):
25+
super().__init__(args, task)
26+
27+
def forward(self, model, sample, reduce=True):
28+
"""Compute the loss for the given sample.
29+
Returns a tuple with three elements:
30+
1) the loss
31+
2) the sample size, which is used as the denominator for the gradient
32+
3) logging outputs to display while training
33+
"""
34+
# compute MLM loss
35+
logits = model(**sample['net_input'], last_state_only=True)[0]
36+
targets = model.get_targets(sample, [logits])
37+
loss = F.nll_loss(
38+
F.log_softmax(
39+
logits.view(-1, logits.size(-1)),
40+
dim=-1,
41+
dtype=torch.float32,
42+
),
43+
targets.view(-1),
44+
reduction='sum',
45+
ignore_index=self.padding_idx,
46+
)
47+
48+
sample_size = targets.ne(self.padding_idx).int().sum().item()
49+
50+
logging_output = {
51+
'loss': utils.item(loss.data) if reduce else loss.data,
52+
'ntokens': sample['ntokens'],
53+
'nsentences': sample['nsentences'],
54+
'sample_size': sample_size,
55+
}
56+
return loss, sample_size, logging_output
57+
58+
@staticmethod
59+
def aggregate_logging_outputs(logging_outputs):
60+
"""Aggregate logging outputs from data parallel training."""
61+
loss = sum(log.get('loss', 0) for log in logging_outputs)
62+
ntokens = sum(log.get('ntokens', 0) for log in logging_outputs)
63+
nsentences = sum(log.get('nsentences', 0) for log in logging_outputs)
64+
sample_size = sum(log.get('sample_size', 0) for log in logging_outputs)
65+
66+
agg_output = {
67+
'loss': loss / sample_size / math.log(2),
68+
'ntokens': ntokens,
69+
'nsentences': nsentences,
70+
'sample_size': sample_size,
71+
}
72+
return agg_output

fairseq/tasks/masked_lm.py

+227
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,227 @@
1+
# Copyright (c) 2017-present, Facebook, Inc.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the license found in the LICENSE file in
5+
# the root directory of this source tree. An additional grant of patent rights
6+
# can be found in the PATENTS file in the same directory.
7+
8+
import itertools
9+
import os
10+
11+
import numpy as np
12+
import torch
13+
import torch.nn.functional as F
14+
15+
from fairseq.data import (
16+
ConcatDataset,
17+
data_utils,
18+
Dictionary,
19+
encoders,
20+
IdDataset,
21+
indexed_dataset,
22+
MaskTokensDataset,
23+
NestedDictionaryDataset,
24+
NumelDataset,
25+
NumSamplesDataset,
26+
PadDataset,
27+
PrependTokenDataset,
28+
SortDataset,
29+
TokenBlockDataset,
30+
)
31+
from fairseq.tasks import FairseqTask, register_task
32+
33+
34+
@register_task('masked_lm')
35+
class MaskedLMTask(FairseqTask):
36+
"""Task for training masked language models (e.g., BERT, RoBERTa)."""
37+
38+
@staticmethod
39+
def add_args(parser):
40+
"""Add task-specific arguments to the parser."""
41+
parser.add_argument('data', help='colon separated path to data directories list, \
42+
will be iterated upon during epochs in round-robin manner')
43+
parser.add_argument('--sample-break-mode', default='complete',
44+
choices=['none', 'complete', 'complete_doc', 'eos'],
45+
help='If omitted or "none", fills each sample with tokens-per-sample '
46+
'tokens. If set to "complete", splits samples only at the end '
47+
'of sentence, but may include multiple sentences per sample. '
48+
'"complete_doc" is similar but respects doc boundaries. '
49+
'If set to "eos", includes only one sentence per sample.')
50+
parser.add_argument('--tokens-per-sample', default=512, type=int,
51+
help='max number of total tokens over all segments '
52+
'per sample for BERT dataset')
53+
parser.add_argument('--mask-prob', default=0.15, type=float,
54+
help='probability of replacing a token with mask')
55+
parser.add_argument('--leave-unmasked-prob', default=0.1, type=float,
56+
help='probability that a masked token is unmasked')
57+
parser.add_argument('--random-token-prob', default=0.1, type=float,
58+
help='probability of replacing a token with a random token')
59+
parser.add_argument('--freq-weighted-replacement', action='store_true',
60+
help='sample random replacement words based on word frequencies')
61+
parser.add_argument('--mask-whole-words', default=False, action='store_true',
62+
help='mask whole words; you may also want to set --bpe')
63+
64+
def __init__(self, args, dictionary):
65+
super().__init__(args)
66+
self.dictionary = dictionary
67+
self.seed = args.seed
68+
69+
# add mask token
70+
self.mask_idx = dictionary.add_symbol('<mask>')
71+
72+
@classmethod
73+
def setup_task(cls, args, **kwargs):
74+
paths = args.data.split(':')
75+
assert len(paths) > 0
76+
dictionary = Dictionary.load(os.path.join(paths[0], 'dict.txt'))
77+
print('| dictionary: {} types'.format(len(dictionary)))
78+
return cls(args, dictionary)
79+
80+
def load_dataset(self, split, epoch=0, combine=False):
81+
"""Load a given dataset split.
82+
83+
Args:
84+
split (str): name of the split (e.g., train, valid, test)
85+
"""
86+
paths = self.args.data.split(':')
87+
assert len(paths) > 0
88+
data_path = paths[epoch % len(paths)]
89+
split_path = os.path.join(data_path, split)
90+
91+
dataset = data_utils.load_indexed_dataset(
92+
split_path,
93+
self.source_dictionary,
94+
self.args.dataset_impl,
95+
combine=combine,
96+
)
97+
if dataset is None:
98+
raise FileNotFoundError('Dataset not found: {} ({})'.format(split, split_path))
99+
100+
# create continuous blocks of tokens
101+
dataset = TokenBlockDataset(
102+
dataset,
103+
dataset.sizes,
104+
self.args.tokens_per_sample - 1, # one less for <s>
105+
pad=self.source_dictionary.pad(),
106+
eos=self.source_dictionary.eos(),
107+
break_mode=self.args.sample_break_mode,
108+
)
109+
110+
# prepend beginning-of-sentence token (<s>, equiv. to [CLS] in BERT)
111+
dataset = PrependTokenDataset(dataset, self.source_dictionary.bos())
112+
113+
# create masked input and targets
114+
if self.args.mask_whole_words:
115+
bpe = encoders.build_bpe(self.args)
116+
if bpe is not None:
117+
118+
def is_beginning_of_word(i):
119+
if i < self.source_dictionary.nspecial:
120+
# special elements are always considered beginnings
121+
return True
122+
tok = self.source_dictionary[i]
123+
if tok.startswith('madeupword'):
124+
return True
125+
try:
126+
return bpe.is_beginning_of_word(tok)
127+
except ValueError:
128+
return True
129+
130+
mask_whole_words = torch.ByteTensor(list(
131+
map(is_beginning_of_word, range(len(self.source_dictionary)))
132+
))
133+
else:
134+
mask_whole_words = None
135+
136+
src_dataset, tgt_dataset = MaskTokensDataset.apply_mask(
137+
dataset,
138+
self.source_dictionary,
139+
pad_idx=self.source_dictionary.pad(),
140+
mask_idx=self.mask_idx,
141+
seed=self.args.seed,
142+
mask_prob=self.args.mask_prob,
143+
leave_unmasked_prob=self.args.leave_unmasked_prob,
144+
random_token_prob=self.args.random_token_prob,
145+
freq_weighted_replacement=self.args.freq_weighted_replacement,
146+
mask_whole_words=mask_whole_words,
147+
)
148+
149+
with data_utils.numpy_seed(self.args.seed + epoch):
150+
shuffle = np.random.permutation(len(src_dataset))
151+
152+
self.datasets[split] = SortDataset(
153+
NestedDictionaryDataset(
154+
{
155+
'id': IdDataset(),
156+
'net_input': {
157+
'src_tokens': PadDataset(
158+
src_dataset,
159+
pad_idx=self.source_dictionary.pad(),
160+
left_pad=False,
161+
),
162+
'src_lengths': NumelDataset(src_dataset, reduce=False),
163+
},
164+
'target': PadDataset(
165+
tgt_dataset,
166+
pad_idx=self.source_dictionary.pad(),
167+
left_pad=False,
168+
),
169+
'nsentences': NumSamplesDataset(),
170+
'ntokens': NumelDataset(src_dataset, reduce=True),
171+
},
172+
sizes=[src_dataset.sizes],
173+
),
174+
sort_order=[
175+
shuffle,
176+
src_dataset.sizes,
177+
],
178+
)
179+
180+
def build_dataset_for_inference(self, src_tokens, src_lengths, sort=True):
181+
if self.args.also_lowercase_words:
182+
raise NotImplementedError
183+
src_dataset = PadDataset(
184+
TokenBlockDataset(
185+
src_tokens,
186+
src_lengths,
187+
self.args.tokens_per_sample - 1, # one less for <s>
188+
pad=self.source_dictionary.pad(),
189+
eos=self.source_dictionary.eos(),
190+
break_mode='eos',
191+
),
192+
pad_idx=self.source_dictionary.pad(),
193+
left_pad=False,
194+
)
195+
src_dataset = PrependTokenDataset(src_dataset, self.source_dictionary.bos())
196+
src_dataset = NestedDictionaryDataset(
197+
{
198+
'id': IdDataset(),
199+
'net_input': {
200+
'src_tokens': src_dataset,
201+
'src_lengths': NumelDataset(src_dataset, reduce=False),
202+
},
203+
},
204+
sizes=src_lengths,
205+
)
206+
if sort:
207+
src_dataset = SortDataset(src_dataset, sort_order=[src_lengths])
208+
return src_dataset
209+
210+
@property
211+
def source_dictionary(self):
212+
return self.dictionary
213+
214+
@property
215+
def target_dictionary(self):
216+
return self.dictionary
217+
218+
def get_average_masked_score(self, model, src_tokens, mask, **net_input):
219+
"""Mask a set of tokens and return their average score."""
220+
masked_tokens = src_tokens.clone()
221+
masked_tokens[mask.byte()] = self.mask_idx
222+
net_output = model(src_tokens=masked_tokens, **net_input, last_state_only=True)
223+
lprobs = F.log_softmax(net_output[0], dim=-1, dtype=torch.float32)
224+
lprobs = lprobs.gather(-1, src_tokens.unsqueeze(-1)).squeeze(-1)
225+
mask = mask.type_as(lprobs)
226+
score = (lprobs * mask).sum(dim=-1) / mask.sum(dim=-1)
227+
return score

0 commit comments

Comments
 (0)