You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
$ python collect_env.py
Collecting environment information...
/opt/conda/envs/py_3.9/lib/python3.9/site-packages/transformers/utils/hub.py:127: FutureWarning: Using `TRANSFORMERS_CACHE` is deprecated and will be removed in v5 of Transformers. Use `HF_HOME` instead.
warnings.warn(
PyTorch version: 2.5.0.dev20240726+rocm6.1
Is debug build: False
CUDA used to build PyTorch: N/A
ROCM used to build PyTorch: 6.1.40091-a8dbc0c19
OS: Ubuntu 20.04.6 LTS (x86_64)
GCC version: (Ubuntu 9.4.0-1ubuntu1~20.04.2) 9.4.0
Clang version: 17.0.0 (https://github.com/RadeonOpenCompute/llvm-project roc-6.1.2 24193 669db884972e769450470020c06a6f132a8a065b)
CMake version: version 3.26.4
Libc version: glibc-2.31
Python version: 3.9.19 (main, May 6 2024, 19:43:03) [GCC 11.2.0] (64-bit runtime)
Python platform: Linux-5.14.0-284.73.1.el9_2.x86_64-x86_64-with-glibc2.31
Is CUDA available: True
CUDA runtime version: Could not collect
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration: AMD Instinct MI210 (gfx90a:sramecc+:xnack-)
Nvidia driver version: Could not collect
cuDNN version: Could not collect
HIP runtime version: 6.1.40093
MIOpen runtime version: 3.1.0
Is XNNPACK available: True
CPU:
Architecture: x86_64
CPU op-mode(s): 32-bit, 64-bit
Byte Order: Little Endian
Address sizes: 52 bits physical, 57 bits virtual
CPU(s): 96
On-line CPU(s) list: 0-95
Thread(s) per core: 2
Core(s) per socket: 24
Socket(s): 2
NUMA node(s): 2
Vendor ID: AuthenticAMD
CPU family: 25
Model: 17
Model name: AMD EPYC 9254 24-Core Processor
Stepping: 1
CPU MHz: 4143.331
BogoMIPS: 5790.96
Virtualization: AMD-V
L1d cache: 1.5 MiB
L1i cache: 1.5 MiB
L2 cache: 48 MiB
L3 cache: 256 MiB
NUMA node0 CPU(s): 0-23,48-71
NUMA node1 CPU(s): 24-47,72-95
Vulnerability Gather data sampling: Not affected
Vulnerability Itlb multihit: Not affected
Vulnerability L1tf: Not affected
Vulnerability Mds: Not affected
Vulnerability Meltdown: Not affected
Vulnerability Mmio stale data: Not affected
Vulnerability Retbleed: Not affected
Vulnerability Spec rstack overflow: Mitigation; Safe RET
Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl
Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2: Mitigation; Retpolines, IBPB conditional, IBRS_FW, STIBP always-on, RSB filling, PBRSB-eIBRS Not affected
Vulnerability Srbds: Not affected
Vulnerability Tsx async abort: Not affected
Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc rep_good nopl nonstop_tsc cpuid extd_apicid aperfmperf rapl pni pclmulqdq monitor ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c rdrand lahf_lm cmp_legacy svm extapic cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw ibs skinit wdt tce topoext perfctr_core perfctr_nb bpext perfctr_llc mwaitx cpb cat_l3 cdp_l3 invpcid_single hw_pstate ssbd mba perfmon_v2 ibrs ibpb stibp vmmcall fsgsbase bmi1 avx2 smep bmi2 erms invpcid cqm rdt_a avx512f avx512dq rdseed adx smap avx512ifma clflushopt clwb avx512cd sha_ni avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local avx512_bf16 clzero irperf xsaveerptr rdpru wbnoinvd amd_ppin cppc arat npt lbrv svm_lock nrip_save tsc_scale vmcb_clean flushbyasid decodeassists pausefilter pfthreshold avic v_vmsave_vmload vgif x2avic v_spec_ctrl avx512vbmi umip pku ospke avx512_vbmi2 gfni vaes vpclmulqdq avx512_vnni avx512_bitalg avx512_vpopcntdq la57 rdpid overflow_recov succor smca fsrm flush_l1d
Versions of relevant libraries:
[pip3] mypy==1.7.0
[pip3] mypy-extensions==1.0.0
[pip3] numpy==1.26.4
[pip3] optree==0.9.1
[pip3] pytorch-triton-rocm==3.0.0+21eae954ef
[pip3] pyzmq==26.2.0
[pip3] torch==2.5.0.dev20240726+rocm6.1
[pip3] torchvision==0.20.0.dev20240726+rocm6.1
[pip3] transformers==4.44.2
[pip3] triton==3.0.0
[conda] No relevant packages
ROCM Version: 6.1.40093-bd86f1708
Neuron SDK Version: N/A
vLLM Version: 0.5.5@760e9f71a839ddc2a05c47af1fea23eeefbc368e
vLLM Build Flags:
CUDA Archs: Not Set; ROCm: Disabled; Neuron: Disabled
GPU Topology:
Could not collect
Model Input Dumps
No response
🐛 Describe the bug
Using multi-step on AMD gpu is crashing, the following script reproduce the issue:
from vllm import LLM, SamplingParams
llm = LLM(model="facebook/opt-125M",
disable_sliding_window=True,
num_scheduler_steps=8
)
params = SamplingParams(seed=123, max_tokens=500, temperature=1)
prompts = ["How to make pizza?"]
outputs = llm.generate(prompts, sampling_params=params )
for o in outputs:
print('_________')
print('### Text')
print('_________')
for o2 in o.outputs:
print(o2.text)
The stacktrace:
[rank0]: Traceback (most recent call last):
[rank0]: File "/tmp/test_async_multi_step2.py", line 11, in <module>
[rank0]: outputs = llm.generate(prompts, sampling_params=params )
[rank0]: File "/vllm-workspace/vllm/utils.py", line 1030, in inner
[rank0]: return fn(*args, **kwargs)
[rank0]: File "/vllm-workspace/vllm/entrypoints/llm.py", line 345, in generate
[rank0]: outputs = self._run_engine(use_tqdm=use_tqdm)
[rank0]: File "/vllm-workspace/vllm/entrypoints/llm.py", line 686, in _run_engine
[rank0]: step_outputs = self.llm_engine.step()
[rank0]: File "/vllm-workspace/vllm/engine/llm_engine.py", line 1369, in step
[rank0]: output = self.model_executor.execute_model(
[rank0]: File "/vllm-workspace/vllm/executor/gpu_executor.py", line 129, in execute_model
[rank0]: output = self.driver_worker.execute_model(execute_model_req)
[rank0]: File "/vllm-workspace/vllm/worker/worker_base.py", line 322, in execute_model
[rank0]: output = self.model_runner.execute_model(
[rank0]: File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
[rank0]: return func(*args, **kwargs)
[rank0]: File "/vllm-workspace/vllm/worker/multi_step_model_runner.py", line 271, in execute_model
[rank0]: model_input = self._advance_step(
[rank0]: File "/vllm-workspace/vllm/worker/multi_step_model_runner.py", line 362, in _advance_step
[rank0]: attn_metadata.advance_step(num_seqs, num_queries)
[rank0]: AttributeError: 'ROCmFlashAttentionMetadata' object has no attribute 'advance_step'
I did a priori investigation to the issue and I think that I found the cause. The stack trace shows that ROCmFlashAttentionMetadata does not implement advance_step. This code snippet from multi_step_model_runner.py has an asset to check if attn_metadata is an instance FlashAttentionMetadata before call advance_step.
@dataclass
class ROCmFlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata):
This class is not inherited by FlashAttentionMetadata which actually implements the advance_step, neither implemented the method by itself.
WARNING:
From the imports on multi_step_model_runner.py, ROCmFlashAttentionMetadata is imported as FlashAttentionMetadata, therefore the above assert is always true for ROCm backend, therefore it may be misleading that the assert before call advance_step is safe, which is not.
try:
from vllm.attention.backends.flash_attn import FlashAttentionMetadata
except ModuleNotFoundError:
# vllm_flash_attn is not installed, use the identical ROCm FA metadata
from vllm.attention.backends.rocm_flash_attn import (
ROCmFlashAttentionMetadata as FlashAttentionMetadata)
Before submitting a new issue...
Make sure you already searched for relevant issues, and asked the chatbot living at the bottom right corner of the documentation page, which can answer lots of frequently asked questions.
The text was updated successfully, but these errors were encountered:
Your current environment
The output of `python collect_env.py`
Model Input Dumps
No response
🐛 Describe the bug
Using multi-step on AMD gpu is crashing, the following script reproduce the issue:
The stacktrace:
I did a priori investigation to the issue and I think that I found the cause. The stack trace shows that ROCmFlashAttentionMetadata does not implement
advance_step
. This code snippet frommulti_step_model_runner.py
has an asset to check if attn_metadata is an instance FlashAttentionMetadata before calladvance_step
.The definition for ROCmFlashAttentionMetadata:
This class is not inherited by FlashAttentionMetadata which actually implements the
advance_step
, neither implemented the method by itself.WARNING:
From the imports on
multi_step_model_runner.py
, ROCmFlashAttentionMetadata is imported as FlashAttentionMetadata, therefore the above assert is always true for ROCm backend, therefore it may be misleading that the assert before call advance_step is safe, which is not.Before submitting a new issue...
The text was updated successfully, but these errors were encountered: