Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cohere Support #457

Merged
merged 5 commits into from
Jun 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions awq/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,4 @@
from .gemma import GemmaAWQForCausalLM
from .stablelm import StableLmAWQForCausalLM
from .starcoder2 import Starcoder2AWQForCausalLM
from .cohere import CohereAWQForCausalLM
1 change: 1 addition & 0 deletions awq/models/auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
"gemma": GemmaAWQForCausalLM,
"stablelm": StableLmAWQForCausalLM,
"starcoder2": Starcoder2AWQForCausalLM,
"cohere": CohereAWQForCausalLM,
}


Expand Down
1 change: 1 addition & 0 deletions awq/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@
"gemma": "AutoModelForCausalLM",
"stablelm": "AutoModelForCausalLM",
"starcoder2": "AutoModelForCausalLM",
"cohere": "AutoModelForCausalLM",
}


Expand Down
128 changes: 128 additions & 0 deletions awq/models/cohere.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
import tqdm
from typing import List, Tuple
from .base import BaseAWQForCausalLM
from awq.utils.fused_utils import fuse_qkv
from awq.modules.fused.block import CohereBlock
from awq.modules.fused.model import CohereModel
from transformers.models.cohere.modeling_cohere import (
CohereDecoderLayer as OldCohereDecoderLayer,
CohereForCausalLM as OldCohereForCausalLM,
)
from awq.modules.fused.norm import FasterTransformerRMSNorm

class CohereAWQForCausalLM(BaseAWQForCausalLM):
layer_type = "CohereDecoderLayer"
max_seq_len_key = "max_position_embeddings"

@staticmethod
def fuse_layers(model: OldCohereForCausalLM):
fuser = CohereFuser(model)
fuser.fuse_transformer()

@staticmethod
def get_model_layers(model: OldCohereForCausalLM):
return model.model.layers

@staticmethod
def get_act_for_scaling(module: OldCohereDecoderLayer):
return dict(is_scalable=False)

@staticmethod
def move_embed(model: OldCohereForCausalLM, device: str):
model.model.embed_tokens = model.model.embed_tokens.to(device)

@staticmethod
def get_layers_for_scaling(
module: OldCohereDecoderLayer, input_feat, module_kwargs
):
layers = []

# input
layers.append(
dict(
prev_op=module.input_layernorm,
layers=[
module.self_attn.q_proj,
module.self_attn.k_proj,
module.self_attn.v_proj,
module.mlp.gate_proj,
module.mlp.up_proj,
],
inp=input_feat["self_attn.q_proj"],
module2inspect=module,
kwargs=module_kwargs,
)
)

# attention out
# Please refer to https://github.com/mit-han-lab/llm-awq/pull/67#issue-1850622696
if module.self_attn.v_proj.weight.shape == module.self_attn.o_proj.weight.shape:
layers.append(
dict(
prev_op=module.self_attn.v_proj,
layers=[module.self_attn.o_proj],
inp=input_feat["self_attn.o_proj"],
)
)

# linear out
layers.append(
dict(
prev_op=module.mlp.up_proj,
layers=[module.mlp.down_proj],
inp=input_feat["mlp.down_proj"],
)
)

return layers

class CohereFuser:
def __init__(self, model: OldCohereForCausalLM):
self.model = model

self.cohere_blocks: List[Tuple[str, OldCohereDecoderLayer]] = [
(name, module)
for name, module in self.model.named_modules()
if "CohereDecoderLayer".lower() in module.__class__.__name__.lower()
]

def fuse_transformer(self):
blocks = []

module: OldCohereDecoderLayer
for module in tqdm.tqdm(self.model.model.layers, desc="Fusing layers..."):
device = next(iter(module.state_dict().values())).device
qkv = fuse_qkv(
module,
module.self_attn.q_proj,
module.self_attn.k_proj,
module.self_attn.v_proj,
)
norm_1 = module.input_layernorm
# norm_2 = FasterTransformerRMSNorm(
# module.post_attention_layernorm.weight,
# module.post_attention_layernorm.variance_epsilon,
# )
blocks.append(
CohereBlock(
hidden_size=self.model.config.hidden_size,
n_heads=self.model.config.num_attention_heads,
n_kv_heads=self.model.config.num_key_value_heads,
qkv_layer=qkv,
o_proj=module.self_attn.o_proj,
mlp=module.mlp,
norm_1=norm_1,
# norm_2=norm_2,
dev=device,
max_seq_len=self.model.config.max_seq_len,
rope_theta=self.model.config.rope_theta,
)
)

