Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[torch.compile] Adding torch compile annotations to some models #9639

Merged
merged 2 commits into from
Oct 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/source/models/supported_models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ Text Generation
- ✅︎
* - :code:`JAISLMHeadModel`
- Jais
- :code:`core42/jais-13b`, :code:`core42/jais-13b-chat`, :code:`core42/jais-30b-v3`, :code:`core42/jais-30b-chat-v3`, etc.
- :code:`inceptionai/jais-13b`, :code:`inceptionai/jais-13b-chat`, :code:`inceptionai/jais-30b-v3`, :code:`inceptionai/jais-30b-chat-v3`, etc.
-
- ✅︎
* - :code:`JambaForCausalLM`
Expand Down
2 changes: 1 addition & 1 deletion tests/distributed/test_pipeline_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ def iter_params(self, model_name: str):
# Uses Llama
# "internlm/internlm-chat-7b": PPTestSettings.fast(),
"internlm/internlm2-chat-7b": PPTestSettings.fast(trust_remote_code=True),
"core42/jais-13b-chat": PPTestSettings.fast(),
"inceptionai/jais-13b-chat": PPTestSettings.fast(),
# TODO: Implement PP
# "ai21labs/AI21-Jamba-1.5-Mini": PPTestSettings.fast(),
"openbmb/MiniCPM-2B-sft-bf16": PPTestSettings.fast(trust_remote_code=True),
Expand Down
4 changes: 3 additions & 1 deletion vllm/model_executor/models/jais.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# coding=utf-8
# Adapted from
# https://huggingface.co/core42/jais-30b-chat-v3/blob/main/modeling_jais.py
# https://huggingface.co/inceptionai/jais-30b-chat-v3/blob/main/modeling_jais.py
# Copyright 2023 The vLLM team.
# Copyright 2023 the Jais authors and HuggingFace Inc. team. All rights
# reserved.
Expand All @@ -26,6 +26,7 @@
from torch import nn

from vllm.attention import Attention, AttentionMetadata
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size)
Expand Down Expand Up @@ -212,6 +213,7 @@ def forward(
return hidden_states


@support_torch_compile
class JAISModel(nn.Module):

def __init__(
Expand Down
2 changes: 2 additions & 0 deletions vllm/model_executor/models/minicpm.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from transformers import PretrainedConfig

from vllm.attention import Attention, AttentionMetadata
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, LoRAConfig
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
Expand Down Expand Up @@ -348,6 +349,7 @@ def forward(
return hidden_states, None


@support_torch_compile
class MiniCPMModel(nn.Module):

def __init__(
Expand Down
2 changes: 2 additions & 0 deletions vllm/model_executor/models/mpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import torch.nn as nn

from vllm.attention import Attention, AttentionMetadata
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size)
Expand Down Expand Up @@ -204,6 +205,7 @@ def forward(
return hidden_states


@support_torch_compile
class MPTModel(nn.Module):

def __init__(
Expand Down
2 changes: 2 additions & 0 deletions vllm/model_executor/models/nemotron.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from torch import nn

from vllm.attention import Attention, AttentionMetadata
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, LoRAConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import get_act_fn
Expand Down Expand Up @@ -290,6 +291,7 @@ def forward(
return hidden_states, residual


@support_torch_compile
class NemotronModel(nn.Module):

def __init__(
Expand Down
2 changes: 2 additions & 0 deletions vllm/model_executor/models/olmo.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from transformers import OlmoConfig

from vllm.attention import Attention, AttentionMetadata
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import SiluAndMul
Expand Down Expand Up @@ -221,6 +222,7 @@ def forward(
return hidden_states


@support_torch_compile
class OlmoModel(nn.Module):

def __init__(self,
Expand Down