Skip to content

Commit

Permalink
Fix incomplete outputs of FlaxBert (#18772)
Browse files Browse the repository at this point in the history
* Fix incomplete FlaxBert outputs

* fix big_bird electra roberta
  • Loading branch information
duongna21 authored Aug 26, 2022
1 parent 62ceb4d commit 21f6f58
Show file tree
Hide file tree
Showing 4 changed files with 4 additions and 4 deletions.
2 changes: 1 addition & 1 deletion src/transformers/models/bert/modeling_flax_bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -612,7 +612,7 @@ def __call__(
if output_hidden_states:
all_hidden_states += (hidden_states,)

outputs = (hidden_states,)
outputs = (hidden_states, all_hidden_states, all_attentions, all_cross_attentions)

if not return_dict:
return tuple(v for v in outputs if v is not None)
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/big_bird/modeling_flax_big_bird.py
Original file line number Diff line number Diff line change
Expand Up @@ -1438,7 +1438,7 @@ def __call__(
if output_hidden_states:
all_hidden_states += (hidden_states,)

outputs = (hidden_states,)
outputs = (hidden_states, all_hidden_states, all_attentions, all_cross_attentions)

if not return_dict:
return tuple(v for v in outputs if v is not None)
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/electra/modeling_flax_electra.py
Original file line number Diff line number Diff line change
Expand Up @@ -590,7 +590,7 @@ def __call__(
if output_hidden_states:
all_hidden_states += (hidden_states,)

outputs = (hidden_states,)
outputs = (hidden_states, all_hidden_states, all_attentions, all_cross_attentions)

if not return_dict:
return tuple(v for v in outputs if v is not None)
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/roberta/modeling_flax_roberta.py
Original file line number Diff line number Diff line change
Expand Up @@ -580,7 +580,7 @@ def __call__(
if output_hidden_states:
all_hidden_states += (hidden_states,)

outputs = (hidden_states,)
outputs = (hidden_states, all_hidden_states, all_attentions, all_cross_attentions)

if not return_dict:
return tuple(v for v in outputs if v is not None)
Expand Down

0 comments on commit 21f6f58

Please sign in to comment.