Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Avoid predicting <pad> by restricting the size of fc_layer in Transformer #819

Merged
merged 5 commits into from
Apr 11, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 8 additions & 14 deletions fluid/neural_machine_translation/transformer/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,21 +43,16 @@ class InferTaskConfig(object):


class ModelHyperParams(object):
# Dictionary size for source and target language. This model directly uses
# paddle.dataset.wmt16 in which <bos>, <eos> and <unk> token has
# alreay been added, but the <pad> token is not added. Transformer requires
# sequences in a mini-batch are padded to have the same length. A <pad> token is
# added into the original dictionary in paddle.dateset.wmt16.
# This model directly uses paddle.dataset.wmt16 in which <bos>, <eos> and
# <unk> token has alreay been added. As for the <pad> token, any token
# included in dict can be used to pad, since the paddings' loss will be
# masked out and make no effect on parameter gradients.

# size of source word dictionary.
src_vocab_size = 10000
# index for <pad> token in source language.
src_pad_idx = src_vocab_size

# size of target word dictionay
trg_vocab_size = 10000
# index for <pad> token in target language.
trg_pad_idx = trg_vocab_size

# index for <bos> token
bos_idx = 0
Expand All @@ -66,11 +61,10 @@ class ModelHyperParams(object):
# index for <unk> token
unk_idx = 2

# position value corresponding to the <pad> token.
pos_pad_idx = 0

# max length of sequences. It should plus 1 to include position
# padding token for position encoding.
# max length of sequences.
# The size of position encoding table should at least plus 1, since the
# sinusoid position encoding starts from 1 and 0 can be used as the padding
# token for position encoding.
max_length = 50

# the dimension for word embeddings, which is also the last dimension of
Expand Down
40 changes: 19 additions & 21 deletions fluid/neural_machine_translation/transformer/infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def translate_batch(exe,
src_pad_idx,
n_head,
is_target=False,
return_pos=True,
is_label=False,
return_attn_bias=True,
return_max_len=False)
# Append the data shape input to reshape the output of embedding layer.
Expand Down Expand Up @@ -250,22 +250,20 @@ def main():
encoder_program = fluid.Program()
with fluid.program_guard(main_program=encoder_program):
enc_output = encoder(
ModelHyperParams.src_vocab_size + 1,
ModelHyperParams.max_length + 1, ModelHyperParams.n_layer,
ModelHyperParams.n_head, ModelHyperParams.d_key,
ModelHyperParams.d_value, ModelHyperParams.d_model,
ModelHyperParams.d_inner_hid, ModelHyperParams.dropout,
ModelHyperParams.src_pad_idx, ModelHyperParams.pos_pad_idx)
ModelHyperParams.src_vocab_size, ModelHyperParams.max_length + 1,
ModelHyperParams.n_layer, ModelHyperParams.n_head,
ModelHyperParams.d_key, ModelHyperParams.d_value,
ModelHyperParams.d_model, ModelHyperParams.d_inner_hid,
ModelHyperParams.dropout)

decoder_program = fluid.Program()
with fluid.program_guard(main_program=decoder_program):
predict = decoder(
ModelHyperParams.trg_vocab_size + 1,
ModelHyperParams.max_length + 1, ModelHyperParams.n_layer,
ModelHyperParams.n_head, ModelHyperParams.d_key,
ModelHyperParams.d_value, ModelHyperParams.d_model,
ModelHyperParams.d_inner_hid, ModelHyperParams.dropout,
ModelHyperParams.trg_pad_idx, ModelHyperParams.pos_pad_idx)
ModelHyperParams.trg_vocab_size, ModelHyperParams.max_length + 1,
ModelHyperParams.n_layer, ModelHyperParams.n_head,
ModelHyperParams.d_key, ModelHyperParams.d_value,
ModelHyperParams.d_model, ModelHyperParams.d_inner_hid,
ModelHyperParams.dropout)

# Load model parameters of encoder and decoder separately from the saved
# transformer model.
Expand Down Expand Up @@ -301,9 +299,6 @@ def main():

trg_idx2word = paddle.dataset.wmt16.get_dict(
"de", dict_size=ModelHyperParams.trg_vocab_size, reverse=True)
# Append the <pad> token since the dict provided by dataset.wmt16 does
# not include it.
trg_idx2word[ModelHyperParams.trg_pad_idx] = "<pad>"

