diff --git a/fairseq/incremental_decoding_utils.py b/fairseq/incremental_decoding_utils.py index 8a124a4349..91128e8879 100644 --- a/fairseq/incremental_decoding_utils.py +++ b/fairseq/incremental_decoding_utils.py @@ -3,27 +3,48 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -from fairseq import utils +from typing import Dict, Optional +import uuid + +from torch import Tensor class FairseqIncrementalState(object): + def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - init_incremental_state(self) + self.init_incremental_state() + + def init_incremental_state(self): + self._incremental_state_id = str(uuid.uuid4()) + + def _get_full_incremental_state_key(self, key: str) -> str: + return "{}.{}".format(self._incremental_state_id, key) + + def get_incremental_state( + self, + incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]], + key: str, + ) -> Optional[Dict[str, Optional[Tensor]]]: + """Helper for getting incremental state for an nn.Module.""" + full_key = self._get_full_incremental_state_key(key) + if incremental_state is None or full_key not in incremental_state: + return None + return incremental_state[full_key] + + def set_incremental_state( + self, + incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]], + key: str, + value: Dict[str, Optional[Tensor]], + ) -> Optional[Dict[str, Dict[str, Optional[Tensor]]]]: + """Helper for setting incremental state for an nn.Module.""" + if incremental_state is not None: + full_key = self._get_full_incremental_state_key(key) + incremental_state[full_key] = value + return incremental_state def with_incremental_state(cls): cls.__bases__ = (FairseqIncrementalState,) + tuple(b for b in cls.__bases__ if b != FairseqIncrementalState) return cls - - -# In most cases we should register incremental states using @with_incremental_state decorator -# instead of calling into this explicitly in initializer. -def init_incremental_state(obj): - obj.module_name = obj.__class__.__name__ - utils.INCREMENTAL_STATE_INSTANCE_ID[obj.module_name] = ( - utils.INCREMENTAL_STATE_INSTANCE_ID.get(obj.module_name, 0) + 1 - ) - obj._fairseq_instance_id = utils.INCREMENTAL_STATE_INSTANCE_ID[ - obj.module_name - ] diff --git a/fairseq/models/fairseq_incremental_decoder.py b/fairseq/models/fairseq_incremental_decoder.py index 4c6caa4a7d..905a43a4c9 100644 --- a/fairseq/models/fairseq_incremental_decoder.py +++ b/fairseq/models/fairseq_incremental_decoder.py @@ -67,14 +67,16 @@ def reorder_incremental_state(self, incremental_state, new_order): order changes between time steps based on the selection of beams. """ seen = set() - - def apply_reorder_incremental_state(module): - if module != self and hasattr(module, 'reorder_incremental_state') \ - and module not in seen: + for module in self.modules(): + if ( + module != self + and hasattr(module, 'reorder_incremental_state') + and module not in seen + ): seen.add(module) - module.reorder_incremental_state(incremental_state, new_order) - - self.apply(apply_reorder_incremental_state) + result = module.reorder_incremental_state(incremental_state, new_order) + if result is not None: + incremental_state = result def set_beam_size(self, beam_size): """Sets the beam size in the decoder and all children.""" diff --git a/fairseq/modules/dynamicconv_layer/dynamicconv_layer.py b/fairseq/modules/dynamicconv_layer/dynamicconv_layer.py index 3e51f09fa6..dd6eaee0a8 100644 --- a/fairseq/modules/dynamicconv_layer/dynamicconv_layer.py +++ b/fairseq/modules/dynamicconv_layer/dynamicconv_layer.py @@ -11,6 +11,7 @@ import dynamicconv_cuda from fairseq import utils from fairseq.modules.unfold import unfold1d +from fairseq.incremental_decoding_utils import with_incremental_state class dynamicconvFunction(Function): @@ -33,6 +34,7 @@ def backward(ctx, grad_output): return grad_input, grad_weights, None +@with_incremental_state class DynamicconvLayer(nn.Module): def __init__( self, diff --git a/fairseq/modules/lightconv_layer/lightconv_layer.py b/fairseq/modules/lightconv_layer/lightconv_layer.py index 5ff0497449..1f969f41d8 100644 --- a/fairseq/modules/lightconv_layer/lightconv_layer.py +++ b/fairseq/modules/lightconv_layer/lightconv_layer.py @@ -10,6 +10,7 @@ import lightconv_cuda from fairseq import utils +from fairseq.incremental_decoding_utils import with_incremental_state class lightconvFunction(Function): @@ -32,6 +33,7 @@ def backward(ctx, grad_output): return grad_input, grad_weights, None +@with_incremental_state class LightconvLayer(nn.Module): def __init__( self, diff --git a/fairseq/modules/multihead_attention.py b/fairseq/modules/multihead_attention.py index 4df7540f72..d12800c705 100644 --- a/fairseq/modules/multihead_attention.py +++ b/fairseq/modules/multihead_attention.py @@ -274,7 +274,7 @@ def forward( saved_state["prev_key_padding_mask"] = key_padding_mask # In this branch incremental_state is never None assert incremental_state is not None - self._set_input_buffer(incremental_state, saved_state) + incremental_state = self._set_input_buffer(incremental_state, saved_state) assert k is not None src_len = k.size(1) @@ -405,28 +405,25 @@ def reorder_incremental_state( for k in input_buffer.keys(): if input_buffer[k] is not None: input_buffer[k] = input_buffer[k].index_select(0, new_order) - self._set_input_buffer(incremental_state, input_buffer) + incremental_state = self._set_input_buffer(incremental_state, input_buffer) + return incremental_state def _get_input_buffer( self, incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] ) -> Dict[str, Optional[Tensor]]: - empty_dict_annotated: Dict[str, Optional[Tensor]] = {} - if incremental_state is None: - return empty_dict_annotated - full_key = utils._get_full_incremental_state_key(self, "attn_state") - if full_key not in incremental_state: - return empty_dict_annotated - return incremental_state[full_key] + result = self.get_incremental_state(incremental_state, "attn_state") + if result is not None: + return result + else: + empty_result: Dict[str, Optional[Tensor]] = {} + return empty_result def _set_input_buffer( self, incremental_state: Dict[str, Dict[str, Optional[Tensor]]], buffer: Dict[str, Optional[Tensor]], ): - full_key = utils._get_full_incremental_state_key( - self, "attn_state" - ) - incremental_state[full_key] = buffer + return self.set_incremental_state(incremental_state, "attn_state", buffer) def apply_sparse_mask(attn_weights, tgt_len: int, src_len: int, bsz: int): return attn_weights diff --git a/fairseq/utils.py b/fairseq/utils.py index 70c8f00512..25e8446b11 100644 --- a/fairseq/utils.py +++ b/fairseq/utils.py @@ -61,27 +61,13 @@ def _move_to_cuda(tensor): return apply_to_sample(_move_to_cuda, sample) -INCREMENTAL_STATE_INSTANCE_ID = {} - - -def _get_full_incremental_state_key( - module_instance: MultiheadAttention, key: str -) -> str: - return "{}.{}.{}".format( - module_instance.module_name, module_instance._fairseq_instance_id, key - ) - - def get_incremental_state( module: MultiheadAttention, incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]], key: str, ) -> Optional[Dict[str, Optional[Tensor]]]: """Helper for getting incremental state for an nn.Module.""" - full_key = _get_full_incremental_state_key(module, key) - if incremental_state is None or full_key not in incremental_state: - return None - return incremental_state[full_key] + return module.get_incremental_state(incremental_state, key) def set_incremental_state( @@ -89,11 +75,13 @@ def set_incremental_state( incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]], key: str, value: Dict[str, Optional[Tensor]], -): +) -> Optional[Dict[str, Dict[str, Optional[Tensor]]]]: """Helper for setting incremental state for an nn.Module.""" if incremental_state is not None: - full_key = _get_full_incremental_state_key(module, key) - incremental_state[full_key] = value + result = module.set_incremental_state(incremental_state, key, value) + if result is not None: + incremental_state = result + return incremental_state def load_align_dict(replace_unk): diff --git a/tests/test_export.py b/tests/test_export.py index 7ae5d57588..e286c7ab49 100644 --- a/tests/test_export.py +++ b/tests/test_export.py @@ -7,10 +7,26 @@ class TestExportModels(unittest.TestCase): + def test_export_multihead_attention(self): module = multihead_attention.MultiheadAttention(embed_dim=8, num_heads=2) torch.jit.script(module) + def test_incremental_state_multihead_attention(self): + module1 = multihead_attention.MultiheadAttention(embed_dim=8, num_heads=2) + module1 = torch.jit.script(module1) + module2 = multihead_attention.MultiheadAttention(embed_dim=8, num_heads=2) + module2 = torch.jit.script(module2) + + state = {} + state = module1.set_incremental_state(state, 'key', {'a': torch.tensor([1])}) + state = module2.set_incremental_state(state, 'key', {'a': torch.tensor([2])}) + v1 = module1.get_incremental_state(state, 'key')['a'] + v2 = module2.get_incremental_state(state, 'key')['a'] + + self.assertEqual(v1, 1) + self.assertEqual(v2, 2) + def test_positional_embedding(self): module = sinusoidal_positional_embedding.SinusoidalPositionalEmbedding( embedding_dim=8, padding_idx=1