Skip to content

Commit e0b3254

Browse files
Myle Ottfacebook-github-bot
Myle Ott
authored andcommitted
Cleanup new incremental state API (facebookresearch#1005)
Summary: * Now that we have `FairseqIncrementalState`, we can move `get_incremental_state` and `set_incremental_state` as methods in that class, instead of having the helper functions in `utils.py`. I think this will eventually help with type checking too. * The incremental ID logic was overly complicated, we can just use `uuid` to generate a unique ID for every instance. * Add missing `with_incremental_state` to light/dynamic conv modules. * Add additional unit test: `test_incremental_state_multihead_attention` Pull Request resolved: fairinternal/fairseq-py#1005 Test Plan: * unit tests Also confirmed this matches master: ``` $ python generate.py ~/data/data-bin/wmt16_en_de_bpe32k --path /checkpoint/myleott/s3/models/wmt16.en-de.joined-dict.transformer/model.pt --beam 4 --lenpen 0.6 --remove-bpe --quiet (...) 2020-01-22 09:53:38 | INFO | fairseq_cli.generate | Generate test with beam=4: BLEU4 = 29.28, 60.8/35.1/22.8/15.3 (BP=0.997, ratio=0.997, syslen=62859, reflen=63078) ``` Reviewed By: cndn Differential Revision: D19517908 Pulled By: myleott fbshipit-source-id: a406490e342d0d30a9231bf823d3350999bda4c0
1 parent 3e0065e commit e0b3254

File tree

7 files changed

+80
-52
lines changed

7 files changed

+80
-52
lines changed

fairseq/incremental_decoding_utils.py

+35-14
Original file line numberDiff line numberDiff line change
@@ -3,27 +3,48 @@
33
# This source code is licensed under the MIT license found in the
44
# LICENSE file in the root directory of this source tree.
55

6-
from fairseq import utils
6+
from typing import Dict, Optional
7+
import uuid
8+
9+
from torch import Tensor
710

811

912
class FairseqIncrementalState(object):
13+
1014
def __init__(self, *args, **kwargs):
1115
super().__init__(*args, **kwargs)
12-
init_incremental_state(self)
16+
self.init_incremental_state()
17+
18+
def init_incremental_state(self):
19+
self._incremental_state_id = str(uuid.uuid4())
20+
21+
def _get_full_incremental_state_key(self, key: str) -> str:
22+
return "{}.{}".format(self._incremental_state_id, key)
23+
24+
def get_incremental_state(
25+
self,
26+
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]],
27+
key: str,
28+
) -> Optional[Dict[str, Optional[Tensor]]]:
29+
"""Helper for getting incremental state for an nn.Module."""
30+
full_key = self._get_full_incremental_state_key(key)
31+
if incremental_state is None or full_key not in incremental_state:
32+
return None
33+
return incremental_state[full_key]
34+
35+
def set_incremental_state(
36+
self,
37+
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]],
38+
key: str,
39+
value: Dict[str, Optional[Tensor]],
40+
) -> Optional[Dict[str, Dict[str, Optional[Tensor]]]]:
41+
"""Helper for setting incremental state for an nn.Module."""
42+
if incremental_state is not None:
43+
full_key = self._get_full_incremental_state_key(key)
44+
incremental_state[full_key] = value
45+
return incremental_state
1346

1447

1548
def with_incremental_state(cls):
1649
cls.__bases__ = (FairseqIncrementalState,) + tuple(b for b in cls.__bases__ if b != FairseqIncrementalState)
1750
return cls
18-
19-
20-
# In most cases we should register incremental states using @with_incremental_state decorator
21-
# instead of calling into this explicitly in initializer.
22-
def init_incremental_state(obj):
23-
obj.module_name = obj.__class__.__name__
24-
utils.INCREMENTAL_STATE_INSTANCE_ID[obj.module_name] = (
25-
utils.INCREMENTAL_STATE_INSTANCE_ID.get(obj.module_name, 0) + 1
26-
)
27-
obj._fairseq_instance_id = utils.INCREMENTAL_STATE_INSTANCE_ID[
28-
obj.module_name
29-
]

fairseq/models/fairseq_incremental_decoder.py

+9-7
Original file line numberDiff line numberDiff line change
@@ -67,14 +67,16 @@ def reorder_incremental_state(self, incremental_state, new_order):
6767
order changes between time steps based on the selection of beams.
6868
"""
6969
seen = set()
70-
71-
def apply_reorder_incremental_state(module):
72-
if module != self and hasattr(module, 'reorder_incremental_state') \
73-
and module not in seen:
70+
for module in self.modules():
71+
if (
72+
module != self
73+
and hasattr(module, 'reorder_incremental_state')
74+
and module not in seen
75+
):
7476
seen.add(module)
75-
module.reorder_incremental_state(incremental_state, new_order)
76-
77-
self.apply(apply_reorder_incremental_state)
77+
result = module.reorder_incremental_state(incremental_state, new_order)
78+
if result is not None:
79+
incremental_state = result
7880

7981
def set_beam_size(self, beam_size):
8082
"""Sets the beam size in the decoder and all children."""

fairseq/modules/dynamicconv_layer/dynamicconv_layer.py

+2
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import dynamicconv_cuda
1212
from fairseq import utils
1313
from fairseq.modules.unfold import unfold1d
14+
from fairseq.incremental_decoding_utils import with_incremental_state
1415

1516

1617
class dynamicconvFunction(Function):
@@ -33,6 +34,7 @@ def backward(ctx, grad_output):
3334
return grad_input, grad_weights, None
3435

3536

37+
@with_incremental_state
3638
class DynamicconvLayer(nn.Module):
3739
def __init__(
3840
self,

fairseq/modules/lightconv_layer/lightconv_layer.py

+2
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
import lightconv_cuda
1212
from fairseq import utils
13+
from fairseq.incremental_decoding_utils import with_incremental_state
1314

1415

1516
class lightconvFunction(Function):
@@ -32,6 +33,7 @@ def backward(ctx, grad_output):
3233
return grad_input, grad_weights, None
3334

3435

36+
@with_incremental_state
3537
class LightconvLayer(nn.Module):
3638
def __init__(
3739
self,

fairseq/modules/multihead_attention.py

+10-13
Original file line numberDiff line numberDiff line change
@@ -274,7 +274,7 @@ def forward(
274274
saved_state["prev_key_padding_mask"] = key_padding_mask
275275
# In this branch incremental_state is never None
276276
assert incremental_state is not None
277-
self._set_input_buffer(incremental_state, saved_state)
277+
incremental_state = self._set_input_buffer(incremental_state, saved_state)
278278
assert k is not None
279279
src_len = k.size(1)
280280

@@ -405,28 +405,25 @@ def reorder_incremental_state(
405405
for k in input_buffer.keys():
406406
if input_buffer[k] is not None:
407407
input_buffer[k] = input_buffer[k].index_select(0, new_order)
408-
self._set_input_buffer(incremental_state, input_buffer)
408+
incremental_state = self._set_input_buffer(incremental_state, input_buffer)
409+
return incremental_state
409410

410411
def _get_input_buffer(
411412
self, incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]]
412413
) -> Dict[str, Optional[Tensor]]:
413-
empty_dict_annotated: Dict[str, Optional[Tensor]] = {}
414-
if incremental_state is None:
415-
return empty_dict_annotated
416-
full_key = utils._get_full_incremental_state_key(self, "attn_state")
417-
if full_key not in incremental_state:
418-
return empty_dict_annotated
419-
return incremental_state[full_key]
414+
result = self.get_incremental_state(incremental_state, "attn_state")
415+
if result is not None:
416+
return result
417+
else:
418+
empty_result: Dict[str, Optional[Tensor]] = {}
419+
return empty_result
420420

421421
def _set_input_buffer(
422422
self,
423423
incremental_state: Dict[str, Dict[str, Optional[Tensor]]],
424424
buffer: Dict[str, Optional[Tensor]],
425425
):
426-
full_key = utils._get_full_incremental_state_key(
427-
self, "attn_state"
428-
)
429-
incremental_state[full_key] = buffer
426+
return self.set_incremental_state(incremental_state, "attn_state", buffer)
430427

431428
def apply_sparse_mask(attn_weights, tgt_len: int, src_len: int, bsz: int):
432429
return attn_weights

fairseq/utils.py

+6-18
Original file line numberDiff line numberDiff line change
@@ -61,39 +61,27 @@ def _move_to_cuda(tensor):
6161
return apply_to_sample(_move_to_cuda, sample)
6262

6363

64-
INCREMENTAL_STATE_INSTANCE_ID = {}
65-
66-
67-
def _get_full_incremental_state_key(
68-
module_instance: MultiheadAttention, key: str
69-
) -> str:
70-
return "{}.{}.{}".format(
71-
module_instance.module_name, module_instance._fairseq_instance_id, key
72-
)
73-
74-
7564
def get_incremental_state(
7665
module: MultiheadAttention,
7766
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]],
7867
key: str,
7968
) -> Optional[Dict[str, Optional[Tensor]]]:
8069
"""Helper for getting incremental state for an nn.Module."""
81-
full_key = _get_full_incremental_state_key(module, key)
82-
if incremental_state is None or full_key not in incremental_state:
83-
return None
84-
return incremental_state[full_key]
70+
return module.get_incremental_state(incremental_state, key)
8571

8672

8773
def set_incremental_state(
8874
module: MultiheadAttention,
8975
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]],
9076
key: str,
9177
value: Dict[str, Optional[Tensor]],
92-
):
78+
) -> Optional[Dict[str, Dict[str, Optional[Tensor]]]]:
9379
"""Helper for setting incremental state for an nn.Module."""
9480
if incremental_state is not None:
95-
full_key = _get_full_incremental_state_key(module, key)
96-
incremental_state[full_key] = value
81+
result = module.set_incremental_state(incremental_state, key, value)
82+
if result is not None:
83+
incremental_state = result
84+
return incremental_state
9785

9886

9987
def load_align_dict(replace_unk):

tests/test_export.py

+16
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,26 @@
77

88

99
class TestExportModels(unittest.TestCase):
10+
1011
def test_export_multihead_attention(self):
1112
module = multihead_attention.MultiheadAttention(embed_dim=8, num_heads=2)
1213
torch.jit.script(module)
1314

15+
def test_incremental_state_multihead_attention(self):
16+
module1 = multihead_attention.MultiheadAttention(embed_dim=8, num_heads=2)
17+
module1 = torch.jit.script(module1)
18+
module2 = multihead_attention.MultiheadAttention(embed_dim=8, num_heads=2)
19+
module2 = torch.jit.script(module2)
20+
21+
state = {}
22+
state = module1.set_incremental_state(state, 'key', {'a': torch.tensor([1])})
23+
state = module2.set_incremental_state(state, 'key', {'a': torch.tensor([2])})
24+
v1 = module1.get_incremental_state(state, 'key')['a']
25+
v2 = module2.get_incremental_state(state, 'key')['a']
26+
27+
self.assertEqual(v1, 1)
28+
self.assertEqual(v2, 2)
29+
1430
def test_positional_embedding(self):
1531
module = sinusoidal_positional_embedding.SinusoidalPositionalEmbedding(
1632
embedding_dim=8, padding_idx=1

0 commit comments

Comments
 (0)