-
Notifications
You must be signed in to change notification settings - Fork 33
/
Copy pathmodeling_unilm.py
715 lines (618 loc) · 31.6 KB
/
modeling_unilm.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
# coding=utf-8
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import copy
import math
import logging
import numpy as np
import torch
from torch import nn
import torch.nn.functional as F
from torch.nn.modules.loss import _Loss
from transformers.modeling_utils import PreTrainedModel
from configuration_unilm import UnilmConfig
from transformers.modeling_bert import load_tf_weights_in_bert, BertPooler, BertIntermediate, BertOutput, BertPredictionHeadTransform, BertSelfOutput, BertLMPredictionHead, BertOnlyMLMHead, BertOnlyMLMHead, BertEmbeddings, BertOnlyNSPHead
logger = logging.getLogger(__name__)
UNILM_PRETRAINED_MODEL_ARCHIVE_MAP = {
'unilm-base-cased': "",
'unilm-large-cased': ""
}
BertLayerNorm = torch.nn.LayerNorm
class BertSelfAttention(nn.Module):
def __init__(self, config):
super(BertSelfAttention, self).__init__()
if config.hidden_size % config.num_attention_heads != 0:
raise ValueError(
"The hidden size (%d) is not a multiple of the number of attention "
"heads (%d)" % (config.hidden_size, config.num_attention_heads))
self.num_attention_heads = config.num_attention_heads
self.attention_head_size = int(
config.hidden_size / config.num_attention_heads)
self.all_head_size = self.num_attention_heads * self.attention_head_size
self.query = nn.Linear(config.hidden_size, self.all_head_size)
self.key = nn.Linear(config.hidden_size, self.all_head_size)
self.value = nn.Linear(config.hidden_size, self.all_head_size)
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
def transpose_for_scores(self, x):
sz = x.size()[:-1] + (self.num_attention_heads,
self.attention_head_size)
x = x.view(*sz)
return x.permute(0, 2, 1, 3)
def forward(self, hidden_states, attention_mask, history_states=None):
if history_states is None:
mixed_query_layer = self.query(hidden_states)
mixed_key_layer = self.key(hidden_states)
mixed_value_layer = self.value(hidden_states)
else:
x_states = torch.cat((history_states, hidden_states), dim=1)
mixed_query_layer = self.query(hidden_states)
mixed_key_layer = self.key(x_states)
mixed_value_layer = self.value(x_states)
query_layer = self.transpose_for_scores(mixed_query_layer)
key_layer = self.transpose_for_scores(mixed_key_layer)
value_layer = self.transpose_for_scores(mixed_value_layer)
attention_scores = torch.matmul(
query_layer / math.sqrt(self.attention_head_size), key_layer.transpose(-1, -2))
attention_scores = attention_scores + attention_mask
attention_probs = nn.Softmax(dim=-1)(attention_scores)
attention_probs = self.dropout(attention_probs)
context_layer = torch.matmul(attention_probs, value_layer)
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[
:-2] + (self.all_head_size,)
context_layer = context_layer.view(*new_context_layer_shape)
return context_layer
class BertAttention(nn.Module):
def __init__(self, config):
super(BertAttention, self).__init__()
self.self = BertSelfAttention(config)
self.output = BertSelfOutput(config)
def forward(self, input_tensor, attention_mask, history_states=None):
self_output = self.self(
input_tensor, attention_mask, history_states=history_states)
attention_output = self.output(self_output, input_tensor)
return attention_output
class BertLayer(nn.Module):
def __init__(self, config):
super(BertLayer, self).__init__()
self.attention = BertAttention(config)
self.intermediate = BertIntermediate(config)
self.output = BertOutput(config)
def forward(self, hidden_states, attention_mask, history_states=None):
attention_output = self.attention(
hidden_states, attention_mask, history_states=history_states)
intermediate_output = self.intermediate(attention_output)
layer_output = self.output(intermediate_output, attention_output)
return layer_output
class BertEncoder(nn.Module):
def __init__(self, config):
super(BertEncoder, self).__init__()
layer = BertLayer(config)
self.layer = nn.ModuleList([copy.deepcopy(layer)
for _ in range(config.num_hidden_layers)])
def forward(self, hidden_states, attention_mask, output_all_encoded_layers=True, prev_embedding=None, prev_encoded_layers=None):
assert (prev_embedding is None) == (prev_encoded_layers is None)
all_encoder_layers = []
if (prev_embedding is not None) and (prev_encoded_layers is not None):
history_states = prev_embedding
for i, layer_module in enumerate(self.layer):
hidden_states = layer_module(
hidden_states, attention_mask, history_states=history_states)
if output_all_encoded_layers:
all_encoder_layers.append(hidden_states)
if prev_encoded_layers is not None:
history_states = prev_encoded_layers[i]
else:
for layer_module in self.layer:
hidden_states = layer_module(
hidden_states, attention_mask)
if output_all_encoded_layers:
all_encoder_layers.append(hidden_states)
if not output_all_encoded_layers:
all_encoder_layers.append(hidden_states)
return all_encoder_layers
class UnilmPreTrainedModel(PreTrainedModel):
config_class = UnilmConfig
pretrained_model_archive_map = UNILM_PRETRAINED_MODEL_ARCHIVE_MAP
load_tf_weights = load_tf_weights_in_bert
base_model_prefix = "unilm"
def _init_weights(self, module):
if isinstance(module, (nn.Linear, nn.Embedding)):
module.weight.data.normal_(
mean=0.0, std=self.config.initializer_range)
elif isinstance(module, BertLayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
if isinstance(module, nn.Linear) and module.bias is not None:
module.bias.data.zero_()
class UnilmModel(UnilmPreTrainedModel):
def __init__(self, config):
super(UnilmModel, self).__init__(config)
self.embeddings = BertEmbeddings(config)
self.encoder = BertEncoder(config)
self.pooler = BertPooler(config)
self.init_weights()
def get_extended_attention_mask(self, input_ids, token_type_ids, attention_mask):
if attention_mask is None:
attention_mask = torch.ones_like(input_ids)
if token_type_ids is None:
token_type_ids = torch.zeros_like(input_ids)
if attention_mask.dim() == 2:
extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
elif attention_mask.dim() == 3:
extended_attention_mask = attention_mask.unsqueeze(1)
else:
raise NotImplementedError
extended_attention_mask = extended_attention_mask.to(
dtype=next(self.parameters()).dtype) # fp16 compatibility
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
return extended_attention_mask
def forward(self, input_ids, token_type_ids=None, attention_mask=None, output_all_encoded_layers=True):
extended_attention_mask = self.get_extended_attention_mask(
input_ids, token_type_ids, attention_mask)
embedding_output = self.embeddings(
input_ids, token_type_ids)
encoded_layers = self.encoder(embedding_output, extended_attention_mask,
output_all_encoded_layers=output_all_encoded_layers)
sequence_output = encoded_layers[-1]
pooled_output = self.pooler(sequence_output)
if not output_all_encoded_layers:
encoded_layers = encoded_layers[-1]
return encoded_layers, pooled_output
class UnilmModelIncr(UnilmModel):
def __init__(self, config):
super(UnilmModelIncr, self).__init__(config)
def forward(self, input_ids, token_type_ids, position_ids, attention_mask, output_all_encoded_layers=True, prev_embedding=None,
prev_encoded_layers=None):
extended_attention_mask = self.get_extended_attention_mask(
input_ids, token_type_ids, attention_mask)
embedding_output = self.embeddings(
input_ids, token_type_ids, position_ids)
encoded_layers = self.encoder(embedding_output,
extended_attention_mask,
output_all_encoded_layers=output_all_encoded_layers,
prev_embedding=prev_embedding,
prev_encoded_layers=prev_encoded_layers)
sequence_output = encoded_layers[-1]
pooled_output = self.pooler(sequence_output)
if not output_all_encoded_layers:
encoded_layers = encoded_layers[-1]
return embedding_output, encoded_layers, pooled_output
class LabelSmoothingLoss(_Loss):
def __init__(self, label_smoothing=0, tgt_vocab_size=0, ignore_index=0, size_average=None, reduce=None, reduction='mean'):
assert 0.0 < label_smoothing <= 1.0
self.ignore_index = ignore_index
super(LabelSmoothingLoss, self).__init__(
size_average=size_average, reduce=reduce, reduction=reduction)
assert label_smoothing > 0
assert tgt_vocab_size > 0
smoothing_value = label_smoothing / (tgt_vocab_size - 2)
one_hot = torch.full((tgt_vocab_size,), smoothing_value)
one_hot[self.ignore_index] = 0
self.register_buffer('one_hot', one_hot.unsqueeze(0))
self.confidence = 1.0 - label_smoothing
self.tgt_vocab_size = tgt_vocab_size
def forward(self, output, target):
assert self.tgt_vocab_size == output.size(2)
batch_size, num_pos = target.size(0), target.size(1)
output = output.view(-1, self.tgt_vocab_size)
target = target.view(-1)
model_prob = self.one_hot.repeat(target.size(0), 1)
model_prob.scatter_(1, target.unsqueeze(1), self.confidence)
model_prob.masked_fill_((target == self.ignore_index).unsqueeze(1), 0)
return F.kl_div(output, model_prob.type_as(output), reduction='none').view(batch_size, num_pos, -1).sum(2)
class UnilmForLM(UnilmPreTrainedModel):
def __init__(self, config):
super(UnilmForLM, self).__init__(config)
self.bert = UnilmModel(config)
self.cls = BertOnlyMLMHead(config)
self.crit_mask_lm = nn.CrossEntropyLoss(reduction='none')
if hasattr(config, 'label_smoothing') and config.label_smoothing:
self.crit_mask_lm_smoothed = LabelSmoothingLoss(
config.label_smoothing, config.vocab_size, ignore_index=0, reduction='none')
else:
self.crit_mask_lm_smoothed = None
self.num_labels = 2
self.cls2 = BertOnlyNSPHead(config)
self.crit_next_sent = nn.CrossEntropyLoss(ignore_index=-1)
self.init_weights()
self.tie_weights()
def tie_weights(self):
self._tie_or_clone_weights(self.cls.predictions.decoder,
self.bert.embeddings.word_embeddings)
def forward(self, input_ids, token_type_ids=None, attention_mask=None, masked_lm_labels=None, masked_pos=None, masked_weights=None, next_sentence_label=None):
sequence_output, pooled_output = self.bert(
input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False)
def gather_seq_out_by_pos(seq, pos):
return torch.gather(seq, 1, pos.unsqueeze(2).expand(-1, -1, seq.size(-1)))
def gather_seq_out_by_pos_average(seq, pos, mask):
batch_size, max_token_num = pos.size(0), pos.size(-1)
pos_vec = torch.gather(seq, 1, pos.view(batch_size, -1).unsqueeze(
2).expand(-1, -1, seq.size(-1))).view(batch_size, -1, max_token_num, seq.size(-1))
mask = mask.type_as(pos_vec)
pos_vec_masked_sum = (
pos_vec * mask.unsqueeze(3).expand_as(pos_vec)).sum(2)
return pos_vec_masked_sum / mask.sum(2, keepdim=True).expand_as(pos_vec_masked_sum)
def loss_mask_and_normalize(loss, mask):
mask = mask.type_as(loss)
loss = loss * mask
denominator = torch.sum(mask) + 1e-5
return (loss / denominator).sum()
if masked_lm_labels is None:
if masked_pos is None:
prediction_scores = self.cls(sequence_output)
else:
sequence_output_masked = gather_seq_out_by_pos(
sequence_output, masked_pos)
prediction_scores = self.cls(sequence_output_masked)
return prediction_scores
sequence_output_masked = gather_seq_out_by_pos(
sequence_output, masked_pos)
prediction_scores_masked = self.cls(sequence_output_masked)
if self.crit_mask_lm_smoothed:
masked_lm_loss = self.crit_mask_lm_smoothed(
F.log_softmax(prediction_scores_masked.float(), dim=-1), masked_lm_labels)
else:
masked_lm_loss = self.crit_mask_lm(
prediction_scores_masked.transpose(1, 2).float(), masked_lm_labels)
masked_lm_loss = loss_mask_and_normalize(
masked_lm_loss.float(), masked_weights)
seq_relationship_score = self.cls2(pooled_output)
if next_sentence_label is None:
total_loss = masked_lm_loss
else:
next_sentence_loss = self.crit_next_sent(
seq_relationship_score.view(-1, self.num_labels).float(), next_sentence_label.view(-1))
total_loss = next_sentence_loss + masked_lm_loss
return total_loss
class UnilmForSeq2Seq(UnilmPreTrainedModel):
"""refer to BertForPreTraining"""
def __init__(self, config):
super(UnilmForSeq2Seq, self).__init__(config)
self.bert = UnilmModel(config)
self.cls = BertOnlyMLMHead(config)
self.crit_mask_lm = nn.CrossEntropyLoss(reduction='none')
if hasattr(config, 'label_smoothing') and config.label_smoothing:
self.crit_mask_lm_smoothed = LabelSmoothingLoss(
config.label_smoothing, config.vocab_size, ignore_index=0, reduction='none')
else:
self.crit_mask_lm_smoothed = None
self.init_weights()
self.tie_weights()
def tie_weights(self):
""" Make sure we are sharing the input and output embeddings.
Export to TorchScript can't handle parameter sharing so we are cloning them instead.
"""
self._tie_or_clone_weights(self.cls.predictions.decoder,
self.bert.embeddings.word_embeddings)
def forward(self, input_ids, token_type_ids=None, attention_mask=None, masked_lm_labels=None, masked_pos=None, masked_weights=None, num_tokens_a=None, num_tokens_b=None):
sequence_output, __ = self.bert(
input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False)
def gather_seq_out_by_pos(seq, pos):
return torch.gather(seq, 1, pos.unsqueeze(2).expand(-1, -1, seq.size(-1)))
def gather_seq_out_by_pos_average(seq, pos, mask):
batch_size, max_token_num = pos.size(0), pos.size(-1)
pos_vec = torch.gather(seq, 1, pos.view(batch_size, -1).unsqueeze(
2).expand(-1, -1, seq.size(-1))).view(batch_size, -1, max_token_num, seq.size(-1))
mask = mask.type_as(pos_vec)
pos_vec_masked_sum = (
pos_vec * mask.unsqueeze(3).expand_as(pos_vec)).sum(2)
return pos_vec_masked_sum / mask.sum(2, keepdim=True).expand_as(pos_vec_masked_sum)
def loss_mask_and_normalize(loss, mask):
mask = mask.type_as(loss)
loss = loss * mask
denominator = torch.sum(mask) + 1e-5
return (loss / denominator).sum()
if masked_lm_labels is None:
if masked_pos is None:
prediction_scores = self.cls(sequence_output)
else:
sequence_output_masked = gather_seq_out_by_pos(
sequence_output, masked_pos)
prediction_scores = self.cls(sequence_output_masked)
return prediction_scores
sequence_output_masked = gather_seq_out_by_pos(
sequence_output, masked_pos)
prediction_scores_masked = self.cls(sequence_output_masked)
if self.crit_mask_lm_smoothed:
masked_lm_loss = self.crit_mask_lm_smoothed(
F.log_softmax(prediction_scores_masked.float(), dim=-1), masked_lm_labels)
else:
masked_lm_loss = self.crit_mask_lm(
prediction_scores_masked.transpose(1, 2).float(), masked_lm_labels)
masked_lm_loss = loss_mask_and_normalize(
masked_lm_loss.float(), masked_weights)
return masked_lm_loss
class UnilmForSeq2SeqDecode(UnilmPreTrainedModel):
def __init__(self, config, mask_word_id=0,
search_beam_size=1, length_penalty=1.0, eos_id=0, sos_id=0,
forbid_duplicate_ngrams=False, forbid_ignore_set=None, ngram_size=3, min_len=0):
super(UnilmForSeq2SeqDecode, self).__init__(config)
self.bert = UnilmModelIncr(config)
self.cls = BertOnlyMLMHead(config)
self.crit_mask_lm = nn.CrossEntropyLoss(reduction='none')
self.mask_word_id = mask_word_id
self.search_beam_size = search_beam_size
self.length_penalty = length_penalty
self.eos_id = eos_id
self.sos_id = sos_id
self.forbid_duplicate_ngrams = forbid_duplicate_ngrams
self.forbid_ignore_set = forbid_ignore_set
self.ngram_size = ngram_size
self.min_len = min_len
self.init_weights()
self.tie_weights()
def tie_weights(self):
self._tie_or_clone_weights(self.cls.predictions.decoder,
self.bert.embeddings.word_embeddings)
def forward(self, input_ids, token_type_ids, position_ids, attention_mask):
if self.search_beam_size > 1:
return self.beam_search(input_ids, token_type_ids, position_ids, attention_mask)
input_shape = list(input_ids.size())
batch_size = input_shape[0]
input_length = input_shape[1]
output_shape = list(token_type_ids.size())
output_length = output_shape[1]
output_ids = []
prev_embedding = None
prev_encoded_layers = None
curr_ids = input_ids
mask_ids = input_ids.new(batch_size, 1).fill_(self.mask_word_id)
next_pos = input_length
while next_pos < output_length:
curr_length = list(curr_ids.size())[1]
start_pos = next_pos - curr_length
x_input_ids = torch.cat((curr_ids, mask_ids), dim=1)
curr_token_type_ids = token_type_ids[:, start_pos:next_pos+1]
curr_attention_mask = attention_mask[:,
start_pos:next_pos+1, :next_pos+1]
curr_position_ids = position_ids[:, start_pos:next_pos+1]
new_embedding, new_encoded_layers, _ = \
self.bert(x_input_ids, curr_token_type_ids, curr_position_ids, curr_attention_mask,
output_all_encoded_layers=True, prev_embedding=prev_embedding, prev_encoded_layers=prev_encoded_layers)
last_hidden = new_encoded_layers[-1][:, -1:, :]
prediction_scores = self.cls(last_hidden)
_, max_ids = torch.max(prediction_scores, dim=-1)
output_ids.append(max_ids)
if prev_embedding is None:
prev_embedding = new_embedding[:, :-1, :]
else:
prev_embedding = torch.cat(
(prev_embedding, new_embedding[:, :-1, :]), dim=1)
if prev_encoded_layers is None:
prev_encoded_layers = [x[:, :-1, :]
for x in new_encoded_layers]
else:
prev_encoded_layers = [torch.cat((x[0], x[1][:, :-1, :]), dim=1)
for x in zip(prev_encoded_layers, new_encoded_layers)]
curr_ids = max_ids
next_pos += 1
return torch.cat(output_ids, dim=1)
def beam_search(self, input_ids, token_type_ids, position_ids, attention_mask):
input_shape = list(input_ids.size())
batch_size = input_shape[0]
input_length = input_shape[1]
output_shape = list(token_type_ids.size())
output_length = output_shape[1]
output_ids = []
prev_embedding = None
prev_encoded_layers = None
curr_ids = input_ids
mask_ids = input_ids.new(batch_size, 1).fill_(self.mask_word_id)
next_pos = input_length
K = self.search_beam_size
total_scores = []
beam_masks = []
step_ids = []
step_back_ptrs = []
partial_seqs = []
forbid_word_mask = None
buf_matrix = None
while next_pos < output_length:
curr_length = list(curr_ids.size())[1]
start_pos = next_pos - curr_length
x_input_ids = torch.cat((curr_ids, mask_ids), dim=1)
curr_token_type_ids = token_type_ids[:, start_pos:next_pos + 1]
curr_attention_mask = attention_mask[:,
start_pos:next_pos + 1, :next_pos + 1]
curr_position_ids = position_ids[:, start_pos:next_pos + 1]
new_embedding, new_encoded_layers, _ = \
self.bert(x_input_ids, curr_token_type_ids, curr_position_ids, curr_attention_mask,
output_all_encoded_layers=True, prev_embedding=prev_embedding, prev_encoded_layers=prev_encoded_layers)
last_hidden = new_encoded_layers[-1][:, -1:, :]
prediction_scores = self.cls(last_hidden)
log_scores = torch.nn.functional.log_softmax(
prediction_scores, dim=-1)
if forbid_word_mask is not None:
log_scores += (forbid_word_mask * -10000.0)
if self.min_len and (next_pos-input_length+1 <= self.min_len):
log_scores[:, :, self.eos_id].fill_(-10000.0)
kk_scores, kk_ids = torch.topk(log_scores, k=K)
if len(total_scores) == 0:
k_ids = torch.reshape(kk_ids, [batch_size, K])
back_ptrs = torch.zeros(batch_size, K, dtype=torch.long)
k_scores = torch.reshape(kk_scores, [batch_size, K])
else:
last_eos = torch.reshape(
beam_masks[-1], [batch_size * K, 1, 1])
last_seq_scores = torch.reshape(
total_scores[-1], [batch_size * K, 1, 1])
kk_scores += last_eos * (-10000.0) + last_seq_scores
kk_scores = torch.reshape(kk_scores, [batch_size, K * K])
k_scores, k_ids = torch.topk(kk_scores, k=K)
back_ptrs = torch.div(k_ids, K)
kk_ids = torch.reshape(kk_ids, [batch_size, K * K])
k_ids = torch.gather(kk_ids, 1, k_ids)
step_back_ptrs.append(back_ptrs)
step_ids.append(k_ids)
beam_masks.append(torch.eq(k_ids, self.eos_id).float())
total_scores.append(k_scores)
def first_expand(x):
input_shape = list(x.size())
expanded_shape = input_shape[:1] + [1] + input_shape[1:]
x = torch.reshape(x, expanded_shape)
repeat_count = [1, K] + [1] * (len(input_shape) - 1)
x = x.repeat(*repeat_count)
x = torch.reshape(x, [input_shape[0] * K] + input_shape[1:])
return x
def select_beam_items(x, ids):
id_shape = list(ids.size())
id_rank = len(id_shape)
assert len(id_shape) == 2
x_shape = list(x.size())
x = torch.reshape(x, [batch_size, K] + x_shape[1:])
x_rank = len(x_shape) + 1
assert x_rank >= 2
if id_rank < x_rank:
ids = torch.reshape(
ids, id_shape + [1] * (x_rank - id_rank))
ids = ids.expand(id_shape + x_shape[1:])
y = torch.gather(x, 1, ids)
y = torch.reshape(y, x_shape)
return y
is_first = (prev_embedding is None)
if prev_embedding is None:
prev_embedding = first_expand(new_embedding[:, :-1, :])
else:
prev_embedding = torch.cat(
(prev_embedding, new_embedding[:, :-1, :]), dim=1)
prev_embedding = select_beam_items(
prev_embedding, back_ptrs)
if prev_encoded_layers is None:
prev_encoded_layers = [first_expand(
x[:, :-1, :]) for x in new_encoded_layers]
else:
prev_encoded_layers = [torch.cat((x[0], x[1][:, :-1, :]), dim=1)
for x in zip(prev_encoded_layers, new_encoded_layers)]
prev_encoded_layers = [select_beam_items(
x, back_ptrs) for x in prev_encoded_layers]
curr_ids = torch.reshape(k_ids, [batch_size * K, 1])
if is_first:
token_type_ids = first_expand(token_type_ids)
position_ids = first_expand(position_ids)
attention_mask = first_expand(attention_mask)
mask_ids = first_expand(mask_ids)
if self.forbid_duplicate_ngrams:
wids = step_ids[-1].tolist()
ptrs = step_back_ptrs[-1].tolist()
if is_first:
partial_seqs = []
for b in range(batch_size):
for k in range(K):
partial_seqs.append([wids[b][k]])
else:
new_partial_seqs = []
for b in range(batch_size):
for k in range(K):
new_partial_seqs.append(
partial_seqs[ptrs[b][k] + b * K] + [wids[b][k]])
partial_seqs = new_partial_seqs
def get_dup_ngram_candidates(seq, n):
cands = set()
if len(seq) < n:
return []
tail = seq[-(n-1):]
if self.forbid_ignore_set and any(tk in self.forbid_ignore_set for tk in tail):
return []
for i in range(len(seq) - (n - 1)):
mismatch = False
for j in range(n - 1):
if tail[j] != seq[i + j]:
mismatch = True
break
if (not mismatch) and not(self.forbid_ignore_set and (seq[i + n - 1] in self.forbid_ignore_set)):
cands.add(seq[i + n - 1])
return list(sorted(cands))
if len(partial_seqs[0]) >= self.ngram_size:
dup_cands = []
for seq in partial_seqs:
dup_cands.append(
get_dup_ngram_candidates(seq, self.ngram_size))
if max(len(x) for x in dup_cands) > 0:
if buf_matrix is None:
vocab_size = list(log_scores.size())[-1]
buf_matrix = np.zeros(
(batch_size * K, vocab_size), dtype=float)
else:
buf_matrix.fill(0)
for bk, cands in enumerate(dup_cands):
for i, wid in enumerate(cands):
buf_matrix[bk, wid] = 1.0
forbid_word_mask = torch.tensor(
buf_matrix, dtype=log_scores.dtype)
forbid_word_mask = torch.reshape(
forbid_word_mask, [batch_size * K, 1, vocab_size]).cuda()
else:
forbid_word_mask = None
next_pos += 1
total_scores = [x.tolist() for x in total_scores]
step_ids = [x.tolist() for x in step_ids]
step_back_ptrs = [x.tolist() for x in step_back_ptrs]
traces = {'pred_seq': [], 'scores': [], 'wids': [], 'ptrs': []}
for b in range(batch_size):
scores = [x[b] for x in total_scores]
wids_list = [x[b] for x in step_ids]
ptrs = [x[b] for x in step_back_ptrs]
traces['scores'].append(scores)
traces['wids'].append(wids_list)
traces['ptrs'].append(ptrs)
last_frame_id = len(scores) - 1
for i, wids in enumerate(wids_list):
if all(wid == self.eos_id for wid in wids):
last_frame_id = i
break
max_score = -math.inf
frame_id = -1
pos_in_frame = -1
for fid in range(last_frame_id + 1):
for i, wid in enumerate(wids_list[fid]):
if wid == self.eos_id or fid == last_frame_id:
s = scores[fid][i]
if self.length_penalty > 0:
s /= math.pow((5 + fid + 1) / 6.0,
self.length_penalty)
if s > max_score:
max_score = s
frame_id = fid
pos_in_frame = i
if frame_id == -1:
traces['pred_seq'].append([0])
else:
seq = [wids_list[frame_id][pos_in_frame]]
for fid in range(frame_id, 0, -1):
pos_in_frame = ptrs[fid][pos_in_frame]
seq.append(wids_list[fid - 1][pos_in_frame])
seq.reverse()
traces['pred_seq'].append(seq)
def _pad_sequence(sequences, max_len, padding_value=0):
trailing_dims = sequences[0].size()[1:]
out_dims = (len(sequences), max_len) + trailing_dims
out_tensor = sequences[0].data.new(*out_dims).fill_(padding_value)
for i, tensor in enumerate(sequences):
length = tensor.size(0)
out_tensor[i, :length, ...] = tensor
return out_tensor
for k in ('pred_seq', 'scores', 'wids', 'ptrs'):
ts_list = traces[k]
if not isinstance(ts_list[0], torch.Tensor):
dt = torch.float if k == 'scores' else torch.long
ts_list = [torch.tensor(it, dtype=dt) for it in ts_list]
traces[k] = _pad_sequence(
ts_list, output_length, padding_value=0).to(input_ids.device)
return traces
class UnilmForSeq2SeqDecodeSample(UnilmPreTrainedModel):
def __init__(self, config):
super(UnilmForSeq2SeqDecodeSample, self).__init__(config)
self.bert = UnilmModel(config)
self.cls = BertOnlyMLMHead(config)
self.init_weights()
self.tie_weights()
def tie_weights(self):
self._tie_or_clone_weights(self.cls.predictions.decoder,
self.bert.embeddings.word_embeddings)
def forward(self, input_ids, token_type_ids, attention_mask):
sequence_output, __ = self.bert(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False)
last_hidden = sequence_output[:, -1:, :]
prediction_scores = self.cls(last_hidden)
return prediction_scores