From 94e73f0b2abb1d5303d72231540e922e0484383d Mon Sep 17 00:00:00 2001
From: TechxGenus <jianghao0728@mail.ustc.edu.cn>
Date: Mon, 11 Mar 2024 22:15:10 +0800
Subject: [PATCH] Add Gemma Support (#393)

---
 awq/models/__init__.py     |   1 +
 awq/models/auto.py         |   1 +
 awq/models/base.py         |   1 +
 awq/models/gemma.py        | 149 +++++++++++++++++++++++++++++++++++++
 awq/modules/fused/attn.py  |  13 +++-
 awq/modules/fused/block.py |   8 ++
 awq/modules/fused/model.py |   4 +-
 awq/quantize/scale.py      |  13 +++-
 8 files changed, 182 insertions(+), 8 deletions(-)
 create mode 100644 awq/models/gemma.py

diff --git a/awq/models/__init__.py b/awq/models/__init__.py
index 14886a24..75542fe4 100644
--- a/awq/models/__init__.py
+++ b/awq/models/__init__.py
@@ -14,3 +14,4 @@
 from .llava import LlavaAWQForCausalLM
 from .mixtral import MixtralAWQForCausalLM
 from .qwen2 import Qwen2AWQForCausalLM
+from .gemma import GemmaAWQForCausalLM
diff --git a/awq/models/auto.py b/awq/models/auto.py
index c992061f..1ac6342a 100644
--- a/awq/models/auto.py
+++ b/awq/models/auto.py
@@ -23,6 +23,7 @@
     "baichuan": BaichuanAWQForCausalLM,
     "llava": LlavaAWQForCausalLM,
     "qwen2": Qwen2AWQForCausalLM,
+    "gemma": GemmaAWQForCausalLM,
 }
 
 
diff --git a/awq/models/base.py b/awq/models/base.py
index 8ef243ab..e5691ae0 100644
--- a/awq/models/base.py
+++ b/awq/models/base.py
@@ -67,6 +67,7 @@
     "baichuan": "AutoModelForCausalLM",
     "llava": "AutoModelForVision2Seq",
     "qwen2": "AutoModelForCausalLM",
+    "gemma": "AutoModelForCausalLM",
 }
 
 
diff --git a/awq/models/gemma.py b/awq/models/gemma.py
new file mode 100644
index 00000000..b3ed65db
--- /dev/null
+++ b/awq/models/gemma.py
@@ -0,0 +1,149 @@
+import tqdm
+import torch
+from typing import List, Tuple
+from .base import BaseAWQForCausalLM
+from awq.utils.fused_utils import fuse_qkv
+from awq.modules.fused.block import LlamaLikeBlock
+from awq.modules.fused.model import LlamaLikeModel
+from transformers.models.gemma.modeling_gemma import (
+    GemmaDecoderLayer as OldGemmaDecoderLayer,
+    GemmaForCausalLM as OldGemmaForCausalLM,
+)
+from awq.modules.fused.norm import FasterTransformerRMSNorm
+
+
+class GemmaAWQForCausalLM(BaseAWQForCausalLM):
+    layer_type = "GemmaDecoderLayer"
+    max_new_tokens_key = "max_position_embeddings"
+
+    @staticmethod
+    def fuse_layers(model: OldGemmaDecoderLayer):
+        fuser = GemmaFuser(model)
+        fuser.fuse_transformer()
+
+    @staticmethod
+    def get_model_layers(model: OldGemmaForCausalLM):
+        return model.model.layers
+
+    @staticmethod
+    def get_act_for_scaling(module: OldGemmaDecoderLayer):
+        return dict(is_scalable=False)
+
+    @staticmethod
+    def move_embed(model: OldGemmaForCausalLM, device: str):
+        model.model.embed_tokens = model.model.embed_tokens.to(device)
+
+    @staticmethod
+    def get_layers_for_scaling(module: OldGemmaDecoderLayer, input_feat, module_kwargs):
+        layers = []
+
+        # attention input
+        layers.append(
+            dict(
+                prev_op=module.input_layernorm,
+                layers=[
+                    module.self_attn.q_proj,
+                    module.self_attn.k_proj,
+                    module.self_attn.v_proj,
+                ],
+                inp=input_feat["self_attn.q_proj"],
+                module2inspect=module.self_attn,
+                kwargs=module_kwargs,
+            )
+        )
+
+        # attention out
+        # Please refer to https://github.com/mit-han-lab/llm-awq/pull/67#issue-1850622696
+        if module.self_attn.v_proj.weight.shape == module.self_attn.o_proj.weight.shape:
+            layers.append(
+                dict(
+                    prev_op=module.self_attn.v_proj,
+                    layers=[module.self_attn.o_proj],
+                    inp=input_feat["self_attn.o_proj"],
+                )
+            )
+
+        # linear 1
+        layers.append(
+            dict(
+                prev_op=module.post_attention_layernorm,
+                layers=[module.mlp.gate_proj, module.mlp.up_proj],
+                inp=input_feat["mlp.gate_proj"],
+                module2inspect=module.mlp,
+            )
+        )
+
+        # linear 2
+        layers.append(
+            dict(
+                prev_op=module.mlp.up_proj,
+                layers=[module.mlp.down_proj],
+                inp=input_feat["mlp.down_proj"],
+            )
+        )
+
+        return layers
+
+
+class GemmaFuser:
+    def __init__(self, model: OldGemmaForCausalLM):
+        self.model = model
+
+        self.Gemma_blocks: List[Tuple[str, OldGemmaDecoderLayer]] = [
+            (name, module)
+            for name, module in self.model.named_modules()
+            if "GemmaDecoderLayer".lower() in module.__class__.__name__.lower()
+        ]
+
+    def fuse_transformer(self):
+        blocks = []
+
+        module: OldGemmaDecoderLayer
+        for module in tqdm.tqdm(self.model.model.layers, desc="Fusing layers..."):
+            device = next(iter(module.state_dict().values())).device
+            qkv = fuse_qkv(
+                module,
+                module.self_attn.q_proj,
+                module.self_attn.k_proj,
+                module.self_attn.v_proj,
+            )
+            with torch.no_grad():
+                # GemmaRMSNorm is different from Llama's in that it multiplies
+                # (1 + weight) to the output, instead of just weight.
+                module.input_layernorm.weight += 1
+                module.post_attention_layernorm.weight += 1
+            norm_1 = FasterTransformerRMSNorm(
+                module.input_layernorm.weight, module.input_layernorm.eps
+            )
+            norm_2 = FasterTransformerRMSNorm(
+                module.post_attention_layernorm.weight,
+                module.post_attention_layernorm.eps,
+            )
+            blocks.append(
+                LlamaLikeBlock(
+                    hidden_size=self.model.config.hidden_size,
+                    n_heads=self.model.config.num_attention_heads,
+                    n_kv_heads=self.model.config.num_key_value_heads,
+                    qkv_layer=qkv,
+                    o_proj=module.self_attn.o_proj,
+                    mlp=module.mlp,
+                    norm_1=norm_1,
+                    norm_2=norm_2,
+                    dev=device,
+                    max_seq_len=self.model.config.max_seq_len,
+                    rope_theta=self.model.config.rope_theta,
+                    head_dim=self.model.config.head_dim,
+                )
+            )
+        
+        with torch.no_grad():
+            # Normalize Gemma's embedding layer
+            self.model.model.embed_tokens.weight *= self.model.config.hidden_size**0.5
+        
+        self.model.model = LlamaLikeModel(
+            self.model.config.vocab_size,
+            blocks,
+            self.model.model.embed_tokens,
+            self.model.model.norm,
+        )
+        setattr(self.model.model, "blocks", self.model.model.blocks)
diff --git a/awq/modules/fused/attn.py b/awq/modules/fused/attn.py
index f90fd502..f1732ea5 100644
--- a/awq/modules/fused/attn.py
+++ b/awq/modules/fused/attn.py
@@ -25,12 +25,12 @@
 
 
 class RoPE(nn.Module):
-    def __init__(self, hidden_size, n_heads, max_seq_len, device, rope_theta):
+    def __init__(self, head_dim, max_seq_len, device, rope_theta):
         super(RoPE, self).__init__()
 
         self.freqs_cis = nn.Parameter(
             self.precompute_freqs_cis(
-                hidden_size // n_heads, max_seq_len * 2, rope_theta
+                head_dim, max_seq_len * 2, rope_theta
             ).to(device),
             requires_grad=False,
         )
@@ -118,6 +118,7 @@ def __init__(
         use_alibi=False,
         attention_shapes=None,
         rope_theta=10000,
+        head_dim=None,
         **kwargs
     ):
         super().__init__()
@@ -125,7 +126,11 @@ def __init__(
         self.n_heads = n_heads
         self.n_kv_heads = n_kv_heads
         self.n_kv_groups = n_heads // n_kv_heads if n_kv_heads != 0 else 0
-        self.head_dim = self.hidden_size // n_heads
+        self.head_dim = head_dim
+        
+        if head_dim is None:
+            self.head_dim = hidden_size // n_heads
+
         self.qkv_proj = qkv_layer
         self.o_proj = o_proj
         self.start_pos = 0
@@ -162,7 +167,7 @@ def __init__(
             self.is_neox = False
         else:
             self.alibi = None
-            self.rope = RoPE(hidden_size, n_heads, max_seq_len, dev, rope_theta)
+            self.rope = RoPE(self.head_dim, max_seq_len, dev, rope_theta)
             self.rotary_dim = self.head_dim
             self.is_neox = True
 
diff --git a/awq/modules/fused/block.py b/awq/modules/fused/block.py
index 0ffc4b93..23cd954d 100644
--- a/awq/modules/fused/block.py
+++ b/awq/modules/fused/block.py
@@ -80,10 +80,17 @@ def __init__(
         max_seq_len,
         rope_theta=10000,
         use_alibi=False,
+        head_dim=None,
     ):
         super().__init__()
         self.n_heads = n_heads
         self.n_kv_heads = n_kv_heads
+        self.head_dim = hidden_size // n_heads
+
+        # To support gemma-7b, its head_dim is separate
+        if head_dim:
+            self.head_dim = head_dim
+
         self.hidden_size = hidden_size
         self.norm_1 = norm_1.to(dev)
         self.attn = QuantAttentionFused(
@@ -96,6 +103,7 @@ def __init__(
             max_seq_len=max_seq_len,
             use_alibi=use_alibi,
             rope_theta=rope_theta,
+            head_dim=head_dim,
         ).to(dev)
         self.norm_2 = norm_2.to(dev)
         self.mlp = mlp.to(dev)
diff --git a/awq/modules/fused/model.py b/awq/modules/fused/model.py
index c02233f6..c1ba2c1e 100644
--- a/awq/modules/fused/model.py
+++ b/awq/modules/fused/model.py
@@ -116,14 +116,14 @@ def forward(
                 h,
                 mask,
             )
-            h, _, past_key_value = layer(
+            h, _, _ = layer(
                 h, None, attention_mask=mask, is_causal=is_causal
             )
         h = self.norm(h)
 
         return BaseModelOutputWithPast(
             last_hidden_state=h,
-            past_key_values=past_key_value,
+            past_key_values=None,
             hidden_states=(),
             attentions=(),
         )
diff --git a/awq/quantize/scale.py b/awq/quantize/scale.py
index 0ee6ea05..47899cc5 100644
--- a/awq/quantize/scale.py
+++ b/awq/quantize/scale.py
@@ -6,9 +6,10 @@
 from awq.utils.module import get_op_by_name, set_op_by_name
 from transformers.models.bloom.modeling_bloom import BloomGelu
 from transformers.models.llama.modeling_llama import LlamaRMSNorm
+from transformers.models.gemma.modeling_gemma import GemmaRMSNorm
 from transformers.activations import NewGELUActivation, PytorchGELUTanh, GELUActivation
 
-allowed_norms = [nn.LayerNorm, LlamaRMSNorm]
+allowed_norms = [nn.LayerNorm, LlamaRMSNorm, GemmaRMSNorm]
 allowed_act_fns = [
     nn.GELU,
     BloomGelu,
@@ -88,7 +89,15 @@ def scale_ln_fcs(ln: nn.Linear, fcs: List[nn.Linear], scales: torch.Tensor):
 
     scales = scales.to(ln.weight.device)
 
-    ln.weight.div_(scales)
+    # GemmaRMSNorm is different from Llama's in that it multiplies
+    # (1 + weight) to the output, instead of just weight.
+    if isinstance(ln, GemmaRMSNorm):
+        ln.weight += 1
+        ln.weight.div_(scales)
+        ln.weight -= 1
+    else:
+        ln.weight.div_(scales)
+
     if hasattr(ln, "bias") and ln.bias is not None:
         ln.bias.div_(scales)