Skip to content

Commit

Permalink
merge Fix skip generation #7270
Browse files Browse the repository at this point in the history
Signed-off-by: Evelina <ebakhturina@nvidia.com>
  • Loading branch information
ekmb committed Aug 21, 2023
1 parent ae3f7d2 commit f54caae
Showing 1 changed file with 18 additions and 16 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -387,22 +387,24 @@ def inference_step(self, dataloader_iter, batch_idx, mode, dataloader_idx=0):
metadata = batch.pop('metadata')
loss = super().validation_step(itertools.chain([batch]), batch_idx)

# We need _inference_config to get generation params
# add_BOS and tokens_to_generate are set in dataset
if self.get_inference_config() is None:
self.set_inference_config(inference_config={})
self._inference_config['add_BOS'] = data_cfg.add_bos
self._inference_config['tokens_to_generate'] = data_cfg.get('tokens_to_generate')

output = self.predict_step(batch, batch_idx, dataloader_idx)

inputs_text = [self.tokenizer.ids_to_text(c.tolist()) for c in batch['contexts']]
labels_text = [self.tokenizer.ids_to_text(a.tolist()) for a in batch['answers']]
preds_text = [
self.tokenizer.ids_to_text(t[l.item() :][: data_cfg.get('tokens_to_generate')])
for t, l in zip(output['token_ids'], batch['context_lengths'])
]

if data_cfg.get("write_predictions_to_file", False) or data_cfg.metric.name != 'loss':
# We need _inference_config to get generation params
# add_BOS and tokens_to_generate are set in dataset
if self.get_inference_config() is None:
self.set_inference_config(inference_config={})
self._inference_config['add_BOS'] = data_cfg.add_bos
self._inference_config['tokens_to_generate'] = data_cfg.get('tokens_to_generate')

output = self.predict_step(batch, batch_idx, dataloader_idx)

inputs_text = [self.tokenizer.ids_to_text(c.tolist()) for c in batch['contexts']]
labels_text = [self.tokenizer.ids_to_text(a.tolist()) for a in batch['answers']]
preds_text = [
self.tokenizer.ids_to_text(t[l.item() :][: data_cfg.get('tokens_to_generate')])
for t, l in zip(output['token_ids'], batch['context_lengths'])
]
else:
inputs_text, labels_text, preds_text = [], [], []
return {
'loss': loss,
'preds': preds_text, # [str]
Expand Down

0 comments on commit f54caae

Please sign in to comment.