Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor of the attention example #243

Closed
wants to merge 2 commits into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 42 additions & 38 deletions examples/python/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,81 +80,85 @@ def attend(input_vectors, state):
return output_vectors


def decode(dec_lstm, vectors, output):
output = [EOS] + list(output) + [EOS]
output = [char2int[c] for c in output]
def make_decoder(input_sequence):
global decoder_w
global decoder_w
global enc_fwd_lstm
global enc_bwd_lstm
global dec_lstm

w = dy.parameter(decoder_w)
b = dy.parameter(decoder_b)

last_output_embeddings = output_lookup[char2int[EOS]]
s = dec_lstm.initial_state().add_input(dy.concatenate([dy.vecInput(STATE_SIZE*2), last_output_embeddings]))
loss = []
for char in output:
vector = dy.concatenate([attend(vectors, s), last_output_embeddings])
embedded = embed_sentence(input_sequence)
encoded = encode_sentence(enc_fwd_lstm, enc_bwd_lstm, embedded)

last_output_embeddings = output_lookup[char2int[EOS]]
s = dec_lstm.initial_state().add_input(dy.concatenate([dy.vecInput(STATE_SIZE * 2), last_output_embeddings]))
while True:
vector = dy.concatenate([attend(encoded, s), last_output_embeddings])
s = s.add_input(vector)
out_vector = w * s.output() + b
probs = dy.softmax(out_vector)
last_output_embeddings = output_lookup[char]
yield probs # output decoding probabilities
prev_char = yield # expect feedback char through .send(char) to continue decoding
last_output_embeddings = output_lookup[prev_char]


def train_decode(input_sequence, output_sequence):
output_sequence = [EOS] + list(output_sequence) + [EOS]
output_sequence = [char2int[c] for c in output_sequence]

loss = []
current = 0
decoder = make_decoder(input_sequence)
probs = next(decoder) # get first decoding probabilities
while len(loss) < len(output_sequence):
char = output_sequence[current]
loss.append(-dy.log(dy.pick(probs, char)))
current += 1
next(decoder) # move decoder to expecting position (prev_char = yield)
probs = decoder.send(char) # provide decoder with actual char to produce next state
loss = dy.esum(loss)
return loss


def generate(input, enc_fwd_lstm, enc_bwd_lstm, dec_lstm):
def generate(input_sequence):
def sample(probs):
rnd = random.random()
for i, p in enumerate(probs):
rnd -= p
if rnd <= 0: break
return i

embedded = embed_sentence(input)
encoded = encode_sentence(enc_fwd_lstm, enc_bwd_lstm, embedded)

w = dy.parameter(decoder_w)
b = dy.parameter(decoder_b)

last_output_embeddings = output_lookup[char2int[EOS]]
s = dec_lstm.initial_state().add_input(dy.concatenate([dy.vecInput(STATE_SIZE * 2), last_output_embeddings]))
out = ''
count_EOS = 0
for i in range(len(input)*2):
decoder = make_decoder(input_sequence)
probs = next(decoder) # get first decoding probabilities
for i in range(len(input_sequence)*2):
if count_EOS == 2: break
vector = dy.concatenate([attend(encoded, s), last_output_embeddings])

s = s.add_input(vector)
out_vector = w * s.output() + b
probs = dy.softmax(out_vector)
probs = probs.vec_value()
next_char = sample(probs)
last_output_embeddings = output_lookup[next_char]
if int2char[next_char] == EOS:
char = sample(probs.vec_value())
next(decoder) # move decoder to expecting position
probs = decoder.send(char) # provide decoder with current sampled char to produce next state
if int2char[char] == EOS:
count_EOS += 1
continue

out += int2char[next_char]
out += int2char[char]
return out


def get_loss(input_sentence, output_sentence, enc_fwd_lstm, enc_bwd_lstm, dec_lstm):
dy.renew_cg()
embedded = embed_sentence(input_sentence)
encoded = encode_sentence(enc_fwd_lstm, enc_bwd_lstm, embedded)
return decode(dec_lstm, encoded, output_sentence)


def train(model, sentence):
trainer = dy.SimpleSGDTrainer(model)
for i in range(600):
loss = get_loss(sentence, sentence, enc_fwd_lstm, enc_bwd_lstm, dec_lstm)
dy.renew_cg()
loss = train_decode(sentence, sentence)
loss_value = loss.value()
loss.backward()
trainer.update()
if i % 20 == 0:
print(loss_value)
print(generate(sentence, enc_fwd_lstm, enc_bwd_lstm, dec_lstm))
print(generate(sentence))


train(model, "it is working")