Skip to content

Commit

Permalink
Avoid needing repad_logits_with_grad, always repad with grads when tr…
Browse files Browse the repository at this point in the history
…aining

I'm not 100% that the conditional with "or labels is None" makes sense though - not sure what the intention is there. Perhaps we can remove that?
  • Loading branch information
tomaarsen committed Jan 9, 2025
1 parent ab11657 commit cedcb4e
Show file tree
Hide file tree
Showing 3 changed files with 2 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -109,9 +109,6 @@ class ModernBertConfig(PretrainedConfig):
the model will be compiled if 1) `triton` is installed, 2) the model is not on MPS, 3) the model is not
shared between devices, and 4) the model is not resized after initialization. If `True`, then the model may
be faster in some scenarios.
repad_logits_with_grad (`bool`, *optional*, defaults to `False`):
When True, ModernBertForMaskedLM keep track of the logits' gradient when repadding for output. This only
applies when using Flash Attention 2 with passed labels. Otherwise output logits always have a gradient.
Examples:
Expand Down Expand Up @@ -167,7 +164,6 @@ def __init__(
sparse_prediction=False,
sparse_pred_ignore_index=-100,
reference_compile=None,
repad_logits_with_grad=False,
**kwargs,
):
super().__init__(
Expand Down Expand Up @@ -207,7 +203,6 @@ def __init__(
self.sparse_prediction = sparse_prediction
self.sparse_pred_ignore_index = sparse_pred_ignore_index
self.reference_compile = reference_compile
self.repad_logits_with_grad = repad_logits_with_grad

if self.classifier_pooling not in ["cls", "mean"]:
raise ValueError(
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/modernbert/modeling_modernbert.py
Original file line number Diff line number Diff line change
Expand Up @@ -1104,7 +1104,7 @@ def forward(
loss = self.loss_function(logits, labels, vocab_size=self.config.vocab_size)

if self.config._attn_implementation == "flash_attention_2":
with nullcontext() if self.config.repad_logits_with_grad or labels is None else torch.no_grad():
with nullcontext() if self.training or labels is None else torch.no_grad():
logits = _pad_modernbert_output(inputs=logits, indices=indices, batch=batch_size, seqlen=seq_len)

if not return_dict:
Expand Down
7 changes: 1 addition & 6 deletions src/transformers/models/modernbert/modular_modernbert.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,9 +142,6 @@ class ModernBertConfig(PretrainedConfig):
the model will be compiled if 1) `triton` is installed, 2) the model is not on MPS, 3) the model is not
shared between devices, and 4) the model is not resized after initialization. If `True`, then the model may
be faster in some scenarios.
repad_logits_with_grad (`bool`, *optional*, defaults to `False`):
When True, ModernBertForMaskedLM keep track of the logits' gradient when repadding for output. This only
applies when using Flash Attention 2 with passed labels. Otherwise output logits always have a gradient.
Examples:
Expand Down Expand Up @@ -200,7 +197,6 @@ def __init__(
sparse_prediction=False,
sparse_pred_ignore_index=-100,
reference_compile=None,
repad_logits_with_grad=False,
**kwargs,
):
super().__init__(
Expand Down Expand Up @@ -240,7 +236,6 @@ def __init__(
self.sparse_prediction = sparse_prediction
self.sparse_pred_ignore_index = sparse_pred_ignore_index
self.reference_compile = reference_compile
self.repad_logits_with_grad = repad_logits_with_grad

if self.classifier_pooling not in ["cls", "mean"]:
raise ValueError(
Expand Down Expand Up @@ -1262,7 +1257,7 @@ def forward(
loss = self.loss_function(logits, labels, vocab_size=self.config.vocab_size)

if self.config._attn_implementation == "flash_attention_2":
with nullcontext() if self.config.repad_logits_with_grad or labels is None else torch.no_grad():
with nullcontext() if self.training or labels is None else torch.no_grad():
logits = _pad_modernbert_output(inputs=logits, indices=indices, batch=batch_size, seqlen=seq_len)

if not return_dict:
Expand Down

0 comments on commit cedcb4e

Please sign in to comment.