From 27fd97ce1c507d723895f8b4d6ad8a59ce04676f Mon Sep 17 00:00:00 2001 From: guosheng Date: Tue, 3 Apr 2018 20:32:34 +0800 Subject: [PATCH] Refine the inference to output special tokens optionally in Transformer --- .../transformer/config.py | 7 ++ .../transformer/infer.py | 84 +++++++++++++++---- 2 files changed, 74 insertions(+), 17 deletions(-) diff --git a/fluid/neural_machine_translation/transformer/config.py b/fluid/neural_machine_translation/transformer/config.py index 8bfdf6461b..0ccebacfd6 100644 --- a/fluid/neural_machine_translation/transformer/config.py +++ b/fluid/neural_machine_translation/transformer/config.py @@ -31,6 +31,11 @@ class InferTaskConfig(object): # the number of decoded sentences to output. n_best = 1 + # the flags indicating whether to output the special tokens. + output_bos = False + output_eos = False + output_unk = False + # the directory for loading the trained model. model_path = "trained_models/pass_1.infer.model" @@ -56,6 +61,8 @@ class ModelHyperParams(object): bos_idx = 0 # index for token eos_idx = 1 + # index for token + unk_idx = 2 # position value corresponding to the token. pos_pad_idx = 0 diff --git a/fluid/neural_machine_translation/transformer/infer.py b/fluid/neural_machine_translation/transformer/infer.py index 02674df125..14d476105d 100644 --- a/fluid/neural_machine_translation/transformer/infer.py +++ b/fluid/neural_machine_translation/transformer/infer.py @@ -11,10 +11,25 @@ from train import pad_batch_data -def translate_batch(exe, src_words, encoder, enc_in_names, enc_out_names, - decoder, dec_in_names, dec_out_names, beam_size, max_length, - n_best, batch_size, n_head, src_pad_idx, trg_pad_idx, - bos_idx, eos_idx): +def translate_batch(exe, + src_words, + encoder, + enc_in_names, + enc_out_names, + decoder, + dec_in_names, + dec_out_names, + beam_size, + max_length, + n_best, + batch_size, + n_head, + src_pad_idx, + trg_pad_idx, + bos_idx, + eos_idx, + unk_idx, + output_unk=True): """ Run the encoder program once and run the decoder program multiple times to implement beam search externally. @@ -48,7 +63,7 @@ def translate_batch(exe, src_words, encoder, enc_in_names, enc_out_names, # size of feeded batch is changing. beam_map = range(batch_size) - def beam_backtrace(prev_branchs, next_ids, n_best=beam_size, add_bos=True): + def beam_backtrace(prev_branchs, next_ids, n_best=beam_size): """ Decode and select n_best sequences for one instance by backtrace. """ @@ -60,7 +75,8 @@ def beam_backtrace(prev_branchs, next_ids, n_best=beam_size, add_bos=True): seq.append(next_ids[j][k]) k = prev_branchs[j][k] seq = seq[::-1] - seq = [bos_idx] + seq if add_bos else seq + # Add the , since next_ids don't include the . + seq = [bos_idx] + seq seqs.append(seq) return seqs @@ -114,8 +130,7 @@ def update_dec_in_data(dec_in_data, next_ids, active_beams): trg_cur_len = len(next_ids[0]) + 1 # include the trg_words = np.array( [ - beam_backtrace( - prev_branchs[beam_idx], next_ids[beam_idx], add_bos=True) + beam_backtrace(prev_branchs[beam_idx], next_ids[beam_idx]) for beam_idx in active_beams ], dtype="int64") @@ -167,6 +182,8 @@ def update_dec_in_data(dec_in_data, next_ids, active_beams): predict_all = (predict_all + scores[beam_map].reshape( [len(beam_map) * beam_size, -1])).reshape( [len(beam_map), beam_size, -1]) + if not output_unk: # To exclude the token. + predict_all[:, :, unk_idx] = -1e9 active_beams = [] for inst_idx, beam_idx in enumerate(beam_map): predict = (predict_all[inst_idx, :, :] @@ -187,7 +204,10 @@ def update_dec_in_data(dec_in_data, next_ids, active_beams): dec_in_data = update_dec_in_data(dec_in_data, next_ids, active_beams) # Decode beams and select n_best sequences for each instance by backtrace. - seqs = [beam_backtrace(prev_branchs[beam_idx], next_ids[beam_idx], n_best)] + seqs = [ + beam_backtrace(prev_branchs[beam_idx], next_ids[beam_idx], n_best) + for beam_idx in range(batch_size) + ] return seqs, scores[:, :n_best].tolist() @@ -254,17 +274,47 @@ def main(): trg_idx2word = paddle.dataset.wmt16.get_dict( "de", dict_size=ModelHyperParams.trg_vocab_size, reverse=True) + def post_process_seq(seq, + bos_idx=ModelHyperParams.bos_idx, + eos_idx=ModelHyperParams.eos_idx, + output_bos=InferTaskConfig.output_bos, + output_eos=InferTaskConfig.output_eos): + """ + Post-process the beam-search decoded sequence. Truncate from the first + and remove the and tokens currently. + """ + eos_pos = len(seq) - 1 + for i, idx in enumerate(seq): + if idx == eos_idx: + eos_pos = i + break + seq = seq[:eos_pos + 1] + return filter( + lambda idx: (output_bos or idx != bos_idx) and \ + (output_eos or idx != eos_idx), + seq) + for batch_id, data in enumerate(test_data()): batch_seqs, batch_scores = translate_batch( - exe, [item[0] for item in data], encoder_program, - encoder_input_data_names, [enc_output.name], decoder_program, - decoder_input_data_names, [predict.name], InferTaskConfig.beam_size, - InferTaskConfig.max_length, InferTaskConfig.n_best, - len(data), ModelHyperParams.n_head, ModelHyperParams.src_pad_idx, - ModelHyperParams.trg_pad_idx, ModelHyperParams.bos_idx, - ModelHyperParams.eos_idx) + exe, [item[0] for item in data], + encoder_program, + encoder_input_data_names, [enc_output.name], + decoder_program, + decoder_input_data_names, [predict.name], + InferTaskConfig.beam_size, + InferTaskConfig.max_length, + InferTaskConfig.n_best, + len(data), + ModelHyperParams.n_head, + ModelHyperParams.src_pad_idx, + ModelHyperParams.trg_pad_idx, + ModelHyperParams.bos_idx, + ModelHyperParams.eos_idx, + ModelHyperParams.unk_idx, + output_unk=InferTaskConfig.output_unk) for i in range(len(batch_seqs)): - seqs = batch_seqs[i] + # Post-process the beam-search decoded sequences. + seqs = map(post_process_seq, batch_seqs[i]) scores = batch_scores[i] for seq in seqs: print(" ".join([trg_idx2word[idx] for idx in seq]))