Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Decouple the program desc with batch_size in Transformer. #783

Merged
5 changes: 3 additions & 2 deletions fluid/neural_machine_translation/transformer/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,7 @@ class TrainTaskConfig(object):
class InferTaskConfig(object):
use_gpu = False
# the number of examples in one run for sequence generation.
# currently the batch size can only be set to 1.
batch_size = 1
batch_size = 10

# the parameters for beam search.
beam_size = 5
Expand Down Expand Up @@ -103,6 +102,7 @@ class ModelHyperParams(object):
"src_word",
"src_pos",
"src_slf_attn_bias",
"src_data_shape",
"src_slf_attn_pre_softmax_shape",
"src_slf_attn_post_softmax_shape", )

Expand All @@ -112,6 +112,7 @@ class ModelHyperParams(object):
"trg_pos",
"trg_slf_attn_bias",
"trg_src_attn_bias",
"trg_data_shape",
"trg_slf_attn_pre_softmax_shape",
"trg_slf_attn_post_softmax_shape",
"trg_src_attn_pre_softmax_shape",
Expand Down
90 changes: 61 additions & 29 deletions fluid/neural_machine_translation/transformer/infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ def translate_batch(exe,
n_best,
batch_size,
n_head,
d_model,
src_pad_idx,
trg_pad_idx,
bos_idx,
Expand All @@ -43,6 +44,11 @@ def translate_batch(exe,
return_pos=True,
return_attn_bias=True,
return_max_len=False)
# Append the data shape input to reshape the output of embedding layer.
enc_in_data = enc_in_data + [
np.array(
[-1, enc_in_data[2].shape[-1], d_model], dtype="int32")
]
# Append the shape inputs to reshape before and after softmax in encoder
# self attention.
enc_in_data = enc_in_data + [
Expand All @@ -59,9 +65,14 @@ def translate_batch(exe,
scores = np.zeros((batch_size, beam_size), dtype="float32")
prev_branchs = [[] for i in range(batch_size)]
next_ids = [[] for i in range(batch_size)]
# Use beam_map to map the instance idx in batch to beam idx, since the
# Use beam_inst_map to map beam idx to the instance idx in batch, since the
# size of feeded batch is changing.
beam_map = range(batch_size)
beam_inst_map = {
beam_idx: inst_idx
for inst_idx, beam_idx in enumerate(range(batch_size))
}
# Use active_beams to recode the alive.
active_beams = range(batch_size)

def beam_backtrace(prev_branchs, next_ids, n_best=beam_size):
"""
Expand Down Expand Up @@ -98,8 +109,14 @@ def init_dec_in_data(batch_size, beam_size, enc_in_data, enc_output):
[-1e9]).astype("float32")
# This is used to remove attention on the paddings of source sequences.
trg_src_attn_bias = np.tile(
src_slf_attn_bias[:, :, ::src_max_length, :],
[beam_size, 1, trg_max_len, 1])
src_slf_attn_bias[:, :, ::src_max_length, :][:, np.newaxis],
[1, beam_size, 1, trg_max_len, 1]).reshape([
-1, src_slf_attn_bias.shape[1], trg_max_len,
src_slf_attn_bias.shape[-1]
])
# Append the shape input to reshape the output of embedding layer.
trg_data_shape = np.array(
[batch_size * beam_size, trg_max_len, d_model], dtype="int32")
# Append the shape inputs to reshape before and after softmax in
# decoder self attention.
trg_slf_attn_pre_softmax_shape = np.array(
Expand All @@ -112,22 +129,24 @@ def init_dec_in_data(batch_size, beam_size, enc_in_data, enc_output):
[-1, trg_src_attn_bias.shape[-1]], dtype="int32")
trg_src_attn_post_softmax_shape = np.array(
trg_src_attn_bias.shape, dtype="int32")
enc_output = np.tile(enc_output, [beam_size, 1, 1])
enc_output = np.tile(
enc_output[:, np.newaxis], [1, beam_size, 1, 1]).reshape(
[-1, enc_output.shape[-2], enc_output.shape[-1]])
return trg_words, trg_pos, trg_slf_attn_bias, trg_src_attn_bias, \
trg_slf_attn_pre_softmax_shape, trg_slf_attn_post_softmax_shape, \
trg_src_attn_pre_softmax_shape, trg_src_attn_post_softmax_shape, \
enc_output
trg_data_shape, trg_slf_attn_pre_softmax_shape, \
trg_slf_attn_post_softmax_shape, trg_src_attn_pre_softmax_shape, \
trg_src_attn_post_softmax_shape, enc_output

def update_dec_in_data(dec_in_data, next_ids, active_beams):
def update_dec_in_data(dec_in_data, next_ids, active_beams, beam_inst_map):
"""
Update the input data of decoder mainly by slicing from the previous
input data and dropping the finished instance beams.
"""
trg_words, trg_pos, trg_slf_attn_bias, trg_src_attn_bias, \
trg_slf_attn_pre_softmax_shape, trg_slf_attn_post_softmax_shape, \
trg_src_attn_pre_softmax_shape, trg_src_attn_post_softmax_shape, \
enc_output = dec_in_data
trg_cur_len = len(next_ids[0]) + 1 # include the <bos>
trg_data_shape, trg_slf_attn_pre_softmax_shape, \
trg_slf_attn_post_softmax_shape, trg_src_attn_pre_softmax_shape, \
trg_src_attn_post_softmax_shape, enc_output = dec_in_data
trg_cur_len = trg_slf_attn_bias.shape[-1] + 1
trg_words = np.array(
[
beam_backtrace(prev_branchs[beam_idx], next_ids[beam_idx])
Expand All @@ -138,6 +157,7 @@ def update_dec_in_data(dec_in_data, next_ids, active_beams):
trg_pos = np.array(
[range(1, trg_cur_len + 1)] * len(active_beams) * beam_size,
dtype="int64").reshape([-1, 1])
active_beams = [beam_inst_map[beam_idx] for beam_idx in active_beams]
active_beams_indice = (
(np.array(active_beams) * beam_size)[:, np.newaxis] +
np.array(range(beam_size))[np.newaxis, :]).flatten()
Expand All @@ -152,6 +172,10 @@ def update_dec_in_data(dec_in_data, next_ids, active_beams):
trg_src_attn_bias = np.tile(trg_src_attn_bias[
active_beams_indice, :, ::trg_src_attn_bias.shape[2], :],
[1, 1, trg_cur_len, 1])
# Append the shape input to reshape the output of embedding layer.
trg_data_shape = np.array(
[len(active_beams) * beam_size, trg_cur_len, d_model],
dtype="int32")
# Append the shape inputs to reshape before and after softmax in
# decoder self attention.
trg_slf_attn_pre_softmax_shape = np.array(
Expand All @@ -166,9 +190,9 @@ def update_dec_in_data(dec_in_data, next_ids, active_beams):
trg_src_attn_bias.shape, dtype="int32")
enc_output = enc_output[active_beams_indice, :, :]
return trg_words, trg_pos, trg_slf_attn_bias, trg_src_attn_bias, \
trg_slf_attn_pre_softmax_shape, trg_slf_attn_post_softmax_shape, \
trg_src_attn_pre_softmax_shape, trg_src_attn_post_softmax_shape, \
enc_output
trg_data_shape, trg_slf_attn_pre_softmax_shape, \
trg_slf_attn_post_softmax_shape, trg_src_attn_pre_softmax_shape, \
trg_src_attn_post_softmax_shape, enc_output

dec_in_data = init_dec_in_data(batch_size, beam_size, enc_in_data,
enc_output)
Expand All @@ -177,15 +201,18 @@ def update_dec_in_data(dec_in_data, next_ids, active_beams):
feed=dict(zip(dec_in_names, dec_in_data)),
fetch_list=dec_out_names)[0]
predict_all = np.log(
predict_all.reshape([len(beam_map) * beam_size, i + 1, -1])[:,
-1, :])
predict_all = (predict_all + scores[beam_map].reshape(
[len(beam_map) * beam_size, -1])).reshape(
[len(beam_map), beam_size, -1])
predict_all.reshape([len(beam_inst_map) * beam_size, i + 1, -1])
[:, -1, :])
predict_all = (predict_all + scores[active_beams].reshape(
[len(beam_inst_map) * beam_size, -1])).reshape(
[len(beam_inst_map), beam_size, -1])
if not output_unk: # To exclude the <unk> token.
predict_all[:, :, unk_idx] = -1e9
active_beams = []
for inst_idx, beam_idx in enumerate(beam_map):
for beam_idx in range(batch_size):
if not beam_inst_map.has_key(beam_idx):
continue
inst_idx = beam_inst_map[beam_idx]
predict = (predict_all[inst_idx, :, :]
if i != 0 else predict_all[inst_idx, 0, :]).flatten()
top_k_indice = np.argpartition(predict, -beam_size)[-beam_size:]
Expand All @@ -198,10 +225,14 @@ def update_dec_in_data(dec_in_data, next_ids, active_beams):
next_ids[beam_idx].append(top_scores_ids % predict_all.shape[-1])
if next_ids[beam_idx][-1][0] != eos_idx:
active_beams.append(beam_idx)
beam_map = active_beams
if len(beam_map) == 0:
if len(active_beams) == 0:
break
dec_in_data = update_dec_in_data(dec_in_data, next_ids, active_beams)
dec_in_data = update_dec_in_data(dec_in_data, next_ids, active_beams,
beam_inst_map)
beam_inst_map = {
beam_idx: inst_idx
for inst_idx, beam_idx in enumerate(active_beams)
}

# Decode beams and select n_best sequences for each instance by backtrace.
seqs = [
Expand All @@ -215,10 +246,8 @@ def update_dec_in_data(dec_in_data, next_ids, active_beams):
def main():
place = fluid.CUDAPlace(0) if InferTaskConfig.use_gpu else fluid.CPUPlace()
exe = fluid.Executor(place)
# The current program desc is coupled with batch_size and the only
# supported batch size is 1 currently.

encoder_program = fluid.Program()
model.batch_size = InferTaskConfig.batch_size
with fluid.program_guard(main_program=encoder_program):
enc_output = encoder(
ModelHyperParams.src_vocab_size + 1,
Expand All @@ -228,7 +257,6 @@ def main():
ModelHyperParams.d_inner_hid, ModelHyperParams.dropout,
ModelHyperParams.src_pad_idx, ModelHyperParams.pos_pad_idx)

model.batch_size = InferTaskConfig.batch_size * InferTaskConfig.beam_size
decoder_program = fluid.Program()
with fluid.program_guard(main_program=decoder_program):
predict = decoder(
Expand Down Expand Up @@ -273,6 +301,9 @@ def main():

trg_idx2word = paddle.dataset.wmt16.get_dict(
"de", dict_size=ModelHyperParams.trg_vocab_size, reverse=True)
# Append the <pad> token since the dict provided by dataset.wmt16 does
# not include it.
trg_idx2word[ModelHyperParams.trg_pad_idx] = "<pad>"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please fix this in next PR.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Get it.


def post_process_seq(seq,
bos_idx=ModelHyperParams.bos_idx,
Expand Down Expand Up @@ -306,6 +337,7 @@ def post_process_seq(seq,
InferTaskConfig.n_best,
len(data),
ModelHyperParams.n_head,
ModelHyperParams.d_model,
ModelHyperParams.src_pad_idx,
ModelHyperParams.trg_pad_idx,
ModelHyperParams.bos_idx,
Expand Down
Loading