Skip to content

Commit

Permalink
Merge pull request #809 from guoshengCS/refine-transformer-token
Browse files Browse the repository at this point in the history
Refine the inference to output special tokens optionally in Transformer
  • Loading branch information
guoshengCS authored Apr 4, 2018
2 parents 288664c + 27fd97c commit c2de925
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 17 deletions.
7 changes: 7 additions & 0 deletions fluid/neural_machine_translation/transformer/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,11 @@ class InferTaskConfig(object):
# the number of decoded sentences to output.
n_best = 1

# the flags indicating whether to output the special tokens.
output_bos = False
output_eos = False
output_unk = False

# the directory for loading the trained model.
model_path = "trained_models/pass_1.infer.model"

Expand All @@ -56,6 +61,8 @@ class ModelHyperParams(object):
bos_idx = 0
# index for <eos> token
eos_idx = 1
# index for <unk> token
unk_idx = 2

# position value corresponding to the <pad> token.
pos_pad_idx = 0
Expand Down
84 changes: 67 additions & 17 deletions fluid/neural_machine_translation/transformer/infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,25 @@
from train import pad_batch_data


def translate_batch(exe, src_words, encoder, enc_in_names, enc_out_names,
decoder, dec_in_names, dec_out_names, beam_size, max_length,
n_best, batch_size, n_head, src_pad_idx, trg_pad_idx,
bos_idx, eos_idx):
def translate_batch(exe,
src_words,
encoder,
enc_in_names,
enc_out_names,
decoder,
dec_in_names,
dec_out_names,
beam_size,
max_length,
n_best,
batch_size,
n_head,
src_pad_idx,
trg_pad_idx,
bos_idx,
eos_idx,
unk_idx,
output_unk=True):
"""
Run the encoder program once and run the decoder program multiple times to
implement beam search externally.
Expand Down Expand Up @@ -48,7 +63,7 @@ def translate_batch(exe, src_words, encoder, enc_in_names, enc_out_names,
# size of feeded batch is changing.
beam_map = range(batch_size)

def beam_backtrace(prev_branchs, next_ids, n_best=beam_size, add_bos=True):
def beam_backtrace(prev_branchs, next_ids, n_best=beam_size):
"""
Decode and select n_best sequences for one instance by backtrace.
"""
Expand All @@ -60,7 +75,8 @@ def beam_backtrace(prev_branchs, next_ids, n_best=beam_size, add_bos=True):
seq.append(next_ids[j][k])
k = prev_branchs[j][k]
seq = seq[::-1]
seq = [bos_idx] + seq if add_bos else seq
# Add the <bos>, since next_ids don't include the <bos>.
seq = [bos_idx] + seq
seqs.append(seq)
return seqs

Expand Down Expand Up @@ -114,8 +130,7 @@ def update_dec_in_data(dec_in_data, next_ids, active_beams):
trg_cur_len = len(next_ids[0]) + 1 # include the <bos>
trg_words = np.array(
[
beam_backtrace(
prev_branchs[beam_idx], next_ids[beam_idx], add_bos=True)
beam_backtrace(prev_branchs[beam_idx], next_ids[beam_idx])
for beam_idx in active_beams
],
dtype="int64")
Expand Down Expand Up @@ -167,6 +182,8 @@ def update_dec_in_data(dec_in_data, next_ids, active_beams):
predict_all = (predict_all + scores[beam_map].reshape(
[len(beam_map) * beam_size, -1])).reshape(
[len(beam_map), beam_size, -1])
if not output_unk: # To exclude the <unk> token.
predict_all[:, :, unk_idx] = -1e9
active_beams = []
for inst_idx, beam_idx in enumerate(beam_map):
predict = (predict_all[inst_idx, :, :]
Expand All @@ -187,7 +204,10 @@ def update_dec_in_data(dec_in_data, next_ids, active_beams):
dec_in_data = update_dec_in_data(dec_in_data, next_ids, active_beams)

# Decode beams and select n_best sequences for each instance by backtrace.
seqs = [beam_backtrace(prev_branchs[beam_idx], next_ids[beam_idx], n_best)]
seqs = [
beam_backtrace(prev_branchs[beam_idx], next_ids[beam_idx], n_best)
for beam_idx in range(batch_size)
]

return seqs, scores[:, :n_best].tolist()

Expand Down Expand Up @@ -254,17 +274,47 @@ def main():
trg_idx2word = paddle.dataset.wmt16.get_dict(
"de", dict_size=ModelHyperParams.trg_vocab_size, reverse=True)

def post_process_seq(seq,
bos_idx=ModelHyperParams.bos_idx,
eos_idx=ModelHyperParams.eos_idx,
output_bos=InferTaskConfig.output_bos,
output_eos=InferTaskConfig.output_eos):
"""
Post-process the beam-search decoded sequence. Truncate from the first
<eos> and remove the <bos> and <eos> tokens currently.
"""
eos_pos = len(seq) - 1
for i, idx in enumerate(seq):
if idx == eos_idx:
eos_pos = i
break
seq = seq[:eos_pos + 1]
return filter(
lambda idx: (output_bos or idx != bos_idx) and \
(output_eos or idx != eos_idx),
seq)

for batch_id, data in enumerate(test_data()):
batch_seqs, batch_scores = translate_batch(
exe, [item[0] for item in data], encoder_program,
encoder_input_data_names, [enc_output.name], decoder_program,
decoder_input_data_names, [predict.name], InferTaskConfig.beam_size,
InferTaskConfig.max_length, InferTaskConfig.n_best,
len(data), ModelHyperParams.n_head, ModelHyperParams.src_pad_idx,
ModelHyperParams.trg_pad_idx, ModelHyperParams.bos_idx,
ModelHyperParams.eos_idx)
exe, [item[0] for item in data],
encoder_program,
encoder_input_data_names, [enc_output.name],
decoder_program,
decoder_input_data_names, [predict.name],
InferTaskConfig.beam_size,
InferTaskConfig.max_length,
InferTaskConfig.n_best,
len(data),
ModelHyperParams.n_head,
ModelHyperParams.src_pad_idx,
ModelHyperParams.trg_pad_idx,
ModelHyperParams.bos_idx,
ModelHyperParams.eos_idx,
ModelHyperParams.unk_idx,
output_unk=InferTaskConfig.output_unk)
for i in range(len(batch_seqs)):
seqs = batch_seqs[i]
# Post-process the beam-search decoded sequences.
seqs = map(post_process_seq, batch_seqs[i])
scores = batch_scores[i]
for seq in seqs:
print(" ".join([trg_idx2word[idx] for idx in seq]))
Expand Down

0 comments on commit c2de925

Please sign in to comment.