diff --git a/awq/models/__init__.py b/awq/models/__init__.py index 2ae3fd55..ceecf4f6 100644 --- a/awq/models/__init__.py +++ b/awq/models/__init__.py @@ -17,3 +17,4 @@ from .gemma import GemmaAWQForCausalLM from .stablelm import StableLmAWQForCausalLM from .starcoder2 import Starcoder2AWQForCausalLM +from .cohere import CohereAWQForCausalLM diff --git a/awq/models/auto.py b/awq/models/auto.py index 0a236979..a6ab87e5 100644 --- a/awq/models/auto.py +++ b/awq/models/auto.py @@ -26,6 +26,7 @@ "gemma": GemmaAWQForCausalLM, "stablelm": StableLmAWQForCausalLM, "starcoder2": Starcoder2AWQForCausalLM, + "cohere": CohereAWQForCausalLM, } diff --git a/awq/models/base.py b/awq/models/base.py index ebd45ccc..ace0c112 100644 --- a/awq/models/base.py +++ b/awq/models/base.py @@ -70,6 +70,7 @@ "gemma": "AutoModelForCausalLM", "stablelm": "AutoModelForCausalLM", "starcoder2": "AutoModelForCausalLM", + "cohere": "AutoModelForCausalLM", } diff --git a/awq/models/cohere.py b/awq/models/cohere.py new file mode 100644 index 00000000..9f669535 --- /dev/null +++ b/awq/models/cohere.py @@ -0,0 +1,128 @@ +import tqdm +from typing import List, Tuple +from .base import BaseAWQForCausalLM +from awq.utils.fused_utils import fuse_qkv +from awq.modules.fused.block import CohereBlock +from awq.modules.fused.model import CohereModel +from transformers.models.cohere.modeling_cohere import ( + CohereDecoderLayer as OldCohereDecoderLayer, + CohereForCausalLM as OldCohereForCausalLM, +) +from awq.modules.fused.norm import FasterTransformerRMSNorm + +class CohereAWQForCausalLM(BaseAWQForCausalLM): + layer_type = "CohereDecoderLayer" + max_seq_len_key = "max_position_embeddings" + + @staticmethod + def fuse_layers(model: OldCohereForCausalLM): + fuser = CohereFuser(model) + fuser.fuse_transformer() + + @staticmethod + def get_model_layers(model: OldCohereForCausalLM): + return model.model.layers + + @staticmethod + def get_act_for_scaling(module: OldCohereDecoderLayer): + return dict(is_scalable=False) + + @staticmethod + def move_embed(model: OldCohereForCausalLM, device: str): + model.model.embed_tokens = model.model.embed_tokens.to(device) + + @staticmethod + def get_layers_for_scaling( + module: OldCohereDecoderLayer, input_feat, module_kwargs + ): + layers = [] + + # input + layers.append( + dict( + prev_op=module.input_layernorm, + layers=[ + module.self_attn.q_proj, + module.self_attn.k_proj, + module.self_attn.v_proj, + module.mlp.gate_proj, + module.mlp.up_proj, + ], + inp=input_feat["self_attn.q_proj"], + module2inspect=module, + kwargs=module_kwargs, + ) + ) + + # attention out + # Please refer to https://github.com/mit-han-lab/llm-awq/pull/67#issue-1850622696 + if module.self_attn.v_proj.weight.shape == module.self_attn.o_proj.weight.shape: + layers.append( + dict( + prev_op=module.self_attn.v_proj, + layers=[module.self_attn.o_proj], + inp=input_feat["self_attn.o_proj"], + ) + ) + + # linear out + layers.append( + dict( + prev_op=module.mlp.up_proj, + layers=[module.mlp.down_proj], + inp=input_feat["mlp.down_proj"], + ) + ) + + return layers + +class CohereFuser: + def __init__(self, model: OldCohereForCausalLM): + self.model = model + + self.cohere_blocks: List[Tuple[str, OldCohereDecoderLayer]] = [ + (name, module) + for name, module in self.model.named_modules() + if "CohereDecoderLayer".lower() in module.__class__.__name__.lower() + ] + + def fuse_transformer(self): + blocks = [] + + module: OldCohereDecoderLayer + for module in tqdm.tqdm(self.model.model.layers, desc="Fusing layers..."): + device = next(iter(module.state_dict().values())).device + qkv = fuse_qkv( + module, + module.self_attn.q_proj, + module.self_attn.k_proj, + module.self_attn.v_proj, + ) + norm_1 = module.input_layernorm + # norm_2 = FasterTransformerRMSNorm( + # module.post_attention_layernorm.weight, + # module.post_attention_layernorm.variance_epsilon, + # ) + blocks.append( + CohereBlock( + hidden_size=self.model.config.hidden_size, + n_heads=self.model.config.num_attention_heads, + n_kv_heads=self.model.config.num_key_value_heads, + qkv_layer=qkv, + o_proj=module.self_attn.o_proj, + mlp=module.mlp, + norm_1=norm_1, + # norm_2=norm_2, + dev=device, + max_seq_len=self.model.config.max_seq_len, + rope_theta=self.model.config.rope_theta, + ) + ) + + self.model.model = CohereModel( + self.model.config.vocab_size, + blocks, + self.model.model.embed_tokens, + self.model.model.norm, + ) + setattr(self.model.model, "blocks", self.model.model.blocks) diff --git a/awq/modules/fused/attn.py b/awq/modules/fused/attn.py index e334dd7f..680b1198 100644 --- a/awq/modules/fused/attn.py +++ b/awq/modules/fused/attn.py @@ -171,6 +171,9 @@ def __init__( self.rope = RoPE(self.rotary_dim, max_seq_len, dev, rope_theta) self.is_neox = True + if kwargs.get("is_neox") is not None: + self.is_neox = kwargs["is_neox"] + def forward( self, hidden_states: torch.Tensor, attention_mask=None, *args, **kwargs ): diff --git a/awq/modules/fused/block.py b/awq/modules/fused/block.py index e1971e37..0726a06a 100644 --- a/awq/modules/fused/block.py +++ b/awq/modules/fused/block.py @@ -132,6 +132,73 @@ def forward( return out, None, past_key_value +class CohereBlock(nn.Module): + def __init__( + self, + hidden_size, + n_heads, + n_kv_heads, + qkv_layer, + o_proj, + mlp, + norm_1, + # norm_2, + dev, + max_seq_len, + rope_theta=10000, + partial_rotary_factor=1.0, + use_alibi=False, + head_dim=None, + ): + super().__init__() + self.n_heads = n_heads + self.n_kv_heads = n_kv_heads + self.head_dim = hidden_size // n_heads + + # To support gemma-7b, its head_dim is separate + if head_dim: + self.head_dim = head_dim + self.hidden_size = hidden_size + self.norm_1 = norm_1.to(dev) + self.attn = QuantAttentionFused( + self.hidden_size, + self.n_heads, + self.n_kv_heads, + qkv_layer, + o_proj, + dev=dev, + max_seq_len=max_seq_len, + use_alibi=use_alibi, + rope_theta=rope_theta, + partial_rotary_factor=partial_rotary_factor, + head_dim=head_dim, + is_neox=False, + ).to(dev) + # self.norm_2 = norm_2.to(dev) + self.mlp = mlp.to(dev) + self.device = dev + + def forward( + self, + hidden_states, + past_key_value, + attn_bias=None, + attention_mask=None, + is_causal=None, + ): + norm_out = self.norm_1(hidden_states) + attn_output, _, past_key_value = self.attn.forward( + hidden_states=norm_out, + past_key_value=past_key_value, + attention_mask=attention_mask, + ) + + h = hidden_states.to(attn_output.device) + attn_output + out = h + self.mlp.forward(norm_out) + + return out, None, past_key_value + + class MPTBlock(nn.Module): def __init__( self, diff --git a/awq/modules/fused/model.py b/awq/modules/fused/model.py index 8733722b..16264ed1 100644 --- a/awq/modules/fused/model.py +++ b/awq/modules/fused/model.py @@ -11,6 +11,7 @@ FalconDecoderLayer, LlamaLikeBlock, MixtralBlock, + CohereBlock, ) @@ -83,11 +84,11 @@ def __init__(self, vocab_size, blocks, embedding, norm): self.blocks: List[LlamaLikeBlock] = nn.ModuleList(blocks) self.norm = norm self.last_forward_num_tokens = 0 - + @property def embed_tokens(self): return self.embedding - + @property def layers(self): return self.blocks @@ -124,9 +125,67 @@ def forward( h, mask, ) - h, _, _ = layer( - h, None, attention_mask=mask, is_causal=is_causal + h, _, _ = layer(h, None, attention_mask=mask, is_causal=is_causal) + h = self.norm(h) + + return BaseModelOutputWithPast( + last_hidden_state=h, + past_key_values=None, + hidden_states=(), + attentions=(), + ) + + +class CohereModel(nn.Module): + def __init__(self, vocab_size, blocks, embedding, norm): + super().__init__() + self.vocab_size = vocab_size + self.embedding = embedding + self.blocks: List[CohereBlock] = nn.ModuleList(blocks) + self.norm = norm + self.last_forward_num_tokens = 0 + + @property + def embed_tokens(self): + return self.embedding + + @property + def layers(self): + return self.blocks + + @torch.inference_mode() + def forward( + self, + input_ids: torch.Tensor, + attn_bias=None, + attention_mask=None, + is_causal=None, + *args, + **kwargs, + ): + input_ids, self.last_forward_num_tokens = fused_utils.prepare_input_ids( + input_ids, self.last_forward_num_tokens + ) + _bsz, seqlen = input_ids.shape + + fused_utils.prepare_cache(self.blocks, seqlen) + + h = self.embedding(input_ids) + + mask = fused_utils.prepare_attention_mask( + seqlen=seqlen, + start_pos=self.blocks[0].attn.start_pos, + device=input_ids.device, + type_as=h, + ) + + for layer in self.blocks: + h, mask = fused_utils.prepare_correct_devices( + layer, + h, + mask, ) + h, _, _ = layer(h, None, attention_mask=mask, is_causal=is_causal) h = self.norm(h) return BaseModelOutputWithPast( diff --git a/awq/quantize/scale.py b/awq/quantize/scale.py index 47899cc5..c072474a 100644 --- a/awq/quantize/scale.py +++ b/awq/quantize/scale.py @@ -7,9 +7,10 @@ from transformers.models.bloom.modeling_bloom import BloomGelu from transformers.models.llama.modeling_llama import LlamaRMSNorm from transformers.models.gemma.modeling_gemma import GemmaRMSNorm +from transformers.models.cohere.modeling_cohere import CohereLayerNorm from transformers.activations import NewGELUActivation, PytorchGELUTanh, GELUActivation -allowed_norms = [nn.LayerNorm, LlamaRMSNorm, GemmaRMSNorm] +allowed_norms = [nn.LayerNorm, LlamaRMSNorm, GemmaRMSNorm, CohereLayerNorm] allowed_act_fns = [ nn.GELU, BloomGelu,