|
3 | 3 | # This source code is licensed under the MIT license found in the
|
4 | 4 | # LICENSE file in the root directory of this source tree.
|
5 | 5 |
|
6 |
| -from fairseq import utils |
| 6 | +from typing import Dict, Optional |
| 7 | +import uuid |
| 8 | + |
| 9 | +from torch import Tensor |
7 | 10 |
|
8 | 11 |
|
9 | 12 | class FairseqIncrementalState(object):
|
| 13 | + |
10 | 14 | def __init__(self, *args, **kwargs):
|
11 | 15 | super().__init__(*args, **kwargs)
|
12 |
| - init_incremental_state(self) |
| 16 | + self.init_incremental_state() |
| 17 | + |
| 18 | + def init_incremental_state(self): |
| 19 | + self._incremental_state_id = str(uuid.uuid4()) |
| 20 | + |
| 21 | + def _get_full_incremental_state_key(self, key: str) -> str: |
| 22 | + return "{}.{}".format(self._incremental_state_id, key) |
| 23 | + |
| 24 | + def get_incremental_state( |
| 25 | + self, |
| 26 | + incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]], |
| 27 | + key: str, |
| 28 | + ) -> Optional[Dict[str, Optional[Tensor]]]: |
| 29 | + """Helper for getting incremental state for an nn.Module.""" |
| 30 | + full_key = self._get_full_incremental_state_key(key) |
| 31 | + if incremental_state is None or full_key not in incremental_state: |
| 32 | + return None |
| 33 | + return incremental_state[full_key] |
| 34 | + |
| 35 | + def set_incremental_state( |
| 36 | + self, |
| 37 | + incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]], |
| 38 | + key: str, |
| 39 | + value: Dict[str, Optional[Tensor]], |
| 40 | + ) -> Optional[Dict[str, Dict[str, Optional[Tensor]]]]: |
| 41 | + """Helper for setting incremental state for an nn.Module.""" |
| 42 | + if incremental_state is not None: |
| 43 | + full_key = self._get_full_incremental_state_key(key) |
| 44 | + incremental_state[full_key] = value |
| 45 | + return incremental_state |
13 | 46 |
|
14 | 47 |
|
15 | 48 | def with_incremental_state(cls):
|
16 | 49 | cls.__bases__ = (FairseqIncrementalState,) + tuple(b for b in cls.__bases__ if b != FairseqIncrementalState)
|
17 | 50 | return cls
|
18 |
| - |
19 |
| - |
20 |
| -# In most cases we should register incremental states using @with_incremental_state decorator |
21 |
| -# instead of calling into this explicitly in initializer. |
22 |
| -def init_incremental_state(obj): |
23 |
| - obj.module_name = obj.__class__.__name__ |
24 |
| - utils.INCREMENTAL_STATE_INSTANCE_ID[obj.module_name] = ( |
25 |
| - utils.INCREMENTAL_STATE_INSTANCE_ID.get(obj.module_name, 0) + 1 |
26 |
| - ) |
27 |
| - obj._fairseq_instance_id = utils.INCREMENTAL_STATE_INSTANCE_ID[ |
28 |
| - obj.module_name |
29 |
| - ] |
|
0 commit comments