[MISC] Remove model input dumping when exception (#12582)

Signed-off-by: Cody Yu <hao.yu.cody@gmail.com>
This commit is contained in:
Cody Yu 2025-02-03 13:34:16 -08:00 committed by GitHub
parent 4797dad3ec
commit cf58b9c4ca
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 3 additions and 128 deletions

View File

@ -30,15 +30,6 @@ body:
</details>
validations:
required: true
- type: textarea
attributes:
label: Model Input Dumps
description: |
If you are facing crashing due to illegal memory access or other issues with model execution, vLLM may dump the problematic input of the model. In this case, you will see the message `Error in model execution (input dumped to /tmp/err_xxx.pkl)`. If you see this message, please zip the file (because GitHub doesn't support .pkl file format) and upload it here. This will help us to reproduce the issue and facilitate the debugging process.
placeholder: |
Upload the dumped input file.
validations:
required: false
- type: textarea
attributes:
label: 🐛 Describe the bug

View File

@ -4,16 +4,12 @@
Run `pytest tests/basic_correctness/test_basic_correctness.py`.
"""
import os
import pickle
import re
import weakref
from unittest.mock import patch
import pytest
from vllm import LLM
from vllm.platforms import current_platform
from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata
from ..conftest import VllmRunner
from ..models.utils import check_outputs_equal
@ -151,57 +147,3 @@ def test_models_distributed(
name_0="hf",
name_1="vllm",
)
@pytest.mark.skip_v1
def test_model_with_failure(vllm_runner) -> None:
try:
with patch("vllm.model_executor.models.opt.OPTForCausalLM.forward",
side_effect=ValueError()):
with pytest.raises(ValueError) as exc_info:
vllm_runner("facebook/opt-125m",
dtype="half",
enforce_eager=False,
gpu_memory_utilization=0.7)
matches = re.search(r"input dumped to (.+).pkl",
str(exc_info.value))
assert matches is not None
filename = f"{matches.group(1)}.pkl"
with open(filename, "rb") as filep:
inputs = pickle.load(filep)
if any(key not in inputs for key in ("arg_1", "arg_2", "arg_3")):
raise AssertionError("Missing keys in dumped inputs. Dumped keys: "
f"{list(inputs.keys())}")
assert isinstance(inputs["arg_1"],
ModelInputForGPUWithSamplingMetadata)
finally:
os.remove(filename)
@pytest.mark.skip_v1
def test_failure_with_async_out_proc(vllm_runner) -> None:
filename = None
try:
with vllm_runner("facebook/opt-125m",
dtype="half",
enforce_eager=False,
gpu_memory_utilization=0.7) as vllm_model,\
patch("vllm.model_executor.models.opt.OPTForCausalLM.forward",
side_effect=ValueError()):
model_config = vllm_model.model.llm_engine.model_config
assert model_config.use_async_output_proc
with pytest.raises(ValueError) as exc_info:
vllm_model.generate_greedy('how to make pizza?', 250)
matches = re.search(r"input dumped to (.+).pkl",
str(exc_info.value))
assert matches is not None
filename = f"{matches.group(1)}.pkl"
finally:
# Clean up
if filename is not None:
os.remove(filename)
pass

View File

@ -57,7 +57,7 @@ from vllm.worker.model_runner_base import (
_add_attn_metadata_broadcastable_dict,
_add_sampling_metadata_broadcastable_dict,
_init_attn_metadata_from_tensor_dict,
_init_sampling_metadata_from_tensor_dict, dump_input_when_exception)
_init_sampling_metadata_from_tensor_dict)
if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionBackend
@ -1647,7 +1647,6 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
virtual_engine=virtual_engine)
@torch.inference_mode()
@dump_input_when_exception(exclude_args=[0], exclude_kwargs=["self"])
def execute_model(
self,
model_input: ModelInputForGPUWithSamplingMetadata,

View File

@ -1,16 +1,12 @@
# SPDX-License-Identifier: Apache-2.0
import dataclasses
import pickle
from abc import ABC, abstractmethod
from datetime import datetime
from functools import wraps
from typing import (TYPE_CHECKING, Any, Dict, Generic, Iterable, List,
Optional, Type, TypeVar)
from typing import (TYPE_CHECKING, Any, Dict, Generic, List, Optional, Type,
TypeVar)
import torch
import torch.nn as nn
from torch import is_tensor
from vllm.config import VllmConfig
from vllm.logger import init_logger
@ -107,59 +103,6 @@ def _init_frozen_model_input_from_tensor_dict(
return tensor_dict
def dump_input_when_exception(exclude_args: Optional[List[int]] = None,
exclude_kwargs: Optional[List[str]] = None):
def _inner(func):
@wraps(func)
def _wrapper(*args, **kwargs):
try:
return func(*args, **kwargs)
except Exception as err:
timestamp = datetime.now().strftime("%Y%m%d-%H%M%S")
filename = f"/tmp/err_{func.__name__}_input_{timestamp}.pkl"
logger.info("Writing input of failed execution to %s...",
filename)
with open(filename, "wb") as filep:
dumped_inputs = {
k: v
for k, v in kwargs.items()
if k not in (exclude_kwargs or [])
}
for i, arg in enumerate(args):
if i not in (exclude_args or []):
dumped_inputs[f"arg_{i}"] = arg
# Only persist dtype and shape for kvcache tensors
# (can be way to big otherwise)
if (kv_caches := dumped_inputs.get("kv_caches")) \
and isinstance(kv_caches, Iterable):
dumped_inputs["kv_caches"] = [(t.dtype, t.shape)
for t in kv_caches
if is_tensor(t)]
try:
pickle.dump(dumped_inputs, filep)
except Exception as pickle_err:
logger.warning(
"Failed to pickle inputs of failed execution: %s",
str(pickle_err))
raise type(err)(f"Error in model execution: "
f"{str(err)}") from err
logger.info(
"Completed writing input of failed execution to %s.",
filename)
raise type(err)(
f"Error in model execution (input dumped to {filename}): "
f"{str(err)}") from err
return _wrapper
return _inner
class BroadcastableModelInput(ABC):
@abstractmethod