def post_process_seq(seq,
bos_idx=ModelHyperParams.bos_idx,
Expand All @@ -327,19 +322,22 @@ def post_process_seq(seq,

for batch_id, data in enumerate(test_data()):
batch_seqs, batch_scores = translate_batch(
exe, [item[0] for item in data],
exe,
[item[0] for item in data],
encoder_program,
encoder_input_data_names, [enc_output.name],
encoder_input_data_names,
[enc_output.name],
decoder_program,
decoder_input_data_names, [predict.name],
decoder_input_data_names,
[predict.name],
InferTaskConfig.beam_size,
InferTaskConfig.max_length,
InferTaskConfig.n_best,
len(data),
ModelHyperParams.n_head,
ModelHyperParams.d_model,
ModelHyperParams.src_pad_idx,
ModelHyperParams.trg_pad_idx,
ModelHyperParams.eos_idx, # Use eos_idx to pad.
ModelHyperParams.eos_idx, # Use eos_idx to pad.
ModelHyperParams.bos_idx,
ModelHyperParams.eos_idx,
ModelHyperParams.unk_idx,
Expand Down
25 changes: 5 additions & 20 deletions fluid/neural_machine_translation/transformer/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,10 +199,8 @@ def prepare_encoder(src_word,
src_pos,
src_vocab_size,
src_emb_dim,
src_pad_idx,
src_max_len,
dropout_rate=0.,
pos_pad_idx=0,
src_data_shape=None,
pos_enc_param_name=None):
"""Add word embeddings and position encodings.
Expand All @@ -214,12 +212,10 @@ def prepare_encoder(src_word,
src_word_emb = layers.embedding(
src_word,
size=[src_vocab_size, src_emb_dim],
padding_idx=src_pad_idx,
param_attr=fluid.initializer.Normal(0., 1.))
src_pos_enc = layers.embedding(
src_pos,
size=[src_max_len, src_emb_dim],
padding_idx=pos_pad_idx,
param_attr=fluid.ParamAttr(
name=pos_enc_param_name, trainable=False))
enc_input = src_word_emb + src_pos_enc
Expand Down Expand Up @@ -480,12 +476,16 @@ def make_inputs(input_data_names,
append_batch_size=False)
input_layers += [slf_attn_post_softmax_shape]
if src_attn_shape_flag:
# This shape input is used to reshape before softmax in encoder-decoder
# attention.
src_attn_pre_softmax_shape = layers.data(
name=input_data_names[len(input_layers)],
shape=[2],
dtype="int32",
append_batch_size=False)
input_layers += [src_attn_pre_softmax_shape]
# This shape input is used to reshape after softmax in encoder-decoder
# attention.
src_attn_post_softmax_shape = layers.data(
name=input_data_names[len(input_layers)],
shape=[4],
Expand Down Expand Up @@ -516,10 +516,7 @@ def transformer(
d_value,
d_model,
d_inner_hid,
dropout_rate,
src_pad_idx,
trg_pad_idx,
pos_pad_idx, ):
dropout_rate, ):
enc_inputs = make_inputs(
encoder_input_data_names,
n_head,
Expand All @@ -543,8 +540,6 @@ def transformer(
d_model,
d_inner_hid,
dropout_rate,
src_pad_idx,
pos_pad_idx,
enc_inputs, )

dec_inputs = make_inputs(
Expand All @@ -570,8 +565,6 @@ def transformer(
d_model,
d_inner_hid,
dropout_rate,
trg_pad_idx,
pos_pad_idx,
dec_inputs,
enc_output, )

Expand Down Expand Up @@ -606,8 +599,6 @@ def wrap_encoder(src_vocab_size,
d_model,
d_inner_hid,
dropout_rate,
src_pad_idx,
pos_pad_idx,
enc_inputs=None):
"""
The wrapper assembles together all needed layers for the encoder.
Expand Down Expand Up @@ -637,10 +628,8 @@ def wrap_encoder(src_vocab_size,
src_pos,
src_vocab_size,
d_model,
src_pad_idx,
max_length,
dropout_rate,
pos_pad_idx,
src_data_shape, )
enc_output = encoder(
enc_input,
Expand All @@ -666,8 +655,6 @@ def wrap_decoder(trg_vocab_size,
d_model,
d_inner_hid,
dropout_rate,
trg_pad_idx,
pos_pad_idx,
dec_inputs=None,
enc_output=None):
"""
Expand Down Expand Up @@ -701,10 +688,8 @@ def wrap_decoder(trg_vocab_size,
trg_pos,
trg_vocab_size,
d_model,
trg_pad_idx,
max_length,
dropout_rate,
pos_pad_idx,
trg_data_shape, )
dec_output = decoder(
dec_input,
Expand Down
51 changes: 31 additions & 20 deletions fluid/neural_machine_translation/transformer/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def pad_batch_data(insts,
pad_idx,
n_head,
is_target=False,
return_pos=True,
is_label=False,
return_attn_bias=True,
return_max_len=True):
"""
Expand All @@ -24,14 +24,20 @@ def pad_batch_data(insts,
"""
return_list = []
max_len = max(len(inst) for inst in insts)
# Any token included in dict can be used to pad, since the paddings' loss
# will be masked out by weights and make no effect on parameter gradients.
inst_data = np.array(
[inst + [pad_idx] * (max_len - len(inst)) for inst in insts])
return_list += [inst_data.astype("int64").reshape([-1, 1])]
if return_pos:
inst_pos = np.array([[
pos_i + 1 if w_i != pad_idx else 0 for pos_i, w_i in enumerate(inst)
] for inst in inst_data])

if is_label: # label weight
inst_weight = np.array(
[[1.] * len(inst) + [0.] * (max_len - len(inst)) for inst in insts])
return_list += [inst_weight.astype("float32").reshape([-1, 1])]
else: # position data
inst_pos = np.array([
range(1, len(inst) + 1) + [0] * (max_len - len(inst))
for inst in insts
])
return_list += [inst_pos.astype("int64").reshape([-1, 1])]
if return_attn_bias:
if is_target:
Expand Down Expand Up @@ -84,9 +90,14 @@ def prepare_batch_input(insts, input_data_names, src_pad_idx, trg_pad_idx,
trg_src_attn_post_softmax_shape = np.array(
trg_src_attn_bias.shape, dtype="int32")

lbl_word = pad_batch_data([inst[2] for inst in insts], trg_pad_idx, n_head,
False, False, False, False)
lbl_weight = (lbl_word != trg_pad_idx).astype("float32").reshape([-1, 1])
lbl_word, lbl_weight = pad_batch_data(
[inst[2] for inst in insts],
trg_pad_idx,
n_head,
is_target=False,
is_label=True,
return_attn_bias=False,
return_max_len=False)

input_dict = dict(
zip(input_data_names, [
Expand All @@ -105,13 +116,11 @@ def main():
exe = fluid.Executor(place)

sum_cost, avg_cost, predict, token_num = transformer(
ModelHyperParams.src_vocab_size + 1,
ModelHyperParams.trg_vocab_size + 1, ModelHyperParams.max_length + 1,
ModelHyperParams.n_layer, ModelHyperParams.n_head,
ModelHyperParams.d_key, ModelHyperParams.d_value,
ModelHyperParams.d_model, ModelHyperParams.d_inner_hid,
ModelHyperParams.dropout, ModelHyperParams.src_pad_idx,
ModelHyperParams.trg_pad_idx, ModelHyperParams.pos_pad_idx)
ModelHyperParams.src_vocab_size, ModelHyperParams.trg_vocab_size,
ModelHyperParams.max_length + 1, ModelHyperParams.n_layer,
ModelHyperParams.n_head, ModelHyperParams.d_key,
ModelHyperParams.d_value, ModelHyperParams.d_model,
ModelHyperParams.d_inner_hid, ModelHyperParams.dropout)

lr_scheduler = LearningRateScheduler(ModelHyperParams.d_model,
TrainTaskConfig.warmup_steps, place,
Expand Down Expand Up @@ -145,8 +154,8 @@ def test(exe):
for batch_id, data in enumerate(val_data()):
data_input = prepare_batch_input(
data, encoder_input_data_names + decoder_input_data_names[:-1] +
label_data_names, ModelHyperParams.src_pad_idx,
ModelHyperParams.trg_pad_idx, ModelHyperParams.n_head,
label_data_names, ModelHyperParams.eos_idx,
ModelHyperParams.eos_idx, ModelHyperParams.n_head,
ModelHyperParams.d_model)
test_sum_cost, test_token_num = exe.run(
test_program,
Expand All @@ -171,10 +180,12 @@ def test(exe):
for pass_id in xrange(TrainTaskConfig.pass_num):
pass_start_time = time.time()
for batch_id, data in enumerate(train_data()):
if len(data) != TrainTaskConfig.batch_size:
continue
data_input = prepare_batch_input(
data, encoder_input_data_names + decoder_input_data_names[:-1] +
label_data_names, ModelHyperParams.src_pad_idx,
ModelHyperParams.trg_pad_idx, ModelHyperParams.n_head,
label_data_names, ModelHyperParams.eos_idx,
ModelHyperParams.eos_idx, ModelHyperParams.n_head,
ModelHyperParams.d_model)
lr_scheduler.update_learning_rate(data_input)
outs = exe.run(fluid.framework.default_main_program(),
Expand Down