Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add GPTJ Support #1

Merged
merged 4 commits into from
Aug 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion awq/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,5 @@
from .llama import LlamaAWQForCausalLM
from .opt import OptAWQForCausalLM
from .falcon import FalconAWQForCausalLM
from .bloom import BloomAWQForCausalLM
from .bloom import BloomAWQForCausalLM
from .gptj import GPTJAWQForCausalLM
3 changes: 2 additions & 1 deletion awq/models/auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@
"opt": OptAWQForCausalLM,
"RefinedWeb": FalconAWQForCausalLM,
"RefinedWebModel": FalconAWQForCausalLM,
"bloom": BloomAWQForCausalLM
"bloom": BloomAWQForCausalLM,
"gptj": GPTJAWQForCausalLM
}

def check_and_get_model_type(model_dir, trust_remote_code=True):
Expand Down
6 changes: 3 additions & 3 deletions awq/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,8 +113,8 @@ def __init__(self, module):
super().__init__()
self.module = module

def forward(self, inp, **kwargs):
inps.append(inp)
def forward(self, hijacked_inputs, **kwargs):
inps.append(hijacked_inputs)
layer_kwargs.update(kwargs)
raise ValueError # early exit to break later inference

Expand Down Expand Up @@ -358,4 +358,4 @@ def _scale_activations(self, layer):

# scale activation
scaled_act = ScaledActivation(scale_dict['scale_layer'], scale_like)
set_op_by_name(layer, scale_dict['scale_name'], scaled_act)
set_op_by_name(layer, scale_dict['scale_name'], scaled_act)
53 changes: 53 additions & 0 deletions awq/models/gptj.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
from .base import BaseAWQForCausalLM
from transformers.models.gptj.modeling_gptj import GPTJForCausalLM, GPTJBlock

class GPTJAWQForCausalLM(BaseAWQForCausalLM):
layer_type = "GPTJBlock"
max_new_tokens_key = "n_positions"

@staticmethod
def get_model_layers(model: GPTJForCausalLM):
return model.transformer.h

@staticmethod
def get_act_for_scaling(module: GPTJBlock):
return dict(
is_scalable=True,
scale_name="mlp.act",
scale_layer=module.mlp.act,
scale_shape=module.mlp.fc_in.out_features
)

@staticmethod
def move_embed(model: GPTJForCausalLM, device: str):
model.transformer.wte = model.transformer.wte.to(device)

@staticmethod
def get_layers_for_scaling(module: GPTJBlock, input_feat, module_kwargs):
layers = []

# attention input + linear 1
layers.append(dict(
prev_op=module.ln_1,
layers=[module.attn.q_proj,
module.attn.k_proj, module.attn.v_proj, module.mlp.fc_in],
inp=input_feat['attn.q_proj'],
module2inspect=module,
kwargs=module_kwargs
))

# attention out
layers.append(dict(
prev_op=module.attn.v_proj,
layers=[module.attn.out_proj],
inp=input_feat['attn.out_proj'],
))

# linear 2
layers.append(dict(
prev_op=module.mlp.act,
layers=[module.mlp.fc_out],
inp=input_feat['mlp.fc_out'],
))

return layers
6 changes: 3 additions & 3 deletions awq/quantize/auto_scale.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from transformers.models.bloom.modeling_bloom import BloomBlock, BloomGelu
from transformers.models.opt.modeling_opt import OPTDecoderLayer
from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaRMSNorm

from transformers.activations import NewGELUActivation
from .qmodule import ScaledActivation
from awq.utils.module import get_op_by_name, get_op_name, set_op_by_name

Expand Down Expand Up @@ -79,7 +79,7 @@ def scale_fc_fc(fc1, fc2, scales):

@torch.no_grad()
def scale_gelu_fc(gelu, fc, scales):
assert isinstance(gelu, nn.GELU) or isinstance(gelu, BloomGelu)
assert any(isinstance(gelu,t) for t in [nn.GELU, BloomGelu, NewGELUActivation])
assert isinstance(fc, nn.Linear)

fc.weight.mul_(scales.view(1, -1).to(fc.weight.device))
Expand Down Expand Up @@ -195,7 +195,7 @@ def apply_scale(module, scales_list, input_feat_dict=None):
scale_fc_fc(prev_op, layers[0], scales)
elif isinstance(prev_op, (nn.LayerNorm, LlamaRMSNorm)):
scale_ln_fcs(prev_op, layers, scales)
elif isinstance(prev_op, nn.GELU) or isinstance(prev_op, BloomGelu):
elif any(isinstance(prev_op,t) for t in [nn.GELU, BloomGelu, NewGELUActivation]):
new_module = ScaledActivation(prev_op, scales)
set_op_by_name(module, prev_op_name, new_module)
scale_gelu_fc(prev_op, layers[0], scales)
Expand Down