From 94e73f0b2abb1d5303d72231540e922e0484383d Mon Sep 17 00:00:00 2001 From: TechxGenus <jianghao0728@mail.ustc.edu.cn> Date: Mon, 11 Mar 2024 22:15:10 +0800 Subject: [PATCH] Add Gemma Support (#393) --- awq/models/__init__.py | 1 + awq/models/auto.py | 1 + awq/models/base.py | 1 + awq/models/gemma.py | 149 +++++++++++++++++++++++++++++++++++++ awq/modules/fused/attn.py | 13 +++- awq/modules/fused/block.py | 8 ++ awq/modules/fused/model.py | 4 +- awq/quantize/scale.py | 13 +++- 8 files changed, 182 insertions(+), 8 deletions(-) create mode 100644 awq/models/gemma.py diff --git a/awq/models/__init__.py b/awq/models/__init__.py index 14886a24..75542fe4 100644 --- a/awq/models/__init__.py +++ b/awq/models/__init__.py @@ -14,3 +14,4 @@ from .llava import LlavaAWQForCausalLM from .mixtral import MixtralAWQForCausalLM from .qwen2 import Qwen2AWQForCausalLM +from .gemma import GemmaAWQForCausalLM diff --git a/awq/models/auto.py b/awq/models/auto.py index c992061f..1ac6342a 100644 --- a/awq/models/auto.py +++ b/awq/models/auto.py @@ -23,6 +23,7 @@ "baichuan": BaichuanAWQForCausalLM, "llava": LlavaAWQForCausalLM, "qwen2": Qwen2AWQForCausalLM, + "gemma": GemmaAWQForCausalLM, } diff --git a/awq/models/base.py b/awq/models/base.py index 8ef243ab..e5691ae0 100644 --- a/awq/models/base.py +++ b/awq/models/base.py @@ -67,6 +67,7 @@ "baichuan": "AutoModelForCausalLM", "llava": "AutoModelForVision2Seq", "qwen2": "AutoModelForCausalLM", + "gemma": "AutoModelForCausalLM", } diff --git a/awq/models/gemma.py b/awq/models/gemma.py new file mode 100644 index 00000000..b3ed65db --- /dev/null +++ b/awq/models/gemma.py @@ -0,0 +1,149 @@ +import tqdm +import torch +from typing import List, Tuple +from .base import BaseAWQForCausalLM +from awq.utils.fused_utils import fuse_qkv +from awq.modules.fused.block import LlamaLikeBlock +from awq.modules.fused.model import LlamaLikeModel +from transformers.models.gemma.modeling_gemma import ( + GemmaDecoderLayer as OldGemmaDecoderLayer, + GemmaForCausalLM as OldGemmaForCausalLM, +) +from awq.modules.fused.norm import FasterTransformerRMSNorm + + +class GemmaAWQForCausalLM(BaseAWQForCausalLM): + layer_type = "GemmaDecoderLayer" + max_new_tokens_key = "max_position_embeddings" + + @staticmethod + def fuse_layers(model: OldGemmaDecoderLayer): + fuser = GemmaFuser(model) + fuser.fuse_transformer() + + @staticmethod + def get_model_layers(model: OldGemmaForCausalLM): + return model.model.layers + + @staticmethod + def get_act_for_scaling(module: OldGemmaDecoderLayer): + return dict(is_scalable=False) + + @staticmethod + def move_embed(model: OldGemmaForCausalLM, device: str): + model.model.embed_tokens = model.model.embed_tokens.to(device) + + @staticmethod + def get_layers_for_scaling(module: OldGemmaDecoderLayer, input_feat, module_kwargs): + layers = [] + + # attention input + layers.append( + dict( + prev_op=module.input_layernorm, + layers=[ + module.self_attn.q_proj, + module.self_attn.k_proj, + module.self_attn.v_proj, + ], + inp=input_feat["self_attn.q_proj"], + module2inspect=module.self_attn, + kwargs=module_kwargs, + ) + ) + + # attention out + # Please refer to https://github.com/mit-han-lab/llm-awq/pull/67#issue-1850622696 + if module.self_attn.v_proj.weight.shape == module.self_attn.o_proj.weight.shape: + layers.append( + dict( + prev_op=module.self_attn.v_proj, + layers=[module.self_attn.o_proj], + inp=input_feat["self_attn.o_proj"], + ) + ) + + # linear 1 + layers.append( + dict( + prev_op=module.post_attention_layernorm, + layers=[module.mlp.gate_proj, module.mlp.up_proj], + inp=input_feat["mlp.gate_proj"], + module2inspect=module.mlp, + ) + ) + + # linear 2 + layers.append( + dict( + prev_op=module.mlp.up_proj, + layers=[module.mlp.down_proj], + inp=input_feat["mlp.down_proj"], + ) + ) + + return layers + + +class GemmaFuser: + def __init__(self, model: OldGemmaForCausalLM): + self.model = model + + self.Gemma_blocks: List[Tuple[str, OldGemmaDecoderLayer]] = [ + (name, module) + for name, module in self.model.named_modules() + if "GemmaDecoderLayer".lower() in module.__class__.__name__.lower() + ] + + def fuse_transformer(self): + blocks = [] + + module: OldGemmaDecoderLayer + for module in tqdm.tqdm(self.model.model.layers, desc="Fusing layers..."): + device = next(iter(module.state_dict().values())).device + qkv = fuse_qkv( + module, + module.self_attn.q_proj, + module.self_attn.k_proj, + module.self_attn.v_proj, + ) + with torch.no_grad(): + # GemmaRMSNorm is different from Llama's in that it multiplies + # (1 + weight) to the output, instead of just weight. + module.input_layernorm.weight += 1 + module.post_attention_layernorm.weight += 1 + norm_1 = FasterTransformerRMSNorm( + module.input_layernorm.weight, module.input_layernorm.eps + ) + norm_2 = FasterTransformerRMSNorm( + module.post_attention_layernorm.weight, + module.post_attention_layernorm.eps, + ) + blocks.append( + LlamaLikeBlock( + hidden_size=self.model.config.hidden_size, + n_heads=self.model.config.num_attention_heads, + n_kv_heads=self.model.config.num_key_value_heads, + qkv_layer=qkv, + o_proj=module.self_attn.o_proj, + mlp=module.mlp, + norm_1=norm_1, + norm_2=norm_2, + dev=device, + max_seq_len=self.model.config.max_seq_len, + rope_theta=self.model.config.rope_theta, + head_dim=self.model.config.head_dim, + ) + ) + + with torch.no_grad(): + # Normalize Gemma's embedding layer + self.model.model.embed_tokens.weight *= self.model.config.hidden_size**0.5 + + self.model.model = LlamaLikeModel( + self.model.config.vocab_size, + blocks, + self.model.model.embed_tokens, + self.model.model.norm, + ) + setattr(self.model.model, "blocks", self.model.model.blocks) diff --git a/awq/modules/fused/attn.py b/awq/modules/fused/attn.py index f90fd502..f1732ea5 100644 --- a/awq/modules/fused/attn.py +++ b/awq/modules/fused/attn.py @@ -25,12 +25,12 @@ class RoPE(nn.Module): - def __init__(self, hidden_size, n_heads, max_seq_len, device, rope_theta): + def __init__(self, head_dim, max_seq_len, device, rope_theta): super(RoPE, self).__init__() self.freqs_cis = nn.Parameter( self.precompute_freqs_cis( - hidden_size // n_heads, max_seq_len * 2, rope_theta + head_dim, max_seq_len * 2, rope_theta ).to(device), requires_grad=False, ) @@ -118,6 +118,7 @@ def __init__( use_alibi=False, attention_shapes=None, rope_theta=10000, + head_dim=None, **kwargs ): super().__init__() @@ -125,7 +126,11 @@ def __init__( self.n_heads = n_heads self.n_kv_heads = n_kv_heads self.n_kv_groups = n_heads // n_kv_heads if n_kv_heads != 0 else 0 - self.head_dim = self.hidden_size // n_heads + self.head_dim = head_dim + + if head_dim is None: + self.head_dim = hidden_size // n_heads + self.qkv_proj = qkv_layer self.o_proj = o_proj self.start_pos = 0 @@ -162,7 +167,7 @@ def __init__( self.is_neox = False else: self.alibi = None - self.rope = RoPE(hidden_size, n_heads, max_seq_len, dev, rope_theta) + self.rope = RoPE(self.head_dim, max_seq_len, dev, rope_theta) self.rotary_dim = self.head_dim self.is_neox = True diff --git a/awq/modules/fused/block.py b/awq/modules/fused/block.py index 0ffc4b93..23cd954d 100644 --- a/awq/modules/fused/block.py +++ b/awq/modules/fused/block.py @@ -80,10 +80,17 @@ def __init__( max_seq_len, rope_theta=10000, use_alibi=False, + head_dim=None, ): super().__init__() self.n_heads = n_heads self.n_kv_heads = n_kv_heads + self.head_dim = hidden_size // n_heads + + # To support gemma-7b, its head_dim is separate + if head_dim: + self.head_dim = head_dim + self.hidden_size = hidden_size self.norm_1 = norm_1.to(dev) self.attn = QuantAttentionFused( @@ -96,6 +103,7 @@ def __init__( max_seq_len=max_seq_len, use_alibi=use_alibi, rope_theta=rope_theta, + head_dim=head_dim, ).to(dev) self.norm_2 = norm_2.to(dev) self.mlp = mlp.to(dev) diff --git a/awq/modules/fused/model.py b/awq/modules/fused/model.py index c02233f6..c1ba2c1e 100644 --- a/awq/modules/fused/model.py +++ b/awq/modules/fused/model.py @@ -116,14 +116,14 @@ def forward( h, mask, ) - h, _, past_key_value = layer( + h, _, _ = layer( h, None, attention_mask=mask, is_causal=is_causal ) h = self.norm(h) return BaseModelOutputWithPast( last_hidden_state=h, - past_key_values=past_key_value, + past_key_values=None, hidden_states=(), attentions=(), ) diff --git a/awq/quantize/scale.py b/awq/quantize/scale.py index 0ee6ea05..47899cc5 100644 --- a/awq/quantize/scale.py +++ b/awq/quantize/scale.py @@ -6,9 +6,10 @@ from awq.utils.module import get_op_by_name, set_op_by_name from transformers.models.bloom.modeling_bloom import BloomGelu from transformers.models.llama.modeling_llama import LlamaRMSNorm +from transformers.models.gemma.modeling_gemma import GemmaRMSNorm from transformers.activations import NewGELUActivation, PytorchGELUTanh, GELUActivation -allowed_norms = [nn.LayerNorm, LlamaRMSNorm] +allowed_norms = [nn.LayerNorm, LlamaRMSNorm, GemmaRMSNorm] allowed_act_fns = [ nn.GELU, BloomGelu, @@ -88,7 +89,15 @@ def scale_ln_fcs(ln: nn.Linear, fcs: List[nn.Linear], scales: torch.Tensor): scales = scales.to(ln.weight.device) - ln.weight.div_(scales) + # GemmaRMSNorm is different from Llama's in that it multiplies + # (1 + weight) to the output, instead of just weight. + if isinstance(ln, GemmaRMSNorm): + ln.weight += 1 + ln.weight.div_(scales) + ln.weight -= 1 + else: + ln.weight.div_(scales) + if hasattr(ln, "bias") and ln.bias is not None: ln.bias.div_(scales)