Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: natively support Granite models #2682

Merged
merged 2 commits into from
Oct 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/supported_models.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ Text Generation Inference enables serving optimized models. The following sectio
- [Llava Next (1.6)](https://huggingface.co/llava-hf/llava-v1.6-vicuna-13b-hf) (Multimodal)
- [Llama](https://huggingface.co/collections/meta-llama/llama-31-669fc079a0c406a149a5738f)
- [Phi 3](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct)
- [Granite](https://huggingface.co/ibm-granite/granite-3.0-8b-instruct)
- [Gemma](https://huggingface.co/google/gemma-7b)
- [PaliGemma](https://huggingface.co/google/paligemma-3b-pt-224)
- [Gemma2](https://huggingface.co/collections/google/gemma-2-release-667d6600fd5220e7b967f315)
Expand Down
2 changes: 2 additions & 0 deletions nix/impure-shell.nix
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,9 @@ mkShell {
[
cuda_cccl
cuda_cudart
cuda_nvrtc
cuda_nvtx
cuda_profiler_api
cudnn
libcublas
libcusolver
Expand Down
1 change: 1 addition & 0 deletions router/src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,7 @@ pub enum Config {
Idefics2(Idefics2),
Ssm,
GptBigcode,
Granite,
Santacoder,
Bloom,
Mpt,
Expand Down
1,355 changes: 744 additions & 611 deletions server/poetry.lock

Large diffs are not rendered by default.

12 changes: 6 additions & 6 deletions server/requirements_cuda.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
certifi==2024.8.30 ; python_version >= "3.9" and python_version < "3.13"
charset-normalizer==3.3.2 ; python_version >= "3.9" and python_version < "3.13"
charset-normalizer==3.4.0 ; python_version >= "3.9" and python_version < "3.13"
click==8.1.7 ; python_version >= "3.9" and python_version < "3.13"
colorama==0.4.6 ; python_version >= "3.9" and python_version < "3.13" and (sys_platform == "win32" or platform_system == "Windows")
deprecated==1.2.14 ; python_version >= "3.9" and python_version < "3.13"
Expand All @@ -10,7 +10,7 @@ googleapis-common-protos==1.65.0 ; python_version >= "3.9" and python_version <
grpc-interceptor==0.15.4 ; python_version >= "3.9" and python_version < "3.13"
grpcio-reflection==1.62.3 ; python_version >= "3.9" and python_version < "3.13"
grpcio-status==1.62.3 ; python_version >= "3.9" and python_version < "3.13"
grpcio==1.66.1 ; python_version >= "3.9" and python_version < "3.13"
grpcio==1.67.0 ; python_version >= "3.9" and python_version < "3.13"
hf-transfer==0.1.8 ; python_version >= "3.9" and python_version < "3.13"
huggingface-hub==0.23.5 ; python_version >= "3.9" and python_version < "3.13"
idna==3.10 ; python_version >= "3.9" and python_version < "3.13"
Expand Down Expand Up @@ -38,14 +38,14 @@ pygments==2.18.0 ; python_version >= "3.9" and python_version < "3.13"
pyyaml==6.0.2 ; python_version >= "3.9" and python_version < "3.13"
regex==2024.9.11 ; python_version >= "3.9" and python_version < "3.13"
requests==2.32.3 ; python_version >= "3.9" and python_version < "3.13"
rich==13.8.1 ; python_version >= "3.9" and python_version < "3.13"
rich==13.9.3 ; python_version >= "3.9" and python_version < "3.13"
safetensors==0.4.5 ; python_version >= "3.9" and python_version < "3.13"
scipy==1.13.1 ; python_version >= "3.9" and python_version < "3.13"
sentencepiece==0.2.0 ; python_version >= "3.9" and python_version < "3.13"
setuptools==75.1.0 ; python_version >= "3.9" and python_version < "3.13"
tokenizers==0.20.0 ; python_version >= "3.9" and python_version < "3.13"
setuptools==75.2.0 ; python_version >= "3.9" and python_version < "3.13"
tokenizers==0.20.1 ; python_version >= "3.9" and python_version < "3.13"
tqdm==4.66.5 ; python_version >= "3.9" and python_version < "3.13"
transformers==4.45.0 ; python_version >= "3.9" and python_version < "3.13"
transformers==4.45.2 ; python_version >= "3.9" and python_version < "3.13"
typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
typing-extensions==4.12.2 ; python_version >= "3.9" and python_version < "3.13"
urllib3==2.2.3 ; python_version >= "3.9" and python_version < "3.13"
Expand Down
12 changes: 6 additions & 6 deletions server/requirements_intel.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
certifi==2024.8.30 ; python_version >= "3.9" and python_version < "3.13"
charset-normalizer==3.3.2 ; python_version >= "3.9" and python_version < "3.13"
charset-normalizer==3.4.0 ; python_version >= "3.9" and python_version < "3.13"
click==8.1.7 ; python_version >= "3.9" and python_version < "3.13"
colorama==0.4.6 ; python_version >= "3.9" and python_version < "3.13" and (sys_platform == "win32" or platform_system == "Windows")
deprecated==1.2.14 ; python_version >= "3.9" and python_version < "3.13"
Expand All @@ -10,7 +10,7 @@ googleapis-common-protos==1.65.0 ; python_version >= "3.9" and python_version <
grpc-interceptor==0.15.4 ; python_version >= "3.9" and python_version < "3.13"
grpcio-reflection==1.62.3 ; python_version >= "3.9" and python_version < "3.13"
grpcio-status==1.62.3 ; python_version >= "3.9" and python_version < "3.13"
grpcio==1.66.1 ; python_version >= "3.9" and python_version < "3.13"
grpcio==1.67.0 ; python_version >= "3.9" and python_version < "3.13"
hf-transfer==0.1.8 ; python_version >= "3.9" and python_version < "3.13"
huggingface-hub==0.23.5 ; python_version >= "3.9" and python_version < "3.13"
idna==3.10 ; python_version >= "3.9" and python_version < "3.13"
Expand Down Expand Up @@ -38,14 +38,14 @@ pygments==2.18.0 ; python_version >= "3.9" and python_version < "3.13"
pyyaml==6.0.2 ; python_version >= "3.9" and python_version < "3.13"
regex==2024.9.11 ; python_version >= "3.9" and python_version < "3.13"
requests==2.32.3 ; python_version >= "3.9" and python_version < "3.13"
rich==13.8.1 ; python_version >= "3.9" and python_version < "3.13"
rich==13.9.3 ; python_version >= "3.9" and python_version < "3.13"
safetensors==0.4.5 ; python_version >= "3.9" and python_version < "3.13"
scipy==1.13.1 ; python_version >= "3.9" and python_version < "3.13"
sentencepiece==0.2.0 ; python_version >= "3.9" and python_version < "3.13"
setuptools==75.1.0 ; python_version >= "3.9" and python_version < "3.13"
tokenizers==0.20.0 ; python_version >= "3.9" and python_version < "3.13"
setuptools==75.2.0 ; python_version >= "3.9" and python_version < "3.13"
tokenizers==0.20.1 ; python_version >= "3.9" and python_version < "3.13"
tqdm==4.66.5 ; python_version >= "3.9" and python_version < "3.13"
transformers==4.45.0 ; python_version >= "3.9" and python_version < "3.13"
transformers==4.45.2 ; python_version >= "3.9" and python_version < "3.13"
typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
typing-extensions==4.12.2 ; python_version >= "3.9" and python_version < "3.13"
urllib3==2.2.3 ; python_version >= "3.9" and python_version < "3.13"
Expand Down
12 changes: 6 additions & 6 deletions server/requirements_rocm.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
certifi==2024.8.30 ; python_version >= "3.9" and python_version < "3.13"
charset-normalizer==3.3.2 ; python_version >= "3.9" and python_version < "3.13"
charset-normalizer==3.4.0 ; python_version >= "3.9" and python_version < "3.13"
click==8.1.7 ; python_version >= "3.9" and python_version < "3.13"
colorama==0.4.6 ; python_version >= "3.9" and python_version < "3.13" and (sys_platform == "win32" or platform_system == "Windows")
deprecated==1.2.14 ; python_version >= "3.9" and python_version < "3.13"
Expand All @@ -10,7 +10,7 @@ googleapis-common-protos==1.65.0 ; python_version >= "3.9" and python_version <
grpc-interceptor==0.15.4 ; python_version >= "3.9" and python_version < "3.13"
grpcio-reflection==1.62.3 ; python_version >= "3.9" and python_version < "3.13"
grpcio-status==1.62.3 ; python_version >= "3.9" and python_version < "3.13"
grpcio==1.66.1 ; python_version >= "3.9" and python_version < "3.13"
grpcio==1.67.0 ; python_version >= "3.9" and python_version < "3.13"
hf-transfer==0.1.8 ; python_version >= "3.9" and python_version < "3.13"
huggingface-hub==0.23.5 ; python_version >= "3.9" and python_version < "3.13"
idna==3.10 ; python_version >= "3.9" and python_version < "3.13"
Expand Down Expand Up @@ -38,14 +38,14 @@ pygments==2.18.0 ; python_version >= "3.9" and python_version < "3.13"
pyyaml==6.0.2 ; python_version >= "3.9" and python_version < "3.13"
regex==2024.9.11 ; python_version >= "3.9" and python_version < "3.13"
requests==2.32.3 ; python_version >= "3.9" and python_version < "3.13"
rich==13.8.1 ; python_version >= "3.9" and python_version < "3.13"
rich==13.9.3 ; python_version >= "3.9" and python_version < "3.13"
safetensors==0.4.5 ; python_version >= "3.9" and python_version < "3.13"
scipy==1.13.1 ; python_version >= "3.9" and python_version < "3.13"
sentencepiece==0.2.0 ; python_version >= "3.9" and python_version < "3.13"
setuptools==75.1.0 ; python_version >= "3.9" and python_version < "3.13"
tokenizers==0.20.0 ; python_version >= "3.9" and python_version < "3.13"
setuptools==75.2.0 ; python_version >= "3.9" and python_version < "3.13"
tokenizers==0.20.1 ; python_version >= "3.9" and python_version < "3.13"
tqdm==4.66.5 ; python_version >= "3.9" and python_version < "3.13"
transformers==4.45.0 ; python_version >= "3.9" and python_version < "3.13"
transformers==4.45.2 ; python_version >= "3.9" and python_version < "3.13"
typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
typing-extensions==4.12.2 ; python_version >= "3.9" and python_version < "3.13"
urllib3==2.2.3 ; python_version >= "3.9" and python_version < "3.13"
Expand Down
16 changes: 14 additions & 2 deletions server/text_generation_server/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,11 @@ class ModelType(enum.Enum):
"name": "Phi 3",
"url": "https://huggingface.co/microsoft/Phi-3-mini-4k-instruct",
}
GRANITE = {
"type": "granite",
"name": "Granite",
"url": "https://huggingface.co/ibm-granite/granite-3.0-8b-instruct",
}
GEMMA = {
"type": "gemma",
"name": "Gemma",
Expand Down Expand Up @@ -862,7 +867,12 @@ def get_model(
trust_remote_code=trust_remote_code,
)

elif model_type == LLAMA or model_type == BAICHUAN or model_type == PHI3:
elif (
model_type == LLAMA
or model_type == BAICHUAN
or model_type == PHI3
or model_type == GRANITE
):
if FLASH_ATTENTION:
return FlashCausalLM(
model_id=model_id,
Expand All @@ -876,7 +886,9 @@ def get_model(
lora_adapter_ids=lora_adapter_ids,
)
elif sharded:
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Llama"))
raise NotImplementedError(
FLASH_ATT_ERROR_MESSAGE.format(f"Sharded {model_type}")
)
else:
return CausalLM.fallback(
model_id,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,10 @@ def __init__(
device=weights.device,
)

self.softmax_scale = self.head_size**-0.5
# `config.attention_multiplier` is used in Granite
self.softmax_scale = getattr(
config, "attention_multiplier", self.head_size**-0.5
)

if self.num_heads % weights.process_group.size() != 0:
raise ValueError(
Expand All @@ -180,7 +183,7 @@ def __init__(
config,
prefix=f"{prefix}.o_proj",
weights=weights,
bias=False,
bias=getattr(config, "attention_bias", False),
)

self.o_proj = TensorParallelAdapterRowLinear.load(
Expand Down Expand Up @@ -436,6 +439,11 @@ def __init__(self, index, prefix, config, weights):
eps=config.rms_norm_eps,
)

# Used in Granite
# This could eventually be baked into the weights like we do for the embeddings/lm_head
# but this would mean modifying the lora code
self.residual_multiplier = getattr(config, "residual_multiplier", None)

def forward(
self,
hidden_states,
Expand Down Expand Up @@ -466,13 +474,16 @@ def forward(
max_s,
adapter_data,
)
if self.residual_multiplier is not None:
attn_output *= self.residual_multiplier

# faster post attention rms norm
normed_attn_res_output, attn_res = self.post_attention_layernorm(
attn_output, res
)

mlp_output = self.dense(normed_attn_res_output, adapter_data)
if self.residual_multiplier is not None:
mlp_output *= self.residual_multiplier

return mlp_output, attn_res

Expand Down Expand Up @@ -624,13 +635,28 @@ def __init__(self, prefix: str, config, weights):
else:
suffix = "lm_head"

# Used in Granite
embedding_multiplier = getattr(config, "embedding_multiplier", None)
if embedding_multiplier is not None:
self.embed_tokens.weight.data *= embedding_multiplier

with no_fp8(weights):
self.lm_head = SpeculativeHead.load(
config,
prefix=suffix if not prefix else f"{prefix}.{suffix}",
weights=weights,
)

# Used in Granite
self.logits_scaling = getattr(config, "logits_scaling", None)
if self.logits_scaling is not None and self.lm_head.head is not None:
try:
# Scale the weights directly
self.lm_head.head.linear.weight.data /= self.logits_scaling
self.logits_scaled = True
except Exception:
self.logits_scaled = False

def forward(
self,
input_ids: torch.Tensor,
Expand Down Expand Up @@ -664,4 +690,11 @@ def forward(
if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices]
logits, speculative_logits = self.lm_head(hidden_states)

# Used in Granite
if not self.logits_scaled:
logits /= self.logits_scaling
if speculative_logits is not None:
speculative_logits /= self.logits_scaling

return logits, speculative_logits
Loading