diff --git a/gluonnlp/model/beam_search.py b/gluonnlp/model/beam_search.py index 3c6aa59ad7..4ac2638912 100644 --- a/gluonnlp/model/beam_search.py +++ b/gluonnlp/model/beam_search.py @@ -201,17 +201,13 @@ def __init__(self, beam_size, eos_id, scorer, state_info, single_step=False, pre self._single_step = single_step assert eos_id >= 0, 'eos_id cannot be negative! Received eos_id={}'.format(eos_id) - def hybrid_forward(self, F, samples, valid_length, log_probs, scores, step, beam_alive_mask, # pylint: disable=arguments-differ - states, vocab_num, batch_shift): + def hybrid_forward(self, F, valid_length, log_probs, scores, step, beam_alive_mask, # pylint: disable=arguments-differ + states, vocab_num, batch_shift, samples): """ Parameters ---------- F - samples : NDArray or Symbol or an empty list - The current samples generated by beam search. - An empty list when single_step is True. - (batch_size, beam_size, L) when single_step is False. valid_length : NDArray or Symbol The current valid lengths of the samples log_probs : NDArray or Symbol @@ -230,6 +226,10 @@ def hybrid_forward(self, F, samples, valid_length, log_probs, scores, step, beam batch_shift : NDArray or Symbol Contains [0, beam_size, 2 * beam_size, ..., (batch_size - 1) * beam_size]. Shape (batch_size,) + samples : NDArray or Symbol or an empty list + The current samples generated by beam search. + An empty list when single_step is True. + (batch_size, beam_size, L) when single_step is False. Returns ------- @@ -385,8 +385,8 @@ def __call__(self, inputs, states): batch_shift_nd = mx.nd.arange(0, batch_size * beam_size, beam_size, ctx=ctx) step_nd = mx.nd.array([i + 1], ctx=ctx) samples, valid_length, scores, chosen_word_ids, beam_alive_mask, states = \ - self._updater(samples, valid_length, log_probs, scores, step_nd, beam_alive_mask, - new_states, vocab_num_nd, batch_shift_nd) + self._updater(valid_length, log_probs, scores, step_nd, beam_alive_mask, + new_states, vocab_num_nd, batch_shift_nd, samples) step_input = mx.nd.relu(chosen_word_ids).reshape((-1,)) if mx.nd.sum(beam_alive_mask).asscalar() == 0: return mx.nd.round(samples).astype(np.int32),\ @@ -430,7 +430,7 @@ class HybridBeamSearchSampler(HybridBlock): The score function used in beam search. max_length : int, default 100 The maximum search length. - vocab_size : int, default None, meaning `decoder._vocab_size` + vocab_num : int, default None, meaning `decoder._vocab_size` The vocabulary size """ def __init__(self, batch_size, beam_size, decoder, eos_id, @@ -496,56 +496,55 @@ def hybrid_forward(self, F, inputs, states): # pylint: disable=arguments-diffe dim=1, ) beam_alive_mask = F.ones(shape=(batch_size, beam_size)) - vocab_num = F.ones([vocab_size]) + vocab_num = F.full(shape=(1, ), val=vocab_size) batch_shift = F.arange(0, batch_size * beam_size, beam_size) - def _loop_cond(_i, _step_input, _states, _valid_length, _scores, beam_alive_mask): + def _loop_cond(_i, _step_input, _valid_length, _scores, beam_alive_mask, *_states): return F.sum(beam_alive_mask) > 0 - def _loop_func(i, step_input, states, valid_length, scores, beam_alive_mask): + def _loop_func(i, step_input, valid_length, scores, beam_alive_mask, *states): log_probs, new_states = self._decoder(step_input, states) step = i + 1 - _, valid_length, scores, chosen_word_ids, beam_alive_mask, states = \ - self._updater([], valid_length, log_probs, scores, step, beam_alive_mask, - new_states, vocab_num, batch_shift) - step_input = F.relu(chosen_word_ids).reshape((-1,)) - return (chosen_word_ids, ), (step, step_input, states, valid_length, scores, beam_alive_mask) + _, new_valid_length, new_scores, chosen_word_ids, new_beam_alive_mask, new_new_states = \ + self._updater(valid_length, log_probs, scores, step, beam_alive_mask, + new_states, vocab_num, batch_shift, []) + new_step_input = F.relu(chosen_word_ids).reshape((-1,)) + return (chosen_word_ids, ), (step, new_step_input, new_valid_length, new_scores, new_beam_alive_mask) + tuple(new_new_states) - (samples, ), (_, _, _, valid_length, scores, beam_alive_mask) = \ + (samples, ), (_, _, new_valid_length, new_scores, new_beam_alive_mask, _) = \ F.contrib.while_loop( cond=_loop_cond, func=_loop_func, max_iterations=self._max_length, loop_vars=( F.zeros(shape=(1, )), # i step_input, - states, valid_length, scores, beam_alive_mask - ) + ) + tuple(states) ) def _then_func(): new_samples = F.concat( step_input.reshape((batch_size, beam_size, 1)), - samples, + samples.transpose((1, 2, 0)), + F.full(shape=(batch_size, beam_size, 1), val=-1), dim=2) - return new_samples, valid_length + new_new_valid_length = new_valid_length + return new_samples, new_new_valid_length def _else_func(): final_word = F.where(beam_alive_mask, - F.full(shape=(batch_size, beam_size), - val=self._eos_id), - F.full(shape=(batch_size, beam_size), - val=-1)) + F.full(shape=(batch_size, beam_size), val=self._eos_id), + F.full(shape=(batch_size, beam_size), val=-1)) new_samples = F.concat( step_input.reshape((batch_size, beam_size, 1)), - samples, + samples.transpose((1, 2, 0)), final_word.reshape((0, 0, 1)), dim=2) - new_valid_length = valid_length + beam_alive_mask + new_new_valid_length = new_valid_length + new_beam_alive_mask return new_samples, new_valid_length - samples, scores = F.contrib.cond(F.sum(beam_alive_mask) == 0, _then_func, _else_func) - return F.round(samples).astype(np.int32),\ - scores,\ - F.round(valid_length).astype(np.int32) + new_samples, new_new_valid_length = F.contrib.cond(F.sum(new_beam_alive_mask) == 0, _then_func, _else_func) + return F.round(new_samples).astype(np.int32),\ + new_scores,\ + F.round(new_new_valid_length).astype(np.int32) diff --git a/tests/unittest/test_beam_search.py b/tests/unittest/test_beam_search.py index c988ae0a9e..1f3c0b38d4 100644 --- a/tests/unittest/test_beam_search.py +++ b/tests/unittest/test_beam_search.py @@ -7,7 +7,7 @@ from mxnet.gluon.rnn import RNNCell, RNN from numpy.testing import assert_allclose -from gluonnlp.model import BeamSearchSampler, BeamSearchScorer +from gluonnlp.model import BeamSearchSampler, BeamSearchScorer, HybridBeamSearchSampler def test_beam_search_score(): @@ -27,7 +27,9 @@ def test_beam_search_score(): @pytest.mark.seed(1) -def test_beam_search(): +@pytest.mark.parametrize('hybridize', [False, True]) +@pytest.mark.parametrize('sampler_cls', [HybridBeamSearchSampler, BeamSearchSampler]) +def test_beam_search(hybridize, sampler_cls): def _get_new_states(states, state_info, sel_beam_ids): assert not state_info or isinstance(state_info, (type(states), dict)), \ 'states and state_info don\'t match' @@ -229,12 +231,22 @@ def forward(self, inputs, states): state_info = decoder.state_info() else: state_info = None + if sampler_cls is HybridBeamSearchSampler and decoder_fn is RNNLayerDecoder: + # Hybrid beam search does not work on non-hybridizable object + continue for beam_size, bos_id, eos_id, alpha, K in [(2, 1, 3, 0, 1.0), (4, 2, 3, 1.0, 5.0)]: scorer = BeamSearchScorer(alpha=alpha, K=K) for max_length in [10, 20]: - sampler = BeamSearchSampler(beam_size=beam_size, decoder=decoder, eos_id=eos_id, - scorer=scorer, max_length=max_length) for batch_size in [1, 2, 5]: + if sampler_cls is HybridBeamSearchSampler: + sampler = sampler_cls(batch_size=batch_size, beam_size=beam_size, + decoder=decoder, eos_id=eos_id, vocab_size=vocab_num, + scorer=scorer, max_length=max_length) + if hybridize: + sampler.hybridize() + else: + sampler = sampler_cls(beam_size=beam_size, decoder=decoder, eos_id=eos_id, + scorer=scorer, max_length=max_length) print(type(decoder).__name__, beam_size, bos_id, eos_id, alpha, K, batch_size) states = decoder.begin_state(batch_size) inputs = mx.nd.full(shape=(batch_size,), val=bos_id)