Skip to content

Commit

Permalink
[RWKV] Add Gradient Checkpointing support for RWKV (#24955)
Browse files Browse the repository at this point in the history
add GC support for RWKV
  • Loading branch information
younesbelkada authored Jul 20, 2023
1 parent 9f912ef commit 89a1f34
Showing 1 changed file with 27 additions and 3 deletions.
30 changes: 27 additions & 3 deletions src/transformers/models/rwkv/modeling_rwkv.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,6 +406,7 @@ class RwkvPreTrainedModel(PreTrainedModel):
base_model_prefix = "rwkv"
_no_split_modules = ["RwkvBlock"]
_keep_in_fp32_modules = ["time_decay", "time_first"]
supports_gradient_checkpointing = True

def _init_weights(self, module):
"""Initialize the weights."""
Expand Down Expand Up @@ -605,6 +606,8 @@ def __init__(self, config):

self.layers_are_rescaled = False

self.gradient_checkpointing = False

# Initialize weights and apply final processing
self.post_init()

Expand Down Expand Up @@ -659,14 +662,35 @@ def forward(
]
state[4] -= 1e30

if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False

hidden_states = inputs_embeds

all_self_attentions = () if output_attentions else None
all_hidden_states = () if output_hidden_states else None
for idx, block in enumerate(self.blocks):
hidden_states, state, attentions = block(
hidden_states, state=state, use_cache=use_cache, output_attentions=output_attentions
)
if self.gradient_checkpointing and self.training:

def create_custom_forward(module):
def custom_forward(*inputs):
# None for past_key_value
return module(*inputs, use_cache=use_cache, output_attentions=output_attentions)

return custom_forward

hidden_states, state, attentions = torch.utils.checkpoint.checkpoint(
create_custom_forward(block), hidden_states, state
)
else:
hidden_states, state, attentions = block(
hidden_states, state=state, use_cache=use_cache, output_attentions=output_attentions
)

if (
self.layers_are_rescaled
and self.config.rescale_every > 0
Expand Down

0 comments on commit 89a1f34

Please sign in to comment.