Skip to content

Commit

Permalink
Merge pull request #1784 from luotao1/beam
Browse files Browse the repository at this point in the history
add seqtext_print for seqToseq demo
  • Loading branch information
luotao1 authored Apr 14, 2017
2 parents 92edc2d + 555b2df commit c51ab42
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 7 deletions.
39 changes: 35 additions & 4 deletions demo/seqToseq/api_train_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ def gru_decoder_with_attention(enc_vec, enc_proj, current_word):

def main():
paddle.init(use_gpu=False, trainer_count=1)
is_generating = True
is_generating = False

# source and target dict dim.
dict_size = 30000
Expand Down Expand Up @@ -167,16 +167,47 @@ def event_handler(event):

# generate a english sequence to french
else:
gen_creator = paddle.dataset.wmt14.test(dict_size)
# use the first 3 samples for generation
gen_creator = paddle.dataset.wmt14.gen(dict_size)
gen_data = []
gen_num = 3
for item in gen_creator():
gen_data.append((item[0], ))
if len(gen_data) == 3:
if len(gen_data) == gen_num:
break

beam_gen = seqToseq_net(source_dict_dim, target_dict_dim, is_generating)
# get the pretrained model, whose bleu = 26.92
parameters = paddle.dataset.wmt14.model()
trg_dict = paddle.dataset.wmt14.trg_dict(dict_size)
# prob is the prediction probabilities, and id is the prediction word.
beam_result = paddle.infer(
output_layer=beam_gen,
parameters=parameters,
input=gen_data,
field=['prob', 'id'])

# get the dictionary
src_dict, trg_dict = paddle.dataset.wmt14.get_dict(dict_size)

# the delimited element of generated sequences is -1,
# the first element of each generated sequence is the sequence length
seq_list = []
seq = []
for w in beam_result[1]:
if w != -1:
seq.append(w)
else:
seq_list.append(' '.join([trg_dict.get(w) for w in seq[1:]]))
seq = []

prob = beam_result[0]
beam_size = 3
for i in xrange(gen_num):
print "\n*******************************************************\n"
print "src:", ' '.join(
[src_dict.get(w) for w in gen_data[i][0]]), "\n"
for j in xrange(beam_size):
print "prob = %f:" % (prob[i][j]), seq_list[i * beam_size + j]


if __name__ == '__main__':
Expand Down
16 changes: 13 additions & 3 deletions python/paddle/v2/dataset/wmt14.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
MD5_DEV_TEST = '7d7897317ddd8ba0ae5c5fa7248d3ff5'
# this is a small set of data for test. The original data is too large and will be add later.
URL_TRAIN = 'http://paddlepaddle.cdn.bcebos.com/demo/wmt_shrinked_data/wmt14.tgz'
MD5_TRAIN = 'a755315dd01c2c35bde29a744ede23a6'
MD5_TRAIN = '0791583d57d5beb693b9414c5b36798c'
# this is the pretrained model, whose bleu = 26.92
URL_MODEL = 'http://paddlepaddle.bj.bcebos.com/demo/wmt_14/wmt14_model.tar.gz'
MD5_MODEL = '4ce14a26607fb8a1cc23bcdedb1895e4'
Expand Down Expand Up @@ -108,17 +108,27 @@ def test(dict_size):
download(URL_TRAIN, 'wmt14', MD5_TRAIN), 'test/test', dict_size)


def gen(dict_size):
return reader_creator(
download(URL_TRAIN, 'wmt14', MD5_TRAIN), 'gen/gen', dict_size)


def model():
tar_file = download(URL_MODEL, 'wmt14', MD5_MODEL)
with gzip.open(tar_file, 'r') as f:
parameters = Parameters.from_tar(f)
return parameters


def trg_dict(dict_size):
def get_dict(dict_size, reverse=True):
# if reverse = False, return dict = {'a':'001', 'b':'002', ...}
# else reverse = true, return dict = {'001':'a', '002':'b', ...}
tar_file = download(URL_TRAIN, 'wmt14', MD5_TRAIN)
src_dict, trg_dict = __read_to_dict__(tar_file, dict_size)
return trg_dict
if reverse:
src_dict = {v: k for k, v in src_dict.items()}
trg_dict = {v: k for k, v in trg_dict.items()}
return src_dict, trg_dict


def fetch():
Expand Down

0 comments on commit c51ab42

Please sign in to comment.