Skip to content

Commit

Permalink
Merge pull request #719 from allenai/shanea/hf-olmo-gradient-checkpoi…
Browse files Browse the repository at this point in the history
…nting

[HF OLMo] Add flash attention and gradient checkpointing support
  • Loading branch information
2015aroras authored Sep 11, 2024
2 parents 0b92077 + 90ef327 commit 47f8f5a
Show file tree
Hide file tree
Showing 5 changed files with 142 additions and 14 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Added

- Added ability to try loading latest checkpoint from save folder using `--try_load_latest_save`.
- Added support for flash attention and gradient checkpointing to `hf_olmo`.

## [v0.5.0](https://github.com/allenai/OLMo/releases/tag/v0.5.0) - 2024-08-26

Expand Down
36 changes: 34 additions & 2 deletions hf_olmo/modeling_olmo.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
import logging
from dataclasses import fields
from typing import List, Optional, Tuple, Union
from typing import Callable, List, Optional, Tuple, Union

import torch
from transformers import PreTrainedModel
from transformers.cache_utils import Cache
from transformers.modeling_outputs import CausalLMOutputWithPast
from transformers.models.auto import AutoModelForCausalLM

from olmo.config import ModelConfig
from olmo.config import ActivationCheckpointingStrategy, ModelConfig
from olmo.model import OLMo

from .configuration_olmo import OLMoConfig
Expand All @@ -26,6 +26,15 @@ def create_model_config_from_pretrained_config(config: OLMoConfig):
kwargs[field.name] = getattr(config, field.name)

model_config = ModelConfig(**kwargs)

# Handle flash attention settings
if config._attn_implementation == "flash_attention_2":
model_config.flash_attention = True
elif config._attn_implementation in ("eager", "sdpa"):
model_config.flash_attention = False
else:
raise ValueError(f"Unexpected _attn_implementation {config._attn_implementation}")

return model_config


Expand All @@ -37,10 +46,16 @@ class OLMoForCausalLM(PreTrainedModel):
config_class = OLMoConfig
base_model_prefix = "model"
_no_split_modules = ["OLMoBlock"]
_supports_flash_attn_2 = True
_supports_sdpa = True
supports_gradient_checkpointing = True

def __init__(self, config: OLMoConfig, model: Optional[OLMo] = None, init_params: bool = False):
super().__init__(config)

self._gradient_checkpointing_func: Optional[Callable] = None
self._gradient_checkpointing = False

if not model:
model_config = create_model_config_from_pretrained_config(config)
# Initialize model (always on CPU to start with so we don't run out of GPU memory).
Expand All @@ -49,6 +64,23 @@ def __init__(self, config: OLMoConfig, model: Optional[OLMo] = None, init_params
else:
self.model = model

@property
def gradient_checkpointing(self) -> bool:
return self._gradient_checkpointing

@gradient_checkpointing.setter
def gradient_checkpointing(self, enabled: bool):
if self._gradient_checkpointing == enabled:
return

# HF does not specify a way to pass checkpointing strategies, so we pick
# whole layer as our strategy. We can make this configurable later if needed.
checkpointing_strategy = ActivationCheckpointingStrategy.whole_layer if enabled else None
self.model.set_activation_checkpointing(
checkpointing_strategy, checkpoint_func=self._gradient_checkpointing_func
)
self._gradient_checkpointing = enabled

def forward(
self,
input_ids: torch.LongTensor = None,
Expand Down
30 changes: 20 additions & 10 deletions olmo/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,7 +418,7 @@ def __init__(self, layer_id: int, config: ModelConfig, cache: BufferCache):
self.__cache = cache
assert config.d_model % config.n_heads == 0

self._activation_checkpoint_fn = None
self._activation_checkpoint_fn: Optional[Callable] = None

# Dropout.
self.dropout = Dropout(config.residual_dropout)
Expand Down Expand Up @@ -500,9 +500,11 @@ def reset_parameters(self):
init_normal(self.attn_out, std=attn_out_std, init_cutoff_factor=cutoff_factor)
init_normal(self.ff_out, std=ff_out_std, init_cutoff_factor=cutoff_factor)

def set_activation_checkpointing(self, strategy: Optional[ActivationCheckpointingStrategy]):
def set_activation_checkpointing(
self, strategy: Optional[ActivationCheckpointingStrategy], checkpoint_func: Optional[Callable] = None
):
if strategy == ActivationCheckpointingStrategy.fine_grained:
self._activation_checkpoint_fn = activation_checkpoint_function(self.config)
self._activation_checkpoint_fn = checkpoint_func or activation_checkpoint_function(self.config)
else:
self._activation_checkpoint_fn = None

Expand Down Expand Up @@ -980,7 +982,7 @@ class OLMoOutput(NamedTuple):
Attention keys and values from each block.
"""

hidden_states: Optional[Tuple[torch.Tensor]]
hidden_states: Optional[Tuple[torch.Tensor, ...]]
"""
Hidden states from each block.
"""
Expand Down Expand Up @@ -1050,10 +1052,12 @@ def reset_parameters(self):
for block in self:
block.reset_parameters()

def set_activation_checkpointing(self, strategy: Optional[ActivationCheckpointingStrategy]):
def set_activation_checkpointing(
self, strategy: Optional[ActivationCheckpointingStrategy], checkpoint_func: Optional[Callable] = None
):
self.activation_checkpointing_strategy = strategy
for block in self:
block.set_activation_checkpointing(strategy)
block.set_activation_checkpointing(strategy, checkpoint_func=checkpoint_func)


class OLMo(nn.Module):
Expand Down Expand Up @@ -1140,14 +1144,16 @@ def __init__(self, config: ModelConfig, init_params: bool = True):
get_causal_attention_bias(self.__cache, config.max_sequence_length, _non_meta_init_device(config))
self.get_alibi_attention_bias(config.max_sequence_length, _non_meta_init_device(config))

def set_activation_checkpointing(self, strategy: Optional[ActivationCheckpointingStrategy]):
def set_activation_checkpointing(
self, strategy: Optional[ActivationCheckpointingStrategy], checkpoint_func: Optional[Callable] = None
):
self.activation_checkpointing_strategy = strategy
if self.config.block_group_size != 1:
for block_group in self.transformer.block_groups:
block_group.set_activation_checkpointing(strategy)
block_group.set_activation_checkpointing(strategy, checkpoint_func=checkpoint_func)
else:
for block in self.transformer.blocks:
block.set_activation_checkpointing(strategy)
block.set_activation_checkpointing(strategy, checkpoint_func=checkpoint_func)

@property
def device(self) -> torch.device:
Expand Down Expand Up @@ -1445,7 +1451,11 @@ def forward(
if self.config.scale_logits:
logits.mul_(1 / math.sqrt(self.config.d_model))

return OLMoOutput(logits=logits, attn_key_values=attn_key_values, hidden_states=tuple(all_hidden_states) if output_hidden_states else None) # type: ignore[arg-type]
return OLMoOutput(
logits=logits,
attn_key_values=attn_key_values,
hidden_states=tuple(all_hidden_states) if output_hidden_states else None,
)

def get_fsdp_wrap_policy(self, wrap_strategy: Optional[FSDPWrapStrategy] = None):
if wrap_strategy is None:
Expand Down
8 changes: 7 additions & 1 deletion test_fixtures/test-olmo-model/config.json
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@
"block_type": "sequential",
"clip_qkv": null,
"d_model": 32,
"emb_init_std": null,
"embedding_dropout": 0.1,
"embedding_layer_norm": false,
"embedding_size": 50304,
"eos_token_id": 50256,
"flash_attention": false,
Expand All @@ -22,6 +24,7 @@
"init_device": null,
"init_fn": "normal",
"init_std": 0.02,
"layer_norm_eps": 1e-05,
"layer_norm_type": "default",
"layer_norm_with_affine": true,
"max_sequence_length": 1024,
Expand All @@ -32,13 +35,16 @@
"n_heads": 1,
"n_kv_heads": null,
"n_layers": 1,
"norm_after": false,
"pad_token_id": 50256,
"precision": null,
"residual_dropout": 0.1,
"rope": false,
"rope_full_precision": true,
"rope_theta": 10000,
"scale_emb_init": false,
"scale_logits": false,
"transformers_version": "4.40.2",
"transformers_version": "4.44.2",
"use_cache": true,
"vocab_size": 50257,
"weight_tying": true
Expand Down
81 changes: 80 additions & 1 deletion tests/hf_olmo/modeling_olmo_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,86 @@ def test_olmo_model(model_path: str):
output = model(input_tensor)
hf_output = hf_model(input_tensor)

torch.testing.assert_allclose(output.logits, hf_output.logits)
torch.testing.assert_close(hf_output.logits, output.logits)


@pytest.mark.gpu
@pytest.mark.skipif(torch.cuda.device_count() < 1, reason="Requires CUDA devices")
def test_flash_attention_2(model_path: str):
from transformers import AutoModelForCausalLM, AutoTokenizer

import hf_olmo # noqa: F401

hf_model = AutoModelForCausalLM.from_pretrained(model_path)
hf_model_flash_attn = AutoModelForCausalLM.from_pretrained(model_path, attn_implementation="flash_attention_2")

tokenizer = AutoTokenizer.from_pretrained(model_path)
encoded_input = tokenizer.encode("My name is OLMo!")
input_tensor = torch.tensor(encoded_input).unsqueeze(0)

hf_output = hf_model(input_tensor)
hf_output_flash_attn = hf_model_flash_attn(input_tensor)

torch.testing.assert_close(hf_output_flash_attn.logits, hf_output.logits)


def test_sdpa(model_path: str):
from transformers import AutoModelForCausalLM, AutoTokenizer

import hf_olmo # noqa: F401

hf_model = AutoModelForCausalLM.from_pretrained(model_path)
hf_model_sdpa = AutoModelForCausalLM.from_pretrained(model_path, attn_implementation="sdpa")

tokenizer = AutoTokenizer.from_pretrained(model_path)
encoded_input = tokenizer.encode("My name is OLMo!")
input_tensor = torch.tensor(encoded_input).unsqueeze(0)

hf_output = hf_model(input_tensor)
hf_output_sdpa = hf_model_sdpa(input_tensor)

torch.testing.assert_close(hf_output_sdpa.logits, hf_output.logits)


def test_gradient_checkpointing(model_path: str):
from transformers import AutoModelForCausalLM, AutoTokenizer, PreTrainedModel

import hf_olmo # noqa: F401

hf_model: PreTrainedModel = AutoModelForCausalLM.from_pretrained(model_path)

tokenizer = AutoTokenizer.from_pretrained(model_path)
encoded_input = tokenizer.encode("My name is OLMo!")
input_tensor = torch.tensor(encoded_input).unsqueeze(0)

hf_output_no_checkpointing = hf_model(input_tensor)

hf_model.gradient_checkpointing_enable()

hf_output_checkpointing = hf_model(input_tensor)

torch.testing.assert_close(hf_output_checkpointing.logits, hf_output_no_checkpointing.logits)


def test_gradient_checkpointing_disable(model_path: str):
from transformers import AutoModelForCausalLM, AutoTokenizer, PreTrainedModel

import hf_olmo # noqa: F401

hf_model: PreTrainedModel = AutoModelForCausalLM.from_pretrained(model_path)

tokenizer = AutoTokenizer.from_pretrained(model_path)
encoded_input = tokenizer.encode("My name is OLMo!")
input_tensor = torch.tensor(encoded_input).unsqueeze(0)

hf_output = hf_model(input_tensor)

hf_model.gradient_checkpointing_enable()
hf_model.gradient_checkpointing_disable()

hf_output_after_disable = hf_model(input_tensor)

torch.testing.assert_close(hf_output_after_disable.logits, hf_output.logits)


def test_save_pretrained(model_path: str):
Expand Down

0 comments on commit 47f8f5a

Please sign in to comment.