[Misc] Don't dump contents of kvcache tensors on errors (#8527)

This commit is contained in:
Nick Hill 2024-09-17 20:24:29 +01:00 committed by GitHub
parent a54ed80249
commit 56c3de018c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -3,11 +3,13 @@ import pickle
from abc import ABC, abstractmethod
from datetime import datetime
from functools import wraps
from typing import (TYPE_CHECKING, Any, Dict, Generic, List, Optional, Type,
TypeVar)
from typing import (TYPE_CHECKING, Any, Dict, Generic, Iterable, List,
Optional, Type, TypeVar)
import torch
from torch import is_tensor
from vllm.logger import init_logger
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors, SequenceGroupMetadata
@ -17,6 +19,8 @@ if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionBackend
from vllm.model_executor import SamplingMetadata
logger = init_logger(__name__)
T = TypeVar('T', bound="BroadcastableModelInput")
@ -113,6 +117,8 @@ def dump_input_when_exception(exclude_args: Optional[List[int]] = None,
except Exception as err:
timestamp = datetime.now().strftime("%Y%m%d-%H%M%S")
filename = f"/tmp/err_{func.__name__}_input_{timestamp}.pkl"
logger.info("Writing input of failed execution to %s...",
filename)
with open(filename, "wb") as filep:
dumped_inputs = {
k: v
@ -122,7 +128,19 @@ def dump_input_when_exception(exclude_args: Optional[List[int]] = None,
for i, arg in enumerate(args):
if i not in (exclude_args or []):
dumped_inputs[f"arg_{i}"] = arg
# Only persist dtype and shape for kvcache tensors
# (can be way to big otherwise)
if (kv_caches := dumped_inputs.get("kv_caches")) \
and isinstance(kv_caches, Iterable):
dumped_inputs["kv_caches"] = [(t.dtype, t.shape)
for t in kv_caches
if is_tensor(t)]
pickle.dump(dumped_inputs, filep)
logger.info(
"Completed writing input of failed execution to %s.",
filename)
raise type(err)(
f"Error in model execution (input dumped to {filename}): "
f"{str(err)}") from err