self.model.model = CohereModel(
self.model.config.vocab_size,
blocks,
self.model.model.embed_tokens,
self.model.model.norm,
)
setattr(self.model.model, "blocks", self.model.model.blocks)
3 changes: 3 additions & 0 deletions awq/modules/fused/attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,9 @@ def __init__(
self.rope = RoPE(self.rotary_dim, max_seq_len, dev, rope_theta)
self.is_neox = True

if kwargs.get("is_neox") is not None:
self.is_neox = kwargs["is_neox"]

def forward(
self, hidden_states: torch.Tensor, attention_mask=None, *args, **kwargs
):
Expand Down
67 changes: 67 additions & 0 deletions awq/modules/fused/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,73 @@ def forward(
return out, None, past_key_value


class CohereBlock(nn.Module):
def __init__(
self,
hidden_size,
n_heads,
n_kv_heads,
qkv_layer,
o_proj,
mlp,
norm_1,
# norm_2,
dev,
max_seq_len,
rope_theta=10000,
partial_rotary_factor=1.0,
use_alibi=False,
head_dim=None,
):
super().__init__()
self.n_heads = n_heads
self.n_kv_heads = n_kv_heads
self.head_dim = hidden_size // n_heads

# To support gemma-7b, its head_dim is separate
if head_dim:
self.head_dim = head_dim
self.hidden_size = hidden_size
self.norm_1 = norm_1.to(dev)
self.attn = QuantAttentionFused(
self.hidden_size,
self.n_heads,
self.n_kv_heads,
qkv_layer,
o_proj,
dev=dev,
max_seq_len=max_seq_len,
use_alibi=use_alibi,
rope_theta=rope_theta,
partial_rotary_factor=partial_rotary_factor,
head_dim=head_dim,
is_neox=False,
).to(dev)
# self.norm_2 = norm_2.to(dev)
self.mlp = mlp.to(dev)
self.device = dev

def forward(
self,
hidden_states,
past_key_value,
attn_bias=None,
attention_mask=None,
is_causal=None,
):
norm_out = self.norm_1(hidden_states)
attn_output, _, past_key_value = self.attn.forward(
hidden_states=norm_out,
past_key_value=past_key_value,
attention_mask=attention_mask,
)

h = hidden_states.to(attn_output.device) + attn_output
out = h + self.mlp.forward(norm_out)

return out, None, past_key_value


class MPTBlock(nn.Module):
def __init__(
self,
Expand Down
67 changes: 63 additions & 4 deletions awq/modules/fused/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
FalconDecoderLayer,
LlamaLikeBlock,
MixtralBlock,
CohereBlock,
)


Expand Down Expand Up @@ -83,11 +84,11 @@ def __init__(self, vocab_size, blocks, embedding, norm):
self.blocks: List[LlamaLikeBlock] = nn.ModuleList(blocks)
self.norm = norm
self.last_forward_num_tokens = 0

@property
def embed_tokens(self):
return self.embedding

@property
def layers(self):
return self.blocks
Expand Down Expand Up @@ -124,9 +125,67 @@ def forward(
h,
mask,
)
h, _, _ = layer(
h, None, attention_mask=mask, is_causal=is_causal
h, _, _ = layer(h, None, attention_mask=mask, is_causal=is_causal)
h = self.norm(h)

return BaseModelOutputWithPast(
last_hidden_state=h,
past_key_values=None,
hidden_states=(),
attentions=(),
)


class CohereModel(nn.Module):
def __init__(self, vocab_size, blocks, embedding, norm):
super().__init__()
self.vocab_size = vocab_size
self.embedding = embedding
self.blocks: List[CohereBlock] = nn.ModuleList(blocks)
self.norm = norm
self.last_forward_num_tokens = 0

@property
def embed_tokens(self):
return self.embedding

@property
def layers(self):
return self.blocks

@torch.inference_mode()
def forward(
self,
input_ids: torch.Tensor,
attn_bias=None,
attention_mask=None,
is_causal=None,
*args,
**kwargs,
):
input_ids, self.last_forward_num_tokens = fused_utils.prepare_input_ids(
input_ids, self.last_forward_num_tokens
)
_bsz, seqlen = input_ids.shape

fused_utils.prepare_cache(self.blocks, seqlen)

h = self.embedding(input_ids)

mask = fused_utils.prepare_attention_mask(
seqlen=seqlen,
start_pos=self.blocks[0].attn.start_pos,
device=input_ids.device,
type_as=h,
)

for layer in self.blocks:
h, mask = fused_utils.prepare_correct_devices(
layer,
h,
mask,
)
h, _, _ = layer(h, None, attention_mask=mask, is_causal=is_causal)
h = self.norm(h)

return BaseModelOutputWithPast(
Expand Down
3 changes: 2 additions & 1 deletion awq/quantize/scale.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,10 @@
from transformers.models.bloom.modeling_bloom import BloomGelu
from transformers.models.llama.modeling_llama import LlamaRMSNorm
from transformers.models.gemma.modeling_gemma import GemmaRMSNorm
from transformers.models.cohere.modeling_cohere import CohereLayerNorm
from transformers.activations import NewGELUActivation, PytorchGELUTanh, GELUActivation

allowed_norms = [nn.LayerNorm, LlamaRMSNorm, GemmaRMSNorm]
allowed_norms = [nn.LayerNorm, LlamaRMSNorm, GemmaRMSNorm, CohereLayerNorm]
allowed_act_fns = [
nn.GELU,
BloomGelu,
Expand Down