diff --git a/src/transformers/models/luke/modeling_luke.py b/src/transformers/models/luke/modeling_luke.py index 6d40dfafe8e477..befeaccd55f6ee 100644 --- a/src/transformers/models/luke/modeling_luke.py +++ b/src/transformers/models/luke/modeling_luke.py @@ -226,7 +226,7 @@ class EntitySpanClassificationOutput(ModelOutput): Args: loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): Classification loss. - logits (`torch.FloatTensor` of shape `(batch_size, config.num_labels)`): + logits (`torch.FloatTensor` of shape `(batch_size, entity_length, config.num_labels)`): Classification scores (before SoftMax). hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of