Skip to content
This repository has been archived by the owner on Jan 15, 2024. It is now read-only.

Commit

Permalink
[WIP] Fix a lot of stuff
Browse files Browse the repository at this point in the history
Here is one thing I could not address.
The `samples` are `taken` at each time stamp,
it could not be expressed in `while_loop`.
I totally have no idea how to deal with this.
  • Loading branch information
junrushao committed Jul 31, 2018
1 parent ee4175b commit 7f00f1b
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 37 deletions.
65 changes: 32 additions & 33 deletions gluonnlp/model/beam_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,17 +201,13 @@ def __init__(self, beam_size, eos_id, scorer, state_info, single_step=False, pre
self._single_step = single_step
assert eos_id >= 0, 'eos_id cannot be negative! Received eos_id={}'.format(eos_id)

def hybrid_forward(self, F, samples, valid_length, log_probs, scores, step, beam_alive_mask, # pylint: disable=arguments-differ
states, vocab_num, batch_shift):
def hybrid_forward(self, F, valid_length, log_probs, scores, step, beam_alive_mask, # pylint: disable=arguments-differ
states, vocab_num, batch_shift, samples):
"""
Parameters
----------
F
samples : NDArray or Symbol or an empty list
The current samples generated by beam search.
An empty list when single_step is True.
(batch_size, beam_size, L) when single_step is False.
valid_length : NDArray or Symbol
The current valid lengths of the samples
log_probs : NDArray or Symbol
Expand All @@ -230,6 +226,10 @@ def hybrid_forward(self, F, samples, valid_length, log_probs, scores, step, beam
batch_shift : NDArray or Symbol
Contains [0, beam_size, 2 * beam_size, ..., (batch_size - 1) * beam_size].
Shape (batch_size,)
samples : NDArray or Symbol or an empty list
The current samples generated by beam search.
An empty list when single_step is True.
(batch_size, beam_size, L) when single_step is False.
Returns
-------
Expand Down Expand Up @@ -385,8 +385,8 @@ def __call__(self, inputs, states):
batch_shift_nd = mx.nd.arange(0, batch_size * beam_size, beam_size, ctx=ctx)
step_nd = mx.nd.array([i + 1], ctx=ctx)
samples, valid_length, scores, chosen_word_ids, beam_alive_mask, states = \
self._updater(samples, valid_length, log_probs, scores, step_nd, beam_alive_mask,
new_states, vocab_num_nd, batch_shift_nd)
self._updater(valid_length, log_probs, scores, step_nd, beam_alive_mask,
new_states, vocab_num_nd, batch_shift_nd, samples)
step_input = mx.nd.relu(chosen_word_ids).reshape((-1,))
if mx.nd.sum(beam_alive_mask).asscalar() == 0:
return mx.nd.round(samples).astype(np.int32),\
Expand Down Expand Up @@ -430,7 +430,7 @@ class HybridBeamSearchSampler(HybridBlock):
The score function used in beam search.
max_length : int, default 100
The maximum search length.
vocab_size : int, default None, meaning `decoder._vocab_size`
vocab_num : int, default None, meaning `decoder._vocab_size`
The vocabulary size
"""
def __init__(self, batch_size, beam_size, decoder, eos_id,
Expand Down Expand Up @@ -496,56 +496,55 @@ def hybrid_forward(self, F, inputs, states): # pylint: disable=arguments-diffe
dim=1,
)
beam_alive_mask = F.ones(shape=(batch_size, beam_size))
vocab_num = F.ones([vocab_size])
vocab_num = F.full(shape=(1, ), val=vocab_size)
batch_shift = F.arange(0, batch_size * beam_size, beam_size)

def _loop_cond(_i, _step_input, _states, _valid_length, _scores, beam_alive_mask):
def _loop_cond(_i, _step_input, _valid_length, _scores, beam_alive_mask, *_states):
return F.sum(beam_alive_mask) > 0

def _loop_func(i, step_input, states, valid_length, scores, beam_alive_mask):
def _loop_func(i, step_input, valid_length, scores, beam_alive_mask, *states):
log_probs, new_states = self._decoder(step_input, states)
step = i + 1
_, valid_length, scores, chosen_word_ids, beam_alive_mask, states = \
self._updater([], valid_length, log_probs, scores, step, beam_alive_mask,
new_states, vocab_num, batch_shift)
step_input = F.relu(chosen_word_ids).reshape((-1,))
return (chosen_word_ids, ), (step, step_input, states, valid_length, scores, beam_alive_mask)
_, new_valid_length, new_scores, chosen_word_ids, new_beam_alive_mask, new_new_states = \
self._updater(valid_length, log_probs, scores, step, beam_alive_mask,
new_states, vocab_num, batch_shift, [])
new_step_input = F.relu(chosen_word_ids).reshape((-1,))
return (chosen_word_ids, ), (step, new_step_input, new_valid_length, new_scores, new_beam_alive_mask) + tuple(new_new_states)

(samples, ), (_, _, _, valid_length, scores, beam_alive_mask) = \
(samples, ), (_, _, new_valid_length, new_scores, new_beam_alive_mask, _) = \
F.contrib.while_loop(
cond=_loop_cond, func=_loop_func, max_iterations=self._max_length,
loop_vars=(
F.zeros(shape=(1, )), # i
step_input,
states,
valid_length,
scores,
beam_alive_mask
)
) + tuple(states)
)

def _then_func():
new_samples = F.concat(
step_input.reshape((batch_size, beam_size, 1)),
samples,
samples.transpose((1, 2, 0)),
F.full(shape=(batch_size, beam_size, 1), val=-1),
dim=2)
return new_samples, valid_length
new_new_valid_length = new_valid_length
return new_samples, new_new_valid_length

def _else_func():
final_word = F.where(beam_alive_mask,
F.full(shape=(batch_size, beam_size),
val=self._eos_id),
F.full(shape=(batch_size, beam_size),
val=-1))
F.full(shape=(batch_size, beam_size), val=self._eos_id),
F.full(shape=(batch_size, beam_size), val=-1))
new_samples = F.concat(
step_input.reshape((batch_size, beam_size, 1)),
samples,
samples.transpose((1, 2, 0)),
final_word.reshape((0, 0, 1)),
dim=2)
new_valid_length = valid_length + beam_alive_mask
return new_samples, new_valid_length
new_new_valid_length = new_valid_length + new_beam_alive_mask
return new_samples, new_new_valid_length

samples, scores = F.contrib.cond(F.sum(beam_alive_mask) == 0, _then_func, _else_func)
return F.round(samples).astype(np.int32),\
scores,\
F.round(valid_length).astype(np.int32)
new_samples, new_new_valid_length = F.contrib.cond(F.sum(new_beam_alive_mask) == 0, _then_func, _else_func)
return F.round(new_samples).astype(np.int32),\
new_scores,\
F.round(new_new_valid_length).astype(np.int32)
20 changes: 16 additions & 4 deletions tests/unittest/test_beam_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from mxnet.gluon.rnn import RNNCell, RNN
from numpy.testing import assert_allclose

from gluonnlp.model import BeamSearchSampler, BeamSearchScorer
from gluonnlp.model import BeamSearchSampler, BeamSearchScorer, HybridBeamSearchSampler


def test_beam_search_score():
Expand All @@ -27,7 +27,9 @@ def test_beam_search_score():


@pytest.mark.seed(1)
def test_beam_search():
@pytest.mark.parametrize('hybridize', [False, True])
@pytest.mark.parametrize('sampler_cls', [HybridBeamSearchSampler, BeamSearchSampler])
def test_beam_search(hybridize, sampler_cls):
def _get_new_states(states, state_info, sel_beam_ids):
assert not state_info or isinstance(state_info, (type(states), dict)), \
'states and state_info don\'t match'
Expand Down Expand Up @@ -229,12 +231,22 @@ def forward(self, inputs, states):
state_info = decoder.state_info()
else:
state_info = None
if sampler_cls is HybridBeamSearchSampler and decoder_fn is RNNLayerDecoder:
# Hybrid beam search does not work on non-hybridizable object
continue
for beam_size, bos_id, eos_id, alpha, K in [(2, 1, 3, 0, 1.0), (4, 2, 3, 1.0, 5.0)]:
scorer = BeamSearchScorer(alpha=alpha, K=K)
for max_length in [10, 20]:
sampler = BeamSearchSampler(beam_size=beam_size, decoder=decoder, eos_id=eos_id,
scorer=scorer, max_length=max_length)
for batch_size in [1, 2, 5]:
if sampler_cls is HybridBeamSearchSampler:
sampler = sampler_cls(batch_size=batch_size, beam_size=beam_size,
decoder=decoder, eos_id=eos_id, vocab_size=vocab_num,
scorer=scorer, max_length=max_length)
if hybridize:
sampler.hybridize()
else:
sampler = sampler_cls(beam_size=beam_size, decoder=decoder, eos_id=eos_id,
scorer=scorer, max_length=max_length)
print(type(decoder).__name__, beam_size, bos_id, eos_id, alpha, K, batch_size)
states = decoder.begin_state(batch_size)
inputs = mx.nd.full(shape=(batch_size,), val=bos_id)
Expand Down

0 comments on commit 7f00f1b

Please sign in to comment.