Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added support for Jais models #3183

Merged
merged 51 commits into from
Mar 21, 2024
Merged
Show file tree
Hide file tree
Changes from 47 commits
Commits
Show all changes
51 commits
Select commit Hold shift + click to select a range
5a3ddef
updated code for jais
Mar 4, 2024
13cb19e
Merge branch 'vllm-project:main' into main
grandiose-pizza Mar 4, 2024
b5feaa6
updated flake-8
Mar 4, 2024
4d5b65e
fixed formatting
Mar 4, 2024
9ad3061
fixed formatting
Mar 4, 2024
7b015e6
fixed formatting
Mar 4, 2024
b595d39
fixed formatting
Mar 4, 2024
c976954
fixed formatting
Mar 4, 2024
49aa9a4
Merge branch 'vllm-project:main' into main
grandiose-pizza Mar 4, 2024
c626a78
Merge branch 'vllm-project:main' into main
grandiose-pizza Mar 4, 2024
dc32e68
Merge branch 'vllm-project:main' into main
grandiose-pizza Mar 5, 2024
7cb2757
fixed inference bugs
Mar 5, 2024
452227e
apply ruff
Mar 5, 2024
3776a66
apply yapf
Mar 5, 2024
689c3ec
bug fixes
Mar 5, 2024
697969d
ruff and yapf
Mar 5, 2024
1d43043
ruff and yapf
Mar 5, 2024
6e4b06e
fixed bug in config.scale_qk_dot_by_d
Mar 5, 2024
4321fc4
updated architectures in config
Mar 5, 2024
a6166d1
apply ruff
Mar 5, 2024
b68e2b1
apply ruff
Mar 5, 2024
51b745a
apply ruff
Mar 5, 2024
4fb9fe9
apply yapf
Mar 5, 2024
919ad17
Merge branch 'vllm-project:main' into main
grandiose-pizza Mar 6, 2024
65669fc
Merge branch 'vllm-project:main' into main
grandiose-pizza Mar 6, 2024
c9a3db5
Merge branch 'vllm-project:main' into main
grandiose-pizza Mar 6, 2024
e04e56d
fixed bug in multi GPU setting
Mar 6, 2024
23d9b76
Merge branch 'vllm-project:main' into main
grandiose-pizza Mar 7, 2024
98759e3
Merge branch 'vllm-project:main' into main
grandiose-pizza Mar 9, 2024
8fd0aec
adapted to PR #3005
Mar 9, 2024
a80e2dc
apply yapf
Mar 9, 2024
8a852cf
Merge branch 'vllm-project:main' into main
grandiose-pizza Mar 9, 2024
eb0a8e7
Merge branch 'vllm-project:main' into main
grandiose-pizza Mar 10, 2024
d60309e
Merge branch 'vllm-project:main' into main
grandiose-pizza Mar 11, 2024
6b6cb72
Merge branch 'vllm-project:main' into main
grandiose-pizza Mar 11, 2024
ade4c0a
apply ruff
Mar 11, 2024
b4012a2
applied ruff
Mar 11, 2024
07d4e5d
applied yapf
Mar 11, 2024
159f7f9
applied yapf
Mar 11, 2024
7fd25e2
Merge branch 'main' into main
grandiose-pizza Mar 11, 2024
4965c56
adapted to #3299
Mar 11, 2024
b68e975
Merge branch 'vllm-project:main' into main
grandiose-pizza Mar 11, 2024
9860285
Merge branch 'vllm-project:main' into main
grandiose-pizza Mar 14, 2024
7ee7e5a
Merge branch 'vllm-project:main' into main
grandiose-pizza Mar 18, 2024
940e409
Merge branch 'vllm-project:main' into main
grandiose-pizza Mar 20, 2024
85cc0ce
Merge branch 'vllm-project:main' into main
grandiose-pizza Mar 21, 2024
33a3a8c
adapted to #3233 and bug fix for gpt2
Mar 21, 2024
d0b4df5
applied ruff and yapf
Mar 21, 2024
31c12c8
apply ruff
Mar 21, 2024
54d17c7
format
Mar 21, 2024
36e9186
Merge branch 'vllm-project:main' into main
grandiose-pizza Mar 21, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ vLLM seamlessly supports many Hugging Face models, including the following archi
- GPT-NeoX (`EleutherAI/gpt-neox-20b`, `databricks/dolly-v2-12b`, `stabilityai/stablelm-tuned-alpha-7b`, etc.)
- InternLM (`internlm/internlm-7b`, `internlm/internlm-chat-7b`, etc.)
- InternLM2 (`internlm/internlm2-7b`, `internlm/internlm2-chat-7b`, etc.)
- Jais (`core42/jais-13b`, `core42/jais-13b-chat`, `core42/jais-30b-v3`, `core42/jais-30b-chat-v3`, etc.)
- LLaMA & LLaMA-2 (`meta-llama/Llama-2-70b-hf`, `lmsys/vicuna-13b-v1.3`, `young-geng/koala`, `openlm-research/open_llama_13b`, etc.)
- Mistral (`mistralai/Mistral-7B-v0.1`, `mistralai/Mistral-7B-Instruct-v0.1`, etc.)
- Mixtral (`mistralai/Mixtral-8x7B-v0.1`, `mistralai/Mixtral-8x7B-Instruct-v0.1`, etc.)
Expand Down
6 changes: 5 additions & 1 deletion docs/source/models/supported_models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,11 @@ Alongside each architecture, we include some popular models that use it.
* - :code:`InternLM2ForCausalLM`
- InternLM2
- :code:`internlm/internlm2-7b`, :code:`internlm/internlm2-chat-7b`, etc.
-
-
* - :code:`JAISLMHeadModel`
- Jais
- :code:`core42/jais-13b`, :code:`core42/jais-13b-chat`, :code:`core42/jais-30b-v3`, :code:`core42/jais-30b-chat-v3`, etc.
-
* - :code:`LlamaForCausalLM`
- LLaMA, LLaMA-2, Vicuna, Alpaca, Yi
- :code:`meta-llama/Llama-2-13b-hf`, :code:`meta-llama/Llama-2-70b-hf`, :code:`openlm-research/open_llama_13b`, :code:`lmsys/vicuna-13b-v1.3`, :code:`01-ai/Yi-6B`, :code:`01-ai/Yi-34B`, etc.
Expand Down
1 change: 1 addition & 0 deletions vllm/model_executor/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
"GPTNeoXForCausalLM": ("gpt_neox", "GPTNeoXForCausalLM"),
"InternLMForCausalLM": ("llama", "LlamaForCausalLM"),
"InternLM2ForCausalLM": ("internlm2", "InternLM2ForCausalLM"),
"JAISLMHeadModel": ("jais", "JAISLMHeadModel"),
"LlamaForCausalLM": ("llama", "LlamaForCausalLM"),
# For decapoda-research/llama-*
"LLaMAForCausalLM": ("llama", "LlamaForCausalLM"),
Expand Down
3 changes: 1 addition & 2 deletions vllm/model_executor/models/gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,8 +242,7 @@ def sample(
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(self.lm_head_weight, logits,
sampling_metadata)
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens

def load_weights(self,
Expand Down
354 changes: 354 additions & 0 deletions vllm/model_executor/models/jais.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,354 @@
# coding=utf-8
# Adapted from
# https://huggingface.co/core42/jais-30b-chat-v3/blob/main/modeling_jais.py
# Copyright 2023 The vLLM team.
# Copyright 2023 the Jais authors and HuggingFace Inc. team. All rights
# reserved.
# Copyright 2023 Cerebras Systems.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only Jais model compatible with HuggingFace weights."""

import math
from typing import List, Optional, Tuple

import torch
from torch import nn
from vllm.transformers_utils.configs import JAISConfig

from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.layers.attention import Attention
from vllm.model_executor.layers.linear import (
ColumnParallelLinear,
LinearMethodBase,
QKVParallelLinear,
RowParallelLinear,
)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding, )
from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_world_size,
get_tensor_model_parallel_rank,
)
from vllm.model_executor.weight_utils import (
default_weight_loader,
hf_model_weights_iterator,
)
from vllm.sequence import SamplerOutput
from vllm.model_executor.sampling_metadata import (
SamplingMetadata,
SamplingTensors,
)

KVCache = Tuple[torch.Tensor, torch.Tensor]


class SwiGLUActivation(nn.Module):

def forward(self, x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor:
return x1 * nn.functional.silu(x2)


def _get_alibi_slopes(n):

def get_slopes_power_of_2(n):
start = 2**(-(2**-(math.log2(n) - 3)))
ratio = start
return [start * ratio**i for i in range(n)]

if math.log2(n).is_integer():
return get_slopes_power_of_2(n)
else:
closest_power_of_2 = 2**math.floor(math.log2(n))
return (get_slopes_power_of_2(closest_power_of_2) + _get_alibi_slopes(
2 * closest_power_of_2)[0::2][:n - closest_power_of_2])


class JAISAttention(nn.Module):

def __init__(
self,
config: JAISConfig,
linear_method: Optional[LinearMethodBase] = None,
):
super().__init__()
self.hidden_size = config.hidden_size
total_num_heads = config.num_attention_heads
tensor_model_parallel_world_size = (
get_tensor_model_parallel_world_size())
assert total_num_heads % tensor_model_parallel_world_size == 0
self.num_heads = total_num_heads // tensor_model_parallel_world_size
self.head_dim = self.hidden_size // total_num_heads
if hasattr(config, "scale_qk_dot_by_d"):
config.mup_scale_qk_dot_by_d = config.scale_qk_dot_by_d
self.attn_scale_power = 1.0 if config.mup_scale_qk_dot_by_d else 0.5
self.scale = self.head_dim**-self.attn_scale_power

self.c_attn = QKVParallelLinear(
self.hidden_size,
self.head_dim,
total_num_heads,
bias=True,
linear_method=linear_method,
)
self.c_proj = RowParallelLinear(
self.hidden_size,
self.hidden_size,
bias=True,
linear_method=linear_method,
)

tp_rank = get_tensor_model_parallel_rank()
head_start = tp_rank * self.num_heads
head_end = (tp_rank + 1) * self.num_heads
alibi_slopes = _get_alibi_slopes(total_num_heads)
alibi_slopes = alibi_slopes[head_start:head_end]
self.attn = Attention(
self.num_heads,
self.head_dim,
scale=self.scale,
alibi_slopes=alibi_slopes,
)

def forward(
self,
hidden_states: torch.Tensor,
kv_cache: KVCache,
input_metadata: InputMetadata,
) -> torch.Tensor:
qkv, _ = self.c_attn(hidden_states)
q, k, v = qkv.chunk(chunks=3, dim=-1)
key_cache, value_cache = kv_cache
attn_output = self.attn(q, k, v, key_cache, value_cache,
input_metadata)
attn_output, _ = self.c_proj(attn_output)
return attn_output


class JAISMLP(nn.Module):
grandiose-pizza marked this conversation as resolved.
Show resolved Hide resolved

def __init__(
self,
intermediate_size: int,
config: JAISConfig,
linear_method: Optional[LinearMethodBase] = None,
):
super().__init__()
hidden_size = config.hidden_size
self.swiglu = config.activation_function == "swiglu"
self.c_fc = ColumnParallelLinear(
grandiose-pizza marked this conversation as resolved.
Show resolved Hide resolved
hidden_size,
intermediate_size,
bias=True,
linear_method=linear_method,
)
self.c_fc2 = (ColumnParallelLinear(
hidden_size,
intermediate_size,
bias=True,
linear_method=linear_method,
) if self.swiglu else None)
self.c_proj = RowParallelLinear(
intermediate_size,
hidden_size,
bias=True,
linear_method=linear_method,
)

self.act = SwiGLUActivation()

def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
if self.swiglu:
grandiose-pizza marked this conversation as resolved.
Show resolved Hide resolved
hidden_states2, _ = self.c_fc2(hidden_states)
hidden_states, _ = self.c_fc(hidden_states)
hidden_states = (self.act(hidden_states, hidden_states2)
if self.swiglu else self.act(hidden_states))
hidden_states, _ = self.c_proj(hidden_states)
return hidden_states


class JAISBlock(nn.Module):

def __init__(
self,
config: JAISConfig,
linear_method: Optional[LinearMethodBase] = None,
):
super().__init__()
hidden_size = config.hidden_size
inner_dim = (config.n_inner if config.n_inner is not None else 4 *
hidden_size)

self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
self.attn = JAISAttention(config, linear_method)
self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
self.mlp = JAISMLP(inner_dim, config, linear_method)

def forward(
self,
hidden_states: torch.Tensor,
kv_cache: KVCache,
input_metadata: InputMetadata,
) -> torch.Tensor:
residual = hidden_states
hidden_states = self.ln_1(hidden_states)
attn_output = self.attn(
hidden_states=hidden_states,
kv_cache=kv_cache,
input_metadata=input_metadata,
)
# residual connection
hidden_states = attn_output + residual

residual = hidden_states
hidden_states = self.ln_2(hidden_states)
feed_forward_hidden_states = self.mlp(hidden_states)
# residual connection
hidden_states = residual + feed_forward_hidden_states
return hidden_states


class JAISModel(nn.Module):

def __init__(
self,
config: JAISConfig,
linear_method: Optional[LinearMethodBase] = None,
):
super().__init__()
self.config = config
assert not config.add_cross_attention
assert not config.scale_attn_by_inverse_layer_idx
assert not config.reorder_and_upcast_attn
self.embed_dim = config.hidden_size
self.wte = VocabParallelEmbedding(config.vocab_size, self.embed_dim)
self.wpe = (nn.Embedding(config.max_position_embeddings,
self.embed_dim)
if config.position_embedding_type != "alibi" else None)
if hasattr(config, "embeddings_scale"):
self.embeddings_scale = config.embeddings_scale
else:
self.embeddings_scale = config.mup_embeddings_scale
self.h = nn.ModuleList([
JAISBlock(config, linear_method)
for _ in range(config.num_hidden_layers)
])
self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)

def forward(
self,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
kv_caches: List[KVCache],
input_metadata: InputMetadata,
) -> torch.Tensor:
inputs_embeds = self.wte(input_ids)
if self.wpe is not None:
position_embeds = self.wpe(position_ids)
hidden_states = inputs_embeds + position_embeds
else:
hidden_states = inputs_embeds
hidden_states *= torch.tensor(float(self.embeddings_scale),
dtype=hidden_states.dtype)

for i in range(len(self.h)):
layer = self.h[i]
hidden_states = layer(hidden_states, kv_caches[i], input_metadata)

hidden_states = self.ln_f(hidden_states)
return hidden_states


class JAISLMHeadModel(nn.Module):
grandiose-pizza marked this conversation as resolved.
Show resolved Hide resolved

def __init__(
self,
config: JAISConfig,
linear_method: Optional[LinearMethodBase] = None,
):
super().__init__()
self.config = config
self.linear_method = linear_method
self.transformer = JAISModel(config, linear_method)
self.lm_head_weight = self.transformer.wte.weight
if hasattr(config, "width_scale"):
self.output_logits_scale = config.width_scale
else:
self.output_logits_scale = (config.mup_output_alpha *
config.mup_width_scale)
self.logits_processor = LogitsProcessor(vocab_size=config.vocab_size,
scale=self.output_logits_scale)
self.sampler = Sampler()

def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[KVCache],
input_metadata: InputMetadata,
) -> torch.Tensor:
hidden_states = self.transformer(input_ids, positions, kv_caches,
input_metadata)
return hidden_states

def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
logits = self.logits_processor(self.lm_head_weight, hidden_states,
sampling_metadata)
return logits

def sample(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens

def load_weights(
self,
model_name_or_path: str,
cache_dir: Optional[str] = None,
load_format: str = "auto",
revision: Optional[str] = None,
):
params_dict = dict(self.named_parameters(remove_duplicate=False))
for name, loaded_weight in hf_model_weights_iterator(
model_name_or_path, cache_dir, load_format, revision):
if "lm_head.weight" in name:
# GPT-2 ties the weights of the embedding layer and the final
# linear layer.
continue
if ".attn.bias" in name or ".attn.masked_bias" in name:
# Skip attention mask.
# NOTE: "c_attn.bias" should not be skipped.
continue
if "relative_pe" in name:
continue
if not name.startswith("transformer."):
name = "transformer." + name
param = params_dict[name]
# The HF's GPT-2 implementation uses Conv1D instead of Linear.
# Because of this, we need to transpose the weights.
# Note(zhuohan): the logic below might break quantized models.
for conv1d_weight_name in ["c_attn", "c_proj", "c_fc"]:
if conv1d_weight_name not in name:
continue
if not name.endswith(".weight"):
continue
loaded_weight = loaded_weight.t()
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
Loading
Loading