Skip to content

Commit

Permalink
[Misc] Don't dump contents of kvcache tensors on errors (vllm-project…
Browse files Browse the repository at this point in the history
  • Loading branch information
njhill authored and Jeffwan committed Sep 19, 2024
1 parent 0ca11a1 commit 35ede69
Showing 1 changed file with 20 additions and 2 deletions.
22 changes: 20 additions & 2 deletions vllm/worker/model_runner_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,13 @@
from abc import ABC, abstractmethod
from datetime import datetime
from functools import wraps
from typing import (TYPE_CHECKING, Any, Dict, Generic, List, Optional, Type,
TypeVar)
from typing import (TYPE_CHECKING, Any, Dict, Generic, Iterable, List,
Optional, Type, TypeVar)

import torch
from torch import is_tensor

from vllm.logger import init_logger
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors, SequenceGroupMetadata
Expand All @@ -17,6 +19,8 @@
from vllm.attention.backends.abstract import AttentionBackend
from vllm.model_executor import SamplingMetadata

logger = init_logger(__name__)

T = TypeVar('T', bound="BroadcastableModelInput")


Expand Down Expand Up @@ -113,6 +117,8 @@ def _wrapper(*args, **kwargs):
except Exception as err:
timestamp = datetime.now().strftime("%Y%m%d-%H%M%S")
filename = f"/tmp/err_{func.__name__}_input_{timestamp}.pkl"
logger.info("Writing input of failed execution to %s...",
filename)
with open(filename, "wb") as filep:
dumped_inputs = {
k: v
Expand All @@ -122,7 +128,19 @@ def _wrapper(*args, **kwargs):
for i, arg in enumerate(args):
if i not in (exclude_args or []):
dumped_inputs[f"arg_{i}"] = arg

# Only persist dtype and shape for kvcache tensors
# (can be way to big otherwise)
if (kv_caches := dumped_inputs.get("kv_caches")) \
and isinstance(kv_caches, Iterable):
dumped_inputs["kv_caches"] = [(t.dtype, t.shape)
for t in kv_caches
if is_tensor(t)]

pickle.dump(dumped_inputs, filep)
logger.info(
"Completed writing input of failed execution to %s.",
filename)
raise type(err)(
f"Error in model execution (input dumped to {filename}): "
f"{str(err)}") from err
Expand Down

0 comments on commit 35ede69

Please sign in to comment.