Skip to content

Commit 625ebe3

Browse files
Myle Ottfacebook-github-bot
Myle Ott
authored andcommitted
Standardize on 'teacher forcing' rather than 'input feeding' which is… (facebookresearch#769)
Summary: Input feeding generally refers to a slightly different concept Pull Request resolved: fairinternal/fairseq-py#769 Differential Revision: D16491898 Pulled By: myleott fbshipit-source-id: 68573584e820f11f199db4e7e37e9ee7a69a3287
1 parent 0b6fbec commit 625ebe3

8 files changed

+20
-21
lines changed

docs/tutorial_classifying_names.rst

+1-1
Original file line numberDiff line numberDiff line change
@@ -285,7 +285,7 @@ following contents::
285285
max_source_positions=self.args.max_positions,
286286
max_target_positions=1,
287287
# Since our target is a single class label, there's no need for
288-
# input feeding. If we set this to ``True`` then our Model's
288+
# teacher forcing. If we set this to ``True`` then our Model's
289289
# ``forward()`` method would receive an additional argument called
290290
# *prev_output_tokens* that would contain a shifted version of the
291291
# target sequence.

docs/tutorial_simple_lstm.rst

+6-6
Original file line numberDiff line numberDiff line change
@@ -125,9 +125,9 @@ Decoder
125125

126126
Our Decoder will predict the next word, conditioned on the Encoder's final
127127
hidden state and an embedded representation of the previous target word -- which
128-
is sometimes called *input feeding* or *teacher forcing*. More specifically,
129-
we'll use a :class:`torch.nn.LSTM` to produce a sequence of hidden states that
130-
we'll project to the size of the output vocabulary to predict each target word.
128+
is sometimes called *teacher forcing*. More specifically, we'll use a
129+
:class:`torch.nn.LSTM` to produce a sequence of hidden states that we'll project
130+
to the size of the output vocabulary to predict each target word.
131131

132132
::
133133

@@ -171,7 +171,7 @@ we'll project to the size of the output vocabulary to predict each target word.
171171
"""
172172
Args:
173173
prev_output_tokens (LongTensor): previous decoder outputs of shape
174-
`(batch, tgt_len)`, for input feeding/teacher forcing
174+
`(batch, tgt_len)`, for teacher forcing
175175
encoder_out (Tensor, optional): output from the encoder, used for
176176
encoder-side attention
177177

@@ -387,8 +387,8 @@ previous hidden states.
387387

388388
In fairseq this is called :ref:`Incremental decoding`. Incremental decoding is a
389389
special mode at inference time where the Model only receives a single timestep
390-
of input corresponding to the immediately previous output token (for input
391-
feeding) and must produce the next output incrementally. Thus the model must
390+
of input corresponding to the immediately previous output token (for teacher
391+
forcing) and must produce the next output incrementally. Thus the model must
392392
cache any long-term state that is needed about the sequence, e.g., hidden
393393
states, convolutional states, etc.
394394

fairseq/data/language_pair_dataset.py

+5-6
Original file line numberDiff line numberDiff line change
@@ -88,8 +88,7 @@ class LanguagePairDataset(FairseqDataset):
8888
shuffle (bool, optional): shuffle dataset elements before batching
8989
(default: True).
9090
input_feeding (bool, optional): create a shifted version of the targets
91-
to be passed into the model for input feeding/teacher forcing
92-
(default: True).
91+
to be passed into the model for teacher forcing (default: True).
9392
remove_eos_from_source (bool, optional): if set, removes eos from end
9493
of source if it's present (default: False).
9594
append_eos_to_target (bool, optional): if set, appends eos to end of
@@ -167,10 +166,10 @@ def collater(self, samples):
167166
- `src_lengths` (LongTensor): 1D Tensor of the unpadded
168167
lengths of each source sentence of shape `(bsz)`
169168
- `prev_output_tokens` (LongTensor): a padded 2D Tensor of
170-
tokens in the target sentence, shifted right by one position
171-
for input feeding/teacher forcing, of shape `(bsz,
172-
tgt_len)`. This key will not be present if *input_feeding*
173-
is ``False``. Padding will appear on the left if
169+
tokens in the target sentence, shifted right by one
170+
position for teacher forcing, of shape `(bsz, tgt_len)`.
171+
This key will not be present if *input_feeding* is
172+
``False``. Padding will appear on the left if
174173
*left_pad_target* is ``True``.
175174
176175
- `target` (LongTensor): a padded 2D Tensor of tokens in the

fairseq/models/fairseq_decoder.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ def forward(self, prev_output_tokens, encoder_out=None, **kwargs):
2222
"""
2323
Args:
2424
prev_output_tokens (LongTensor): shifted output tokens of shape
25-
`(batch, tgt_len)`, for input feeding/teacher forcing
25+
`(batch, tgt_len)`, for teacher forcing
2626
encoder_out (dict, optional): output from the encoder, used for
2727
encoder-side attention
2828

fairseq/models/fairseq_incremental_decoder.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ class FairseqIncrementalDecoder(FairseqDecoder):
1313
1414
Incremental decoding is a special mode at inference time where the Model
1515
only receives a single timestep of input corresponding to the previous
16-
output token (for input feeding) and must produce the next output
16+
output token (for teacher forcing) and must produce the next output
1717
*incrementally*. Thus the model must cache any long-term state that is
1818
needed about the sequence, e.g., hidden states, convolutional states, etc.
1919
@@ -37,7 +37,7 @@ def forward(self, prev_output_tokens, encoder_out=None, incremental_state=None,
3737
"""
3838
Args:
3939
prev_output_tokens (LongTensor): shifted output tokens of shape
40-
`(batch, tgt_len)`, for input feeding/teacher forcing
40+
`(batch, tgt_len)`, for teacher forcing
4141
encoder_out (dict, optional): output from the encoder, used for
4242
encoder-side attention
4343
incremental_state (dict, optional): dictionary used for storing

fairseq/models/fairseq_model.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -202,8 +202,8 @@ def forward(self, src_tokens, src_lengths, prev_output_tokens, **kwargs):
202202
Run the forward pass for an encoder-decoder model.
203203
204204
First feed a batch of source tokens through the encoder. Then, feed the
205-
encoder output and previous decoder outputs (i.e., input feeding/teacher
206-
forcing) to the decoder to produce the next outputs::
205+
encoder output and previous decoder outputs (i.e., teacher forcing) to
206+
the decoder to produce the next outputs::
207207
208208
encoder_out = self.encoder(src_tokens, src_lengths)
209209
return self.decoder(prev_output_tokens, encoder_out)
@@ -213,7 +213,7 @@ def forward(self, src_tokens, src_lengths, prev_output_tokens, **kwargs):
213213
`(batch, src_len)`
214214
src_lengths (LongTensor): source sentence lengths of shape `(batch)`
215215
prev_output_tokens (LongTensor): previous decoder outputs of shape
216-
`(batch, tgt_len)`, for input feeding/teacher forcing
216+
`(batch, tgt_len)`, for teacher forcing
217217
218218
Returns:
219219
tuple:

fairseq/models/lightconv.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -345,7 +345,7 @@ def forward(self, prev_output_tokens, encoder_out=None, incremental_state=None):
345345
"""
346346
Args:
347347
prev_output_tokens (LongTensor): previous decoder outputs of shape
348-
`(batch, tgt_len)`, for input feeding/teacher forcing
348+
`(batch, tgt_len)`, for teacher forcing
349349
encoder_out (Tensor, optional): output from the encoder, used for
350350
encoder-side attention
351351
incremental_state (dict): dictionary used for storing state during

fairseq/models/transformer.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -370,7 +370,7 @@ def forward(self, prev_output_tokens, encoder_out=None, incremental_state=None,
370370
"""
371371
Args:
372372
prev_output_tokens (LongTensor): previous decoder outputs of shape
373-
`(batch, tgt_len)`, for input feeding/teacher forcing
373+
`(batch, tgt_len)`, for teacher forcing
374374
encoder_out (Tensor, optional): output from the encoder, used for
375375
encoder-side attention
376376
incremental_state (dict): dictionary used for storing state during

0 commit comments

Comments
 (0)