From 8a02cd045ac661481ba2672846e09f5b57110f40 Mon Sep 17 00:00:00 2001 From: Yongzao <532741407@qq.com> Date: Thu, 24 Oct 2024 15:54:57 +0800 Subject: [PATCH] [torch.compile] Adding torch compile annotations to some models (#9639) Signed-off-by: youkaichao Co-authored-by: youkaichao --- docs/source/models/supported_models.rst | 2 +- tests/distributed/test_pipeline_parallel.py | 2 +- vllm/model_executor/models/jais.py | 4 +++- vllm/model_executor/models/minicpm.py | 2 ++ vllm/model_executor/models/mpt.py | 2 ++ vllm/model_executor/models/nemotron.py | 2 ++ vllm/model_executor/models/olmo.py | 2 ++ 7 files changed, 13 insertions(+), 3 deletions(-) diff --git a/docs/source/models/supported_models.rst b/docs/source/models/supported_models.rst index c92d65110f464..a5ce33e548b18 100644 --- a/docs/source/models/supported_models.rst +++ b/docs/source/models/supported_models.rst @@ -144,7 +144,7 @@ Text Generation - ✅︎ * - :code:`JAISLMHeadModel` - Jais - - :code:`core42/jais-13b`, :code:`core42/jais-13b-chat`, :code:`core42/jais-30b-v3`, :code:`core42/jais-30b-chat-v3`, etc. + - :code:`inceptionai/jais-13b`, :code:`inceptionai/jais-13b-chat`, :code:`inceptionai/jais-30b-v3`, :code:`inceptionai/jais-30b-chat-v3`, etc. - - ✅︎ * - :code:`JambaForCausalLM` diff --git a/tests/distributed/test_pipeline_parallel.py b/tests/distributed/test_pipeline_parallel.py index 8d0190e37ef13..214448bf4320e 100644 --- a/tests/distributed/test_pipeline_parallel.py +++ b/tests/distributed/test_pipeline_parallel.py @@ -145,7 +145,7 @@ def iter_params(self, model_name: str): # Uses Llama # "internlm/internlm-chat-7b": PPTestSettings.fast(), "internlm/internlm2-chat-7b": PPTestSettings.fast(trust_remote_code=True), - "core42/jais-13b-chat": PPTestSettings.fast(), + "inceptionai/jais-13b-chat": PPTestSettings.fast(), # TODO: Implement PP # "ai21labs/AI21-Jamba-1.5-Mini": PPTestSettings.fast(), "meta-llama/Meta-Llama-3-8B": PPTestSettings.detailed(), diff --git a/vllm/model_executor/models/jais.py b/vllm/model_executor/models/jais.py index c5e5393442e30..b947f24a693b5 100644 --- a/vllm/model_executor/models/jais.py +++ b/vllm/model_executor/models/jais.py @@ -1,6 +1,6 @@ # coding=utf-8 # Adapted from -# https://huggingface.co/core42/jais-30b-chat-v3/blob/main/modeling_jais.py +# https://huggingface.co/inceptionai/jais-30b-chat-v3/blob/main/modeling_jais.py # Copyright 2023 The vLLM team. # Copyright 2023 the Jais authors and HuggingFace Inc. team. All rights # reserved. @@ -26,6 +26,7 @@ from torch import nn from vllm.attention import Attention, AttentionMetadata +from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) @@ -212,6 +213,7 @@ def forward( return hidden_states +@support_torch_compile class JAISModel(nn.Module): def __init__( diff --git a/vllm/model_executor/models/minicpm.py b/vllm/model_executor/models/minicpm.py index decd90b682a1e..03fb036020f2f 100644 --- a/vllm/model_executor/models/minicpm.py +++ b/vllm/model_executor/models/minicpm.py @@ -29,6 +29,7 @@ from transformers import PretrainedConfig from vllm.attention import Attention, AttentionMetadata +from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, LoRAConfig from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, @@ -348,6 +349,7 @@ def forward( return hidden_states, None +@support_torch_compile class MiniCPMModel(nn.Module): def __init__( diff --git a/vllm/model_executor/models/mpt.py b/vllm/model_executor/models/mpt.py index e3d3937b13fa0..ee802030a5ef3 100644 --- a/vllm/model_executor/models/mpt.py +++ b/vllm/model_executor/models/mpt.py @@ -7,6 +7,7 @@ import torch.nn as nn from vllm.attention import Attention, AttentionMetadata +from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) @@ -204,6 +205,7 @@ def forward( return hidden_states +@support_torch_compile class MPTModel(nn.Module): def __init__( diff --git a/vllm/model_executor/models/nemotron.py b/vllm/model_executor/models/nemotron.py index 14515e16e34ac..72a09129fed63 100644 --- a/vllm/model_executor/models/nemotron.py +++ b/vllm/model_executor/models/nemotron.py @@ -27,6 +27,7 @@ from torch import nn from vllm.attention import Attention, AttentionMetadata +from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, LoRAConfig from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import get_act_fn @@ -290,6 +291,7 @@ def forward( return hidden_states, residual +@support_torch_compile class NemotronModel(nn.Module): def __init__( diff --git a/vllm/model_executor/models/olmo.py b/vllm/model_executor/models/olmo.py index 5ca7c66f5407d..90ab8abcb84b4 100644 --- a/vllm/model_executor/models/olmo.py +++ b/vllm/model_executor/models/olmo.py @@ -28,6 +28,7 @@ from transformers import OlmoConfig from vllm.attention import Attention, AttentionMetadata +from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import SiluAndMul @@ -221,6 +222,7 @@ def forward( return hidden_states +@support_torch_compile class OlmoModel(nn.Module): def __init__(self,