diff --git a/gluonnlp/model/beam_search.py b/gluonnlp/model/beam_search.py index 7feab989c5..14d1c5b0a8 100644 --- a/gluonnlp/model/beam_search.py +++ b/gluonnlp/model/beam_search.py @@ -20,7 +20,7 @@ from __future__ import absolute_import from __future__ import print_function -__all__ = ['BeamSearchScorer', 'BeamSearchSampler'] +__all__ = ['BeamSearchScorer', 'BeamSearchSampler', 'HybridBeamSearchSampler'] import numpy as np import mxnet as mx @@ -75,12 +75,72 @@ def hybrid_forward(self, F, log_probs, scores, step): # pylint: disable=argume return candidate_scores +def _extract_and_flatten_nested_structure(data, flattened=None): + """Flatten the structure of a nested container to a list. + + Parameters + ---------- + data : A single NDArray/Symbol or nested container with NDArrays/Symbol. + The nested container to be flattened. + flattened : list or None + The container thats holds flattened result. + Returns + ------- + structure : An integer or a nested container with integers. + The extracted structure of the container of `data`. + flattened : (optional) list + The container thats holds flattened result. + It is returned only when the input argument `flattened` is not given. + """ + if flattened is None: + flattened = [] + structure = _extract_and_flatten_nested_structure(data, flattened) + return structure, flattened + if isinstance(data, list): + return list(_extract_and_flatten_nested_structure(x, flattened) for x in data) + elif isinstance(data, tuple): + return tuple(_extract_and_flatten_nested_structure(x, flattened) for x in data) + elif isinstance(data, dict): + return {k: _extract_and_flatten_nested_structure(v) for k, v in data.items()} + elif isinstance(data, (mx.sym.Symbol, mx.nd.NDArray)): + flattened.append(data) + return len(flattened) - 1 + else: + raise NotImplementedError + + +def _reconstruct_flattened_structure(structure, flattened): + """Reconstruct the flattened list back to (possibly) nested structure. + + Parameters + ---------- + structure : An integer or a nested container with integers. + The extracted structure of the container of `data`. + flattened : list or None + The container thats holds flattened result. + Returns + ------- + data : A single NDArray/Symbol or nested container with NDArrays/Symbol. + The nested container that was flattened. + """ + if isinstance(structure, list): + return list(_reconstruct_flattened_structure(x, flattened) for x in structure) + elif isinstance(structure, tuple): + return tuple(_reconstruct_flattened_structure(x, flattened) for x in structure) + elif isinstance(structure, dict): + return {k: _reconstruct_flattened_structure(v, flattened) for k, v in structure.items()} + elif isinstance(structure, int): + return flattened[structure] + else: + raise NotImplementedError + + def _expand_to_beam_size(data, beam_size, batch_size, state_info=None): """Tile all the states to have batch_size * beam_size on the batch axis. Parameters ---------- - data : A single NDArray or nested container with NDArrays + data : A single NDArray/Symbol or nested container with NDArrays/Symbol Each NDArray/Symbol should have shape (N, ...) when state_info is None, or same as the layout in state_info when it's not None. beam_size : int @@ -92,8 +152,8 @@ def _expand_to_beam_size(data, beam_size, batch_size, state_info=None): When None, this method assumes that the batch axis is the first dimension. Returns ------- - new_states : Object that contains NDArrays - Each NDArray should have shape batch_size * beam_size on the batch axis. + new_states : Object that contains NDArrays/Symbols + Each NDArray/Symbol should have shape batch_size * beam_size on the batch axis. """ assert not state_info or isinstance(state_info, (type(data), dict)), \ 'data and state_info doesn\'t match, ' \ @@ -128,6 +188,15 @@ def _expand_to_beam_size(data, beam_size, batch_size, state_info=None): return data.expand_dims(batch_axis+1)\ .broadcast_axes(axis=batch_axis+1, size=beam_size)\ .reshape(new_shape) + elif isinstance(data, mx.sym.Symbol): + if not state_info: + batch_axis = 0 + else: + batch_axis = state_info['__layout__'].find('N') + new_shape = (0, ) * batch_axis + (-3, -2) + return data.expand_dims(batch_axis+1)\ + .broadcast_axes(axis=batch_axis+1, size=beam_size)\ + .reshape(new_shape) else: raise NotImplementedError @@ -183,23 +252,27 @@ def _choose_states(F, states, state_info, indices): class _BeamSearchStepUpdate(HybridBlock): - def __init__(self, beam_size, eos_id, scorer, state_info, prefix=None, params=None): + def __init__(self, beam_size, eos_id, scorer, state_info, single_step=False, \ + prefix=None, params=None): super(_BeamSearchStepUpdate, self).__init__(prefix, params) self._beam_size = beam_size self._eos_id = eos_id self._scorer = scorer self._state_info = state_info + self._single_step = single_step assert eos_id >= 0, 'eos_id cannot be negative! Received eos_id={}'.format(eos_id) - def hybrid_forward(self, F, samples, valid_length, log_probs, scores, step, beam_alive_mask, # pylint: disable=arguments-differ - states, vocab_num, batch_shift): + def hybrid_forward(self, F, samples, valid_length, log_probs, scores, step, beam_alive_mask, # pylint: disable=arguments-differ + states, vocab_size, batch_shift): """ Parameters ---------- F samples : NDArray or Symbol - The current samples generated by beam search. Shape (batch_size, beam_size, L) + The current samples generated by beam search. + When single_step is True, (batch_size, beam_size, max_length). + When single_step is False, (batch_size, beam_size, L). valid_length : NDArray or Symbol The current valid lengths of the samples log_probs : NDArray or Symbol @@ -213,7 +286,7 @@ def hybrid_forward(self, F, samples, valid_length, log_probs, scores, step, beam states : nested structure of NDArrays/Symbols Each NDArray/Symbol should have shape (N, ...) when state_info is None, or same as the layout in state_info when it's not None. - vocab_num : NDArray or Symbol + vocab_size : NDArray or Symbol Shape (1,) batch_shift : NDArray or Symbol Contains [0, beam_size, 2 * beam_size, ..., (batch_size - 1) * beam_size]. @@ -221,8 +294,10 @@ def hybrid_forward(self, F, samples, valid_length, log_probs, scores, step, beam Returns ------- - new_samples : NDArray or Symbol - The updated samples. Shape (batch_size, beam_size, L + 1) + new_samples : NDArray or Symbol or an empty list + The updated samples. + When single_step is True, it is an empty list. + When single_step is False, shape (batch_size, beam_size, L + 1) new_valid_length : NDArray or Symbol Valid lengths of the samples. Shape (batch_size, beam_size) new_scores : NDArray or Symbol @@ -250,11 +325,11 @@ def hybrid_forward(self, F, samples, valid_length, log_probs, scores, step, beam finished_scores, dim=1) # Get the top K scores new_scores, indices = F.topk(candidate_scores, axis=1, k=beam_size, ret_typ='both') - use_prev = F.broadcast_greater_equal(indices, beam_size * vocab_num) - chosen_word_ids = F.broadcast_mod(indices, vocab_num) + use_prev = F.broadcast_greater_equal(indices, beam_size * vocab_size) + chosen_word_ids = F.broadcast_mod(indices, vocab_size) beam_ids = F.where(use_prev, - F.broadcast_minus(indices, beam_size * vocab_num), - F.floor(F.broadcast_div(indices, vocab_num))) + F.broadcast_minus(indices, beam_size * vocab_size), + F.floor(F.broadcast_div(indices, vocab_size))) batch_beam_indices = F.broadcast_add(beam_ids, F.expand_dims(batch_shift, axis=1)) chosen_word_ids = F.where(use_prev, -F.ones_like(indices), @@ -264,6 +339,8 @@ def hybrid_forward(self, F, samples, valid_length, log_probs, scores, step, beam batch_beam_indices.reshape(shape=(-1,))), chosen_word_ids.reshape(shape=(-1, 1)), dim=1)\ .reshape(shape=(-4, -1, beam_size, 0)) + if self._single_step: + new_samples = new_samples.slice_axis(axis=2, begin=1, end=None) new_valid_length = F.take(valid_length.reshape(shape=(-1,)), batch_beam_indices.reshape(shape=(-1,))).reshape((-1, beam_size))\ + 1 - use_prev @@ -363,12 +440,12 @@ def __call__(self, inputs, states): samples = step_input.reshape((batch_size, beam_size, 1)) for i in range(self._max_length): log_probs, new_states = self._decoder(step_input, states) - vocab_num_nd = mx.nd.array([log_probs.shape[1]], ctx=ctx) + vocab_size_nd = mx.nd.array([log_probs.shape[1]], ctx=ctx) batch_shift_nd = mx.nd.arange(0, batch_size * beam_size, beam_size, ctx=ctx) step_nd = mx.nd.array([i + 1], ctx=ctx) samples, valid_length, scores, chosen_word_ids, beam_alive_mask, states = \ self._updater(samples, valid_length, log_probs, scores, step_nd, beam_alive_mask, - new_states, vocab_num_nd, batch_shift_nd) + new_states, vocab_size_nd, batch_shift_nd) step_input = mx.nd.relu(chosen_word_ids).reshape((-1,)) if mx.nd.sum(beam_alive_mask).asscalar() == 0: return mx.nd.round(samples).astype(np.int32),\ @@ -384,3 +461,164 @@ def __call__(self, inputs, states): return mx.nd.round(samples).astype(np.int32),\ scores,\ mx.nd.round(valid_length).astype(np.int32) + + +class HybridBeamSearchSampler(HybridBlock): + r"""Draw samples from the decoder by beam search. + + Parameters + ---------- + batch_size : int + The batch size. + beam_size : int + The beam size. + decoder : callable, must be hybridizable + Function of the one-step-ahead decoder, should have the form:: + + log_probs, new_states = decoder(step_input, states) + + The log_probs, input should follow these rules: + + - step_input has shape (batch_size,), + - log_probs has shape (batch_size, V), + - states and new_states have the same structure and the leading + dimension of the inner NDArrays is the batch dimension. + eos_id : int + Id of the EOS token. No other elements will be appended to the sample if it reaches eos_id. + scorer : BeamSearchScorer, default BeamSearchScorer(alpha=1.0, K=5), must be hybridizable + The score function used in beam search. + max_length : int, default 100 + The maximum search length. + vocab_size : int, default None, meaning `decoder._vocab_size` + The vocabulary size + """ + def __init__(self, batch_size, beam_size, decoder, eos_id, + scorer=BeamSearchScorer(alpha=1.0, K=5), + max_length=100, vocab_size=None, + prefix=None, params=None): + super(HybridBeamSearchSampler, self).__init__(prefix, params) + self._batch_size = batch_size + self._beam_size = beam_size + assert beam_size > 0,\ + 'beam_size must be larger than 0. Received beam_size={}'.format(beam_size) + self._decoder = decoder + self._eos_id = eos_id + assert eos_id >= 0, 'eos_id cannot be negative! Received eos_id={}'.format(eos_id) + self._max_length = max_length + self._scorer = scorer + self._state_info_func = getattr(decoder, 'state_info', lambda _=None: None) + self._updater = _BeamSearchStepUpdate(beam_size=beam_size, eos_id=eos_id, scorer=scorer, + single_step=True, state_info=self._state_info_func()) + self._updater.hybridize() + self._vocab_size = vocab_size or getattr(decoder, '_vocab_size', None) + assert self._vocab_size is not None,\ + 'Please provide vocab_size or define decoder._vocab_size' + assert not hasattr(decoder, '_vocab_size') or decoder._vocab_size == self._vocab_size, \ + 'Provided vocab_size={} is not equal to decoder._vocab_size={}'\ + .format(self._vocab_size, decoder._vocab_size) + + def hybrid_forward(self, F, inputs, states): # pylint: disable=arguments-differ + """Sample by beam search. + + Parameters + ---------- + F + inputs : NDArray or Symbol + The initial input of the decoder. Shape is (batch_size,). + states : Object that contains NDArrays or Symbols + The initial states of the decoder. + Returns + ------- + samples : NDArray or Symbol + Samples draw by beam search. Shape (batch_size, beam_size, length). dtype is int32. + scores : NDArray or Symbol + Scores of the samples. Shape (batch_size, beam_size). We make sure that scores[i, :] are + in descending order. + valid_length : NDArray or Symbol + The valid length of the samples. Shape (batch_size, beam_size). dtype will be int32. + """ + batch_size = self._batch_size + beam_size = self._beam_size + vocab_size = self._vocab_size + # Tile the states and inputs to have shape (batch_size * beam_size, ...) + state_info = self._state_info_func(batch_size) + step_input = _expand_to_beam_size(inputs, beam_size=beam_size, batch_size=batch_size) + states = _expand_to_beam_size(states, beam_size=beam_size, batch_size=batch_size, + state_info=state_info) + state_structure, states = _extract_and_flatten_nested_structure(states) + if beam_size == 1: + init_scores = F.zeros(shape=(batch_size, 1)) + else: + init_scores = F.concat( + F.zeros(shape=(batch_size, 1)), + F.full(shape=(batch_size, beam_size - 1), val=LARGE_NEGATIVE_FLOAT), + dim=1, + ) + vocab_size = F.full(shape=(1, ), val=vocab_size) + batch_shift = F.arange(0, batch_size * beam_size, beam_size) + + def _loop_cond(_i, _samples, _indices, _step_input, _valid_length, _scores, \ + beam_alive_mask, *_states): + return F.sum(beam_alive_mask) > 0 + + def _loop_func(i, samples, indices, step_input, valid_length, scores, \ + beam_alive_mask, *states): + log_probs, new_states = self._decoder( + step_input, _reconstruct_flattened_structure(state_structure, states)) + step = i + 1 + new_samples, new_valid_length, new_scores, \ + chosen_word_ids, new_beam_alive_mask, new_new_states = \ + self._updater(samples, valid_length, log_probs, scores, step, beam_alive_mask, + _extract_and_flatten_nested_structure(new_states)[-1], + vocab_size, batch_shift) + new_step_input = F.relu(chosen_word_ids).reshape((-1,)) + # We are doing `new_indices = indices[1 : ] + indices[ : 1]` + new_indices = F.concat( + indices.slice_axis(axis=0, begin=1, end=None), + indices.slice_axis(axis=0, begin=0, end=1), + dim=0, + ) + return [], (step, new_samples, new_indices, new_step_input, new_valid_length, \ + new_scores, new_beam_alive_mask) + tuple(new_new_states) + + _, pad_samples, indices, _, new_valid_length, new_scores, new_beam_alive_mask = \ + F.contrib.while_loop( + cond=_loop_cond, func=_loop_func, max_iterations=self._max_length, + loop_vars=( + F.zeros(shape=(1, )), # i + F.zeros(shape=(batch_size, beam_size, self._max_length)), # samples + F.arange(start=0, stop=self._max_length), # indices + step_input, # step_input + F.ones(shape=(batch_size, beam_size)), # valid_length + init_scores, # scores + F.ones(shape=(batch_size, beam_size)), # beam_alive_mask + ) + tuple(states) + )[1][:7] # I hate Python 2 + samples = pad_samples.take(indices, axis=2) + + def _then_func(): + new_samples = F.concat( + step_input.reshape((batch_size, beam_size, 1)), + samples, + F.full(shape=(batch_size, beam_size, 1), val=-1), + dim=2) + new_new_valid_length = new_valid_length + return new_samples, new_new_valid_length + + def _else_func(): + final_word = F.where(new_beam_alive_mask, + F.full(shape=(batch_size, beam_size), val=self._eos_id), + F.full(shape=(batch_size, beam_size), val=-1)) + new_samples = F.concat( + step_input.reshape((batch_size, beam_size, 1)), + samples, + final_word.reshape((0, 0, 1)), + dim=2) + new_new_valid_length = new_valid_length + new_beam_alive_mask + return new_samples, new_new_valid_length + + new_samples, new_new_valid_length = \ + F.contrib.cond(F.sum(new_beam_alive_mask) == 0, _then_func, _else_func) + return F.round(new_samples).astype(np.int32),\ + new_scores,\ + F.round(new_new_valid_length).astype(np.int32) diff --git a/tests/unittest/test_beam_search.py b/tests/unittest/test_beam_search.py index c988ae0a9e..511ab937d0 100644 --- a/tests/unittest/test_beam_search.py +++ b/tests/unittest/test_beam_search.py @@ -7,7 +7,7 @@ from mxnet.gluon.rnn import RNNCell, RNN from numpy.testing import assert_allclose -from gluonnlp.model import BeamSearchSampler, BeamSearchScorer +from gluonnlp.model import BeamSearchSampler, BeamSearchScorer, HybridBeamSearchSampler def test_beam_search_score(): @@ -27,7 +27,9 @@ def test_beam_search_score(): @pytest.mark.seed(1) -def test_beam_search(): +@pytest.mark.parametrize('hybridize', [False, True]) +@pytest.mark.parametrize('sampler_cls', [HybridBeamSearchSampler, BeamSearchSampler]) +def test_beam_search(hybridize, sampler_cls): def _get_new_states(states, state_info, sel_beam_ids): assert not state_info or isinstance(state_info, (type(states), dict)), \ 'states and state_info don\'t match' @@ -106,7 +108,7 @@ def _npy_beam_search(decoder, scorer, inputs, states, eos_id, beam_size, max_len state_info = None for step in range(max_length): log_probs, states = decoder(mx.nd.array(inputs), states) - vocab_num = log_probs.shape[1] + vocab_size = log_probs.shape[1] candidate_scores = scorer(log_probs, mx.nd.array(scores), mx.nd.array([step + 1])).asnumpy() beam_done_inds = np.where(beam_done)[0] @@ -122,12 +124,12 @@ def _npy_beam_search(decoder, scorer, inputs, states, eos_id, beam_size, max_len sel_beam_ids = [] new_scores = candidate_scores[indices] for ind in indices: - if ind < beam_size * vocab_num: - sel_words.append(ind % vocab_num) - sel_beam_ids.append(ind // vocab_num) + if ind < beam_size * vocab_size: + sel_words.append(ind % vocab_size) + sel_beam_ids.append(ind // vocab_size) else: sel_words.append(-1) - sel_beam_ids.append(beam_done_inds[ind - beam_size * vocab_num]) + sel_beam_ids.append(beam_done_inds[ind - beam_size * vocab_size]) states = _get_new_states(states, state_info, sel_beam_ids) samples = np.concatenate((samples[sel_beam_ids, :], np.expand_dims(np.array(sel_words), axis=1)), axis=1) @@ -143,13 +145,13 @@ def _npy_beam_search(decoder, scorer, inputs, states, eos_id, beam_size, max_len HIDDEN_SIZE = 2 class RNNDecoder(HybridBlock): - def __init__(self, vocab_num, hidden_size, prefix=None, params=None): + def __init__(self, vocab_size, hidden_size, prefix=None, params=None): super(RNNDecoder, self).__init__(prefix=prefix, params=params) - self._vocab_num = vocab_num + self._vocab_size = vocab_size with self.name_scope(): - self._embed = nn.Embedding(input_dim=vocab_num, output_dim=hidden_size) + self._embed = nn.Embedding(input_dim=vocab_size, output_dim=hidden_size) self._rnn = RNNCell(hidden_size=hidden_size) - self._map_to_vocab = nn.Dense(vocab_num) + self._map_to_vocab = nn.Dense(vocab_size) def begin_state(self, batch_size): return self._rnn.begin_state(batch_size=batch_size, @@ -161,15 +163,15 @@ def hybrid_forward(self, F, inputs, states): return log_probs, states class RNNDecoder2(HybridBlock): - def __init__(self, vocab_num, hidden_size, prefix=None, params=None, use_tuple=False): + def __init__(self, vocab_size, hidden_size, prefix=None, params=None, use_tuple=False): super(RNNDecoder2, self).__init__(prefix=prefix, params=params) - self._vocab_num = vocab_num + self._vocab_size = vocab_size self._use_tuple = use_tuple with self.name_scope(): - self._embed = nn.Embedding(input_dim=vocab_num, output_dim=hidden_size) + self._embed = nn.Embedding(input_dim=vocab_size, output_dim=hidden_size) self._rnn1 = RNNCell(hidden_size=hidden_size) self._rnn2 = RNNCell(hidden_size=hidden_size) - self._map_to_vocab = nn.Dense(vocab_num) + self._map_to_vocab = nn.Dense(vocab_size) def begin_state(self, batch_size): ret = [self._rnn1.begin_state(batch_size=batch_size, @@ -195,14 +197,14 @@ def hybrid_forward(self, F, inputs, states): states = [states1, states2] return log_probs, states - class RNNLayerDecoder(Block): - def __init__(self, vocab_num, hidden_size, prefix=None, params=None): + class RNNLayerDecoder(HybridBlock): + def __init__(self, vocab_size, hidden_size, prefix=None, params=None): super(RNNLayerDecoder, self).__init__(prefix=prefix, params=params) - self._vocab_num = vocab_num + self._vocab_size = vocab_size with self.name_scope(): - self._embed = nn.Embedding(input_dim=vocab_num, output_dim=hidden_size) + self._embed = nn.Embedding(input_dim=vocab_size, output_dim=hidden_size) self._rnn = RNN(hidden_size=hidden_size, num_layers=1, activation='tanh') - self._map_to_vocab = nn.Dense(vocab_num, flatten=False) + self._map_to_vocab = nn.Dense(vocab_size, flatten=False) def begin_state(self, batch_size): return self._rnn.begin_state(batch_size=batch_size, @@ -211,31 +213,41 @@ def begin_state(self, batch_size): def state_info(self, *args, **kwargs): return self._rnn.state_info(*args, **kwargs) - def forward(self, inputs, states): + def hybrid_forward(self, F, inputs, states): out, states = self._rnn(self._embed(inputs.expand_dims(0)), states) - log_probs = self._map_to_vocab(out)[0].log_softmax() + log_probs = self._map_to_vocab(out).squeeze(axis=0).log_softmax() return log_probs, states # Begin Testing - for vocab_num in [4, 8]: + for vocab_size in [2, 3]: for decoder_fn in [RNNDecoder, functools.partial(RNNDecoder2, use_tuple=False), functools.partial(RNNDecoder2, use_tuple=True), RNNLayerDecoder]: - decoder = decoder_fn(vocab_num=vocab_num, hidden_size=HIDDEN_SIZE) + decoder = decoder_fn(vocab_size=vocab_size, hidden_size=HIDDEN_SIZE) decoder.hybridize() decoder.initialize() if hasattr(decoder, 'state_info'): state_info = decoder.state_info() else: state_info = None - for beam_size, bos_id, eos_id, alpha, K in [(2, 1, 3, 0, 1.0), (4, 2, 3, 1.0, 5.0)]: + for beam_size, bos_id, eos_id, alpha, K in [(2, 1, 3, 0, 1.0), (4, 2, 3, 1.0, 5.0)]: scorer = BeamSearchScorer(alpha=alpha, K=K) - for max_length in [10, 20]: - sampler = BeamSearchSampler(beam_size=beam_size, decoder=decoder, eos_id=eos_id, - scorer=scorer, max_length=max_length) - for batch_size in [1, 2, 5]: - print(type(decoder).__name__, beam_size, bos_id, eos_id, alpha, K, batch_size) + for max_length in [2, 3]: + for batch_size in [1, 5]: + if sampler_cls is HybridBeamSearchSampler: + sampler = sampler_cls(beam_size=beam_size, decoder=decoder, + eos_id=eos_id, + scorer=scorer, max_length=max_length, + vocab_size=vocab_size, batch_size=batch_size) + if hybridize: + sampler.hybridize() + else: + sampler = sampler_cls(beam_size=beam_size, decoder=decoder, + eos_id=eos_id, + scorer=scorer, max_length=max_length) + print(type(decoder).__name__, beam_size, bos_id, eos_id, \ + alpha, K, batch_size) states = decoder.begin_state(batch_size) inputs = mx.nd.full(shape=(batch_size,), val=bos_id) samples, scores, valid_length = sampler(inputs, states)