From ca794e0969957c2f9320b7b83801e6818f680bf7 Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Tue, 17 Sep 2024 20:24:29 +0100 Subject: [PATCH] [Misc] Don't dump contents of kvcache tensors on errors (#8527) Signed-off-by: Amit Garg --- vllm/worker/model_runner_base.py | 22 ++++++++++++++++++++-- 1 file changed, 20 insertions(+), 2 deletions(-) diff --git a/vllm/worker/model_runner_base.py b/vllm/worker/model_runner_base.py index 94d2507968382..975b88c0e79a2 100644 --- a/vllm/worker/model_runner_base.py +++ b/vllm/worker/model_runner_base.py @@ -3,11 +3,13 @@ from abc import ABC, abstractmethod from datetime import datetime from functools import wraps -from typing import (TYPE_CHECKING, Any, Dict, Generic, List, Optional, Type, - TypeVar) +from typing import (TYPE_CHECKING, Any, Dict, Generic, Iterable, List, + Optional, Type, TypeVar) import torch +from torch import is_tensor +from vllm.logger import init_logger from vllm.model_executor.layers.sampler import SamplerOutput from vllm.platforms import current_platform from vllm.sequence import IntermediateTensors, SequenceGroupMetadata @@ -17,6 +19,8 @@ from vllm.attention.backends.abstract import AttentionBackend from vllm.model_executor import SamplingMetadata +logger = init_logger(__name__) + T = TypeVar('T', bound="BroadcastableModelInput") @@ -113,6 +117,8 @@ def _wrapper(*args, **kwargs): except Exception as err: timestamp = datetime.now().strftime("%Y%m%d-%H%M%S") filename = f"/tmp/err_{func.__name__}_input_{timestamp}.pkl" + logger.info("Writing input of failed execution to %s...", + filename) with open(filename, "wb") as filep: dumped_inputs = { k: v @@ -122,7 +128,19 @@ def _wrapper(*args, **kwargs): for i, arg in enumerate(args): if i not in (exclude_args or []): dumped_inputs[f"arg_{i}"] = arg + + # Only persist dtype and shape for kvcache tensors + # (can be way to big otherwise) + if (kv_caches := dumped_inputs.get("kv_caches")) \ + and isinstance(kv_caches, Iterable): + dumped_inputs["kv_caches"] = [(t.dtype, t.shape) + for t in kv_caches + if is_tensor(t)] + pickle.dump(dumped_inputs, filep) + logger.info( + "Completed writing input of failed execution to %s.", + filename) raise type(err)( f"Error in model execution (input dumped to {filename}): " f"{str(err)}") from err