Skip to content

Commit

Permalink
Add Gemma2 support. (#562)
Browse files Browse the repository at this point in the history
  • Loading branch information
radi-cho authored Aug 4, 2024
1 parent 202b967 commit b9e9b73
Show file tree
Hide file tree
Showing 8 changed files with 320 additions and 2 deletions.
1 change: 1 addition & 0 deletions awq/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from .mixtral import MixtralAWQForCausalLM
from .qwen2 import Qwen2AWQForCausalLM
from .gemma import GemmaAWQForCausalLM
from .gemma2 import Gemma2AWQForCausalLM
from .stablelm import StableLmAWQForCausalLM
from .starcoder2 import Starcoder2AWQForCausalLM
from .llava_next import LlavaNextAWQForCausalLM
Expand Down
1 change: 1 addition & 0 deletions awq/models/auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
"llava": LlavaAWQForCausalLM,
"qwen2": Qwen2AWQForCausalLM,
"gemma": GemmaAWQForCausalLM,
"gemma2": Gemma2AWQForCausalLM,
"stablelm": StableLmAWQForCausalLM,
"starcoder2": Starcoder2AWQForCausalLM,
"llava_next": LlavaNextAWQForCausalLM,
Expand Down
1 change: 1 addition & 0 deletions awq/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@
"llava": "AutoModelForVision2Seq",
"qwen2": "AutoModelForCausalLM",
"gemma": "AutoModelForCausalLM",
"gemma2": "AutoModelForCausalLM",
"stablelm": "AutoModelForCausalLM",
"starcoder2": "AutoModelForCausalLM",
"llava_next": "AutoModelForVision2Seq",
Expand Down
159 changes: 159 additions & 0 deletions awq/models/gemma2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
import copy
import tqdm
import torch
from typing import List, Tuple
from .base import BaseAWQForCausalLM
from awq.utils.fused_utils import fuse_qkv
from awq.modules.fused.block import Gemma2LikeBlock
from awq.modules.fused.model import Gemma2LikeModel
from transformers.models.gemma2.modeling_gemma2 import (
Gemma2DecoderLayer as OldGemmaDecoderLayer,
Gemma2ForCausalLM as OldGemmaForCausalLM,
)
from awq.modules.fused.norm import FasterTransformerRMSNorm


class Gemma2AWQForCausalLM(BaseAWQForCausalLM):
layer_type = "Gemma2DecoderLayer"
max_new_tokens_key = "max_position_embeddings"

@staticmethod
def fuse_layers(model: OldGemmaDecoderLayer):
fuser = GemmaFuser(model)
fuser.fuse_transformer()

@staticmethod
def get_model_layers(model: OldGemmaForCausalLM):
return model.model.layers

@staticmethod
def get_act_for_scaling(module: OldGemmaDecoderLayer):
return dict(is_scalable=False)

@staticmethod
def move_embed(model: OldGemmaForCausalLM, device: str):
model.model.embed_tokens = model.model.embed_tokens.to(device)

@staticmethod
def get_layers_for_scaling(module: OldGemmaDecoderLayer, input_feat, module_kwargs):
layers = []

# attention input
layers.append(
dict(
prev_op=module.input_layernorm,
layers=[
module.self_attn.q_proj,
module.self_attn.k_proj,
module.self_attn.v_proj,
],
inp=input_feat["self_attn.q_proj"],
module2inspect=module.self_attn,
kwargs=module_kwargs,
)
)

# attention out
# Please refer to https://github.com/mit-han-lab/llm-awq/pull/67#issue-1850622696
if module.self_attn.v_proj.weight.shape == module.self_attn.o_proj.weight.shape:
layers.append(
dict(
prev_op=module.self_attn.v_proj,
layers=[module.self_attn.o_proj],
inp=input_feat["self_attn.o_proj"],
)
)

layers.append(
dict(
prev_op=module.pre_feedforward_layernorm,
layers=[module.mlp.gate_proj, module.mlp.up_proj],
inp=input_feat["mlp.gate_proj"],
module2inspect=module.mlp,
)
)

layers.append(
dict(
prev_op=module.mlp.up_proj,
layers=[module.mlp.down_proj],
inp=input_feat["mlp.down_proj"],
)
)

return layers

class GemmaFuser:
def __init__(self, model: OldGemmaForCausalLM):
self.model = model

self.Gemma_blocks: List[Tuple[str, OldGemmaDecoderLayer]] = [
(name, module)
for name, module in self.model.named_modules()
if "Gemma2DecoderLayer".lower() in module.__class__.__name__.lower() #Gemma2DecoderLayer
]

def fuse_transformer(self):
blocks = []

module: OldGemmaDecoderLayer
for module in tqdm.tqdm(self.model.model.layers, desc="Fusing layers..."):
device = next(iter(module.state_dict().values())).device
qkv = fuse_qkv(
module,
module.self_attn.q_proj,
module.self_attn.k_proj,
module.self_attn.v_proj,
)
with torch.no_grad():
# GemmaRMSNorm is different from Llama's in that it multiplies
# (1 + weight) to the output, instead of just weight.
module.input_layernorm.weight += 1
module.post_attention_layernorm.weight += 1
module.pre_feedforward_layernorm.weight += 1
module.post_feedforward_layernorm.weight += 1

norm_1 = FasterTransformerRMSNorm(
module.input_layernorm.weight, module.input_layernorm.eps
)
norm_2 = FasterTransformerRMSNorm(
module.post_attention_layernorm.weight,
module.post_attention_layernorm.eps,
)
norm_3 = FasterTransformerRMSNorm(
module.pre_feedforward_layernorm.weight,
module.pre_feedforward_layernorm.eps
)
norm_4 = FasterTransformerRMSNorm(
module.post_feedforward_layernorm.weight,
module.post_feedforward_layernorm.eps,
)
blocks.append(
Gemma2LikeBlock(
hidden_size=self.model.config.hidden_size,
n_heads=self.model.config.num_attention_heads,
n_kv_heads=self.model.config.num_key_value_heads,
qkv_layer=qkv,
o_proj=module.self_attn.o_proj,
mlp=module.mlp,
norm_1=norm_1,
norm_2=norm_2,
norm_3=norm_3,
norm_4=norm_4,
dev=device,
max_seq_len=self.model.config.max_seq_len,
rope_theta=self.model.config.rope_theta,
head_dim=self.model.config.head_dim,
attn_logit_softcapping=self.model.config.attn_logit_softcapping,
)
)

self.model.model = Gemma2LikeModel(
self.model.config.vocab_size,
blocks,
self.model.model.embed_tokens,
self.model.model.norm,
self.model.config.hidden_size,
)

setattr(self.model.model, "blocks", self.model.model.blocks)
9 changes: 9 additions & 0 deletions awq/modules/fused/attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@ def __init__(
rope_theta=10000,
partial_rotary_factor=1.0,
head_dim=None,
attn_logit_softcapping=None,
**kwargs
):
super().__init__()
Expand Down Expand Up @@ -173,6 +174,8 @@ def __init__(

if kwargs.get("is_neox") is not None:
self.is_neox = kwargs["is_neox"]

self.attn_logit_softcapping = attn_logit_softcapping

def forward(
self, hidden_states: torch.Tensor, attention_mask=None, *args, **kwargs
Expand Down Expand Up @@ -265,6 +268,12 @@ def forward(
values = values.transpose(1, 2)
scores = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim)

# Used in Gemma2
if self.attn_logit_softcapping is not None:
scores = scores / self.attn_logit_softcapping
scores = torch.tanh(scores)
scores = scores * self.attn_logit_softcapping

if self.use_alibi:
scores = self.alibi.forward(scores, seqlen)

Expand Down
81 changes: 81 additions & 0 deletions awq/modules/fused/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,87 @@ def forward(
return out, None, past_key_value


class Gemma2LikeBlock(nn.Module):
def __init__(
self,
hidden_size,
n_heads,
n_kv_heads,
qkv_layer,
o_proj,
mlp,
norm_1,
norm_2,
norm_3,
norm_4,
dev,
max_seq_len,
rope_theta=10000,
partial_rotary_factor=1.0,
use_alibi=False,
head_dim=None,
attn_logit_softcapping=None,
):
super().__init__()
self.n_heads = n_heads
self.n_kv_heads = n_kv_heads
self.head_dim = hidden_size // n_heads

if head_dim:
self.head_dim = head_dim

self.hidden_size = hidden_size
self.norm_1 = norm_1.to(dev)
self.attn = QuantAttentionFused(
self.hidden_size,
self.n_heads,
self.n_kv_heads,
qkv_layer,
o_proj,
dev=dev,
max_seq_len=max_seq_len,
use_alibi=use_alibi,
rope_theta=rope_theta,
partial_rotary_factor=partial_rotary_factor,
head_dim=head_dim,
attn_logit_softcapping=attn_logit_softcapping,
).to(dev)

self.norm_2 = norm_2.to(dev)
self.norm_3 = norm_3.to(dev)
self.mlp = mlp.to(dev)
self.norm_4 = norm_4.to(dev)
self.device = dev

def forward(
self,
hidden_states,
past_key_value,
attn_bias=None,
attention_mask=None,
is_causal=None,
):
residual = hidden_states
hidden_states = self.norm_1(hidden_states)

hidden_states, _, past_key_value = self.attn.forward(
hidden_states=hidden_states,
past_key_value=past_key_value,
attention_mask=attention_mask,
)

hidden_states = self.norm_2(hidden_states)
hidden_states = residual + hidden_states

residual = hidden_states
hidden_states = self.norm_3(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = self.norm_4(hidden_states)
out = residual + hidden_states

return out, None, past_key_value


class CohereBlock(nn.Module):
def __init__(
self,
Expand Down
65 changes: 65 additions & 0 deletions awq/modules/fused/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
MixtralBlock,
Phi3Block,
CohereBlock,
Gemma2LikeBlock,
)


Expand Down Expand Up @@ -373,3 +374,67 @@ def forward(
hidden_states=(),
attentions=(),
)


class Gemma2LikeModel(nn.Module):
def __init__(self, vocab_size, blocks, embedding, norm, hidden_size):
super().__init__()
self.vocab_size = vocab_size
self.embedding = embedding
self.blocks: List[Gemma2LikeBlock] = nn.ModuleList(blocks)
self.norm = norm
self.last_forward_num_tokens = 0
self.hidden_size = hidden_size

@property
def embed_tokens(self):
return self.embedding

@property
def layers(self):
return self.blocks

@torch.inference_mode()
def forward(
self,
input_ids: torch.Tensor,
attn_bias=None,
attention_mask=None,
is_causal=None,
*args,
**kwargs,
):
input_ids, self.last_forward_num_tokens = fused_utils.prepare_input_ids(
input_ids, self.last_forward_num_tokens
)
_bsz, seqlen = input_ids.shape

fused_utils.prepare_cache(self.blocks, seqlen)

h = self.embedding(input_ids)

normalizer = torch.tensor(self.hidden_size**0.5, dtype=h.dtype)
h = h * normalizer

mask = fused_utils.prepare_attention_mask(
seqlen=seqlen,
start_pos=self.blocks[0].attn.start_pos,
device=input_ids.device,
type_as=h,
)

for layer in self.blocks:
h, mask = fused_utils.prepare_correct_devices(
layer,
h,
mask,
)
h, _, _ = layer(h, None, attention_mask=mask, is_causal=is_causal)
h = self.norm(h)

return BaseModelOutputWithPast(
last_hidden_state=h,
past_key_values=None,
hidden_states=(),
attentions=(),
)
Loading

0 comments on commit b9e9b73

Please sign in to comment.