[Misc] Don't dump contents of kvcache tensors on errors (#8527)

This commit is contained in:
Nick Hill 2024-09-17 20:24:29 +01:00 committed by GitHub
parent a54ed80249
commit 56c3de018c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -3,11 +3,13 @@ import pickle
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from datetime import datetime from datetime import datetime
from functools import wraps from functools import wraps
from typing import (TYPE_CHECKING, Any, Dict, Generic, List, Optional, Type, from typing import (TYPE_CHECKING, Any, Dict, Generic, Iterable, List,
TypeVar) Optional, Type, TypeVar)
import torch import torch
from torch import is_tensor
from vllm.logger import init_logger
from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors, SequenceGroupMetadata from vllm.sequence import IntermediateTensors, SequenceGroupMetadata
@ -17,6 +19,8 @@ if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionBackend from vllm.attention.backends.abstract import AttentionBackend
from vllm.model_executor import SamplingMetadata from vllm.model_executor import SamplingMetadata
logger = init_logger(__name__)
T = TypeVar('T', bound="BroadcastableModelInput") T = TypeVar('T', bound="BroadcastableModelInput")
@ -113,6 +117,8 @@ def dump_input_when_exception(exclude_args: Optional[List[int]] = None,
except Exception as err: except Exception as err:
timestamp = datetime.now().strftime("%Y%m%d-%H%M%S") timestamp = datetime.now().strftime("%Y%m%d-%H%M%S")
filename = f"/tmp/err_{func.__name__}_input_{timestamp}.pkl" filename = f"/tmp/err_{func.__name__}_input_{timestamp}.pkl"
logger.info("Writing input of failed execution to %s...",
filename)
with open(filename, "wb") as filep: with open(filename, "wb") as filep:
dumped_inputs = { dumped_inputs = {
k: v k: v
@ -122,7 +128,19 @@ def dump_input_when_exception(exclude_args: Optional[List[int]] = None,
for i, arg in enumerate(args): for i, arg in enumerate(args):
if i not in (exclude_args or []): if i not in (exclude_args or []):
dumped_inputs[f"arg_{i}"] = arg dumped_inputs[f"arg_{i}"] = arg
# Only persist dtype and shape for kvcache tensors
# (can be way to big otherwise)
if (kv_caches := dumped_inputs.get("kv_caches")) \
and isinstance(kv_caches, Iterable):
dumped_inputs["kv_caches"] = [(t.dtype, t.shape)
for t in kv_caches
if is_tensor(t)]
pickle.dump(dumped_inputs, filep) pickle.dump(dumped_inputs, filep)
logger.info(
"Completed writing input of failed execution to %s.",
filename)
raise type(err)( raise type(err)(
f"Error in model execution (input dumped to {filename}): " f"Error in model execution (input dumped to {filename}): "
f"{str(err)}") from err f"{str(err)}") from err