Skip to content

Commit

Permalink
[torch.compile] Adding torch compile annotations to some models (vllm…
Browse files Browse the repository at this point in the history
…-project#9639)

Signed-off-by: youkaichao <youkaichao@gmail.com>
Co-authored-by: youkaichao <youkaichao@gmail.com>
  • Loading branch information
CRZbulabula and youkaichao authored Oct 24, 2024
1 parent 9b4734e commit 846f48d
Show file tree
Hide file tree
Showing 7 changed files with 13 additions and 3 deletions.
2 changes: 1 addition & 1 deletion docs/source/models/supported_models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ Text Generation
- ✅︎
* - :code:`JAISLMHeadModel`
- Jais
- :code:`core42/jais-13b`, :code:`core42/jais-13b-chat`, :code:`core42/jais-30b-v3`, :code:`core42/jais-30b-chat-v3`, etc.
- :code:`inceptionai/jais-13b`, :code:`inceptionai/jais-13b-chat`, :code:`inceptionai/jais-30b-v3`, :code:`inceptionai/jais-30b-chat-v3`, etc.
-
- ✅︎
* - :code:`JambaForCausalLM`
Expand Down
2 changes: 1 addition & 1 deletion tests/distributed/test_pipeline_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ def iter_params(self, model_name: str):
# Uses Llama
# "internlm/internlm-chat-7b": PPTestSettings.fast(),
"internlm/internlm2-chat-7b": PPTestSettings.fast(trust_remote_code=True),
"core42/jais-13b-chat": PPTestSettings.fast(),
"inceptionai/jais-13b-chat": PPTestSettings.fast(),
# TODO: Implement PP
# "ai21labs/AI21-Jamba-1.5-Mini": PPTestSettings.fast(),
"meta-llama/Meta-Llama-3-8B": PPTestSettings.detailed(),
Expand Down
4 changes: 3 additions & 1 deletion vllm/model_executor/models/jais.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# coding=utf-8
# Adapted from
# https://huggingface.co/core42/jais-30b-chat-v3/blob/main/modeling_jais.py
# https://huggingface.co/inceptionai/jais-30b-chat-v3/blob/main/modeling_jais.py
# Copyright 2023 The vLLM team.
# Copyright 2023 the Jais authors and HuggingFace Inc. team. All rights
# reserved.
Expand All @@ -26,6 +26,7 @@
from torch import nn

from vllm.attention import Attention, AttentionMetadata
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size)
Expand Down Expand Up @@ -212,6 +213,7 @@ def forward(
return hidden_states


@support_torch_compile
class JAISModel(nn.Module):

def __init__(
Expand Down
2 changes: 2 additions & 0 deletions vllm/model_executor/models/minicpm.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from transformers import PretrainedConfig

from vllm.attention import Attention, AttentionMetadata
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, LoRAConfig
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
Expand Down Expand Up @@ -348,6 +349,7 @@ def forward(
return hidden_states, None


@support_torch_compile
class MiniCPMModel(nn.Module):

def __init__(
Expand Down
2 changes: 2 additions & 0 deletions vllm/model_executor/models/mpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import torch.nn as nn

from vllm.attention import Attention, AttentionMetadata
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size)
Expand Down Expand Up @@ -204,6 +205,7 @@ def forward(
return hidden_states


@support_torch_compile
class MPTModel(nn.Module):

def __init__(
Expand Down
2 changes: 2 additions & 0 deletions vllm/model_executor/models/nemotron.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from torch import nn

from vllm.attention import Attention, AttentionMetadata
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, LoRAConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import get_act_fn
Expand Down Expand Up @@ -290,6 +291,7 @@ def forward(
return hidden_states, residual


@support_torch_compile
class NemotronModel(nn.Module):

def __init__(
Expand Down
2 changes: 2 additions & 0 deletions vllm/model_executor/models/olmo.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from transformers import OlmoConfig

from vllm.attention import Attention, AttentionMetadata
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import SiluAndMul
Expand Down Expand Up @@ -221,6 +222,7 @@ def forward(
return hidden_states


@support_torch_compile
class OlmoModel(nn.Module):

def __init__(self,
Expand Down

0 comments on commit 846f48d

Please sign in to comment.