Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CI/Build] build on empty device for better dev experience #4773

Merged
merged 13 commits into from
Aug 11, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions requirements-cpu.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,4 @@
-r requirements-common.txt

# Dependencies for x86_64 CPUs
torch == 2.3.0+cpu
triton >= 2.2.0 # FIXME(woosuk): This is a hack to avoid import error.
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Dealt with triton import errors in code

torch == 2.3.0+cpu
4 changes: 2 additions & 2 deletions requirements-cuda.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,5 +6,5 @@ ray >= 2.9
nvidia-ml-py # for pynvml package
vllm-nccl-cu12>=2.18,<2.19 # for downloading nccl library
torch == 2.3.0
xformers == 0.0.26.post1 # Requires PyTorch 2.3.0
vllm-flash-attn == 2.5.8.post1 # Requires PyTorch 2.3.0
xformers == 0.0.26.post1; platform_system == 'Linux' and platform_machine == 'x86_64' # Requires PyTorch 2.3.0
vllm-flash-attn == 2.5.8.post1; platform_system == 'Linux' and platform_machine == 'x86_64' # Requires PyTorch 2.3.0
8 changes: 7 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ def load_module_from_path(module_name, path):
assert sys.platform.startswith(
"linux"), "vLLM only supports Linux platform (including WSL)."

PLATFORM_AGNOSTIC_BUILD = envs.PLATFORM_AGNOSTIC_BUILD

MAIN_CUDA_VERSION = "12.1"


Expand Down Expand Up @@ -398,6 +400,9 @@ def _read_requirements(filename: str) -> List[str]:
ext_modules = []
package_data["vllm"].append("*.so")

if PLATFORM_AGNOSTIC_BUILD:
ext_modules = []

setup(
name="vllm",
version=get_vllm_version(),
Expand Down Expand Up @@ -428,6 +433,7 @@ def _read_requirements(filename: str) -> List[str]:
extras_require={
"tensorizer": ["tensorizer==2.9.0"],
},
cmdclass={"build_ext": cmake_build_ext} if not _is_neuron() else {},
cmdclass={"build_ext": cmake_build_ext}
if not (_is_neuron() or PLATFORM_AGNOSTIC_BUILD) else {},
package_data=package_data,
)
16 changes: 14 additions & 2 deletions vllm/attention/ops/prefix_prefill.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,20 @@
# https://github.com/ModelTC/lightllm/blob/main/lightllm/models/llama/triton_kernel/context_flashattention_nopad.py

import torch
import triton
import triton.language as tl

from vllm.logger import init_logger

logger = init_logger(__name__)

try:
import triton
import triton.language as tl
except ImportError as e:
logger.warning(
"Failed to import triton with %r. To enable vllm execution, "
"please install triton with `pip install triton` "
"(not available on macos)", e)
triton = type('triton', tuple(), {"__version__": "0.0.0"})()

if triton.__version__ >= "2.1.0":

Expand Down
30 changes: 28 additions & 2 deletions vllm/attention/ops/triton_flash_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,34 @@
"""

import torch
import triton
import triton.language as tl

from vllm.logger import init_logger

logger = init_logger(__name__)

try:
import triton
import triton.language as tl
except ImportError as e:
logger.warning(
"Failed to import triton with %r. To enable vllm execution, "
"please install triton with `pip install triton` "
"(not available on macos)", e)

def dummy_decorator(*args, **kwargs):
return args[0]

def dummy_callable(*args, **kwargs):
return None

triton = type(
"triton", tuple(), {
"jit": dummy_decorator,
"autotune": dummy_decorator,
"Config": dummy_callable,
"__call__": dummy_callable
})()
tl = type("tl", tuple(), {"constexpr": None})()

torch_dtype: tl.constexpr = torch.float16

Expand Down
5 changes: 5 additions & 0 deletions vllm/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
VLLM_USE_RAY_COMPILED_DAG: bool = False
VLLM_WORKER_MULTIPROC_METHOD: str = "spawn"
VLLM_TARGET_DEVICE: str = "cuda"
PLATFORM_AGNOSTIC_BUILD: bool = False
MAX_JOBS: Optional[str] = None
NVCC_THREADS: Optional[str] = None
VLLM_BUILD_WITH_NEURON: bool = False
Expand All @@ -49,6 +50,10 @@
"VLLM_TARGET_DEVICE":
lambda: os.getenv("VLLM_TARGET_DEVICE", "cuda"),

# Target device of vLLM, supporting [cuda (by default), rocm, neuron, cpu]
"PLATFORM_AGNOSTIC_BUILD":
lambda: bool(os.environ.get("PLATFORM_AGNOSTIC_BUILD", False)),

# Maximum number of compilation jobs to run in parallel.
# By default this is the number of CPUs
"MAX_JOBS":
Expand Down
17 changes: 15 additions & 2 deletions vllm/model_executor/layers/fused_moe/fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,28 @@
from typing import Any, Dict, Optional, Tuple

import torch
import triton
import triton.language as tl

from vllm import _custom_ops as ops
from vllm.logger import init_logger
from vllm.utils import is_hip

logger = init_logger(__name__)

try:
import triton
import triton.language as tl
except ImportError as e:
logger.warning(
"Failed to import triton with %r. To enable vllm execution, "
"please install triton with `pip install triton` "
"(not available on macos)", e)

def dummy_decorator(*args, **kwargs):
return args[0]

triton = type("triton", tuple(), {"jit": dummy_decorator})()
tl = type("tl", tuple(), {"constexpr": None, "dtype": None})()


@triton.jit
def fused_moe_kernel(
Expand Down
21 changes: 19 additions & 2 deletions vllm/model_executor/layers/ops/rand.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,25 @@
from typing import Optional, Union

import torch
import triton
import triton.language as tl

from vllm.logger import init_logger

logger = init_logger(__name__)

try:
import triton
import triton.language as tl
except ImportError as e:
logger.warning(
"Failed to import triton with %r. To enable vllm execution, "
"please install triton with `pip install triton` "
"(not available on macos)", e)

def dummy_decorator(*args, **kwargs):
return args[0]

triton = type("triton", tuple(), {"jit": dummy_decorator})()
tl = type("tl", tuple(), {"constexpr": None})()


def seeded_uniform(
Expand Down
20 changes: 18 additions & 2 deletions vllm/model_executor/layers/ops/sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,27 @@
from typing import Optional, Tuple

import torch
import triton
import triton.language as tl

from vllm.logger import init_logger
from vllm.model_executor.layers.ops.rand import seeded_uniform

logger = init_logger(__name__)

try:
import triton
import triton.language as tl
except ImportError as e:
logger.warning(
"Failed to import triton with %r. To enable vllm execution, "
"please install triton with `pip install triton` "
"(not available on macos)", e)

def dummy_decorator(*args, **kwargs):
return args[0]

triton = type("triton", tuple(), {"jit": dummy_decorator})()
tl = type("tl", tuple(), {"constexpr": None})()

_EPS = 1e-6

# This is a hardcoded limit in Triton (max block size).
Expand Down
Loading