[Core][Feature] Input metadata dump on crash (#13407)

Signed-off-by: Wallas Santos <wallashss@ibm.com>
This commit is contained in:
Wallas Henrique 2025-05-07 19:15:09 -03:00 committed by GitHub
parent ed5272cf21
commit d43f914d42
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 169 additions and 9 deletions

View File

@ -75,7 +75,7 @@ body:
```
```
The error message you got, with the full traceback.
The error message you got, with the full traceback and the error logs with [dump_input.py:##] if present.
```
validations:
required: true

View File

@ -5,11 +5,13 @@ Run `pytest tests/basic_correctness/test_basic_correctness.py`.
"""
import os
import weakref
from unittest.mock import Mock
import pytest
from vllm import LLM
from vllm.platforms import current_platform
from vllm.v1.engine.llm_engine import LLMEngine as LLMEngineV1
from ..conftest import VllmRunner
from ..models.utils import check_outputs_equal
@ -152,9 +154,44 @@ def test_models_distributed(
with hf_runner(model, dtype=dtype) as hf_model:
hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens)
check_outputs_equal(
outputs_0_lst=hf_outputs,
outputs_1_lst=vllm_outputs,
name_0="hf",
name_1="vllm",
)
check_outputs_equal(
outputs_0_lst=hf_outputs,
outputs_1_lst=vllm_outputs,
name_0="hf",
name_1="vllm",
)
def test_failed_model_execution(vllm_runner, monkeypatch) -> None:
from vllm.envs import VLLM_USE_V1
if not VLLM_USE_V1:
pytest.skip("Skipping V0 test, dump input not supported")
# Needed to mock an error in the same process
monkeypatch.setenv('VLLM_ENABLE_V1_MULTIPROCESSING', '0')
with vllm_runner('facebook/opt-125m', enforce_eager=True) as vllm_model:
if isinstance(vllm_model.model.llm_engine, LLMEngineV1):
v1_test_failed_model_execution(vllm_model)
def v1_test_failed_model_execution(vllm_model):
engine = vllm_model.model.llm_engine
mocked_execute_model = Mock(
side_effect=RuntimeError("Mocked Critical Error"))
engine.engine_core.engine_core.model_executor.execute_model =\
mocked_execute_model
with pytest.raises(RuntimeError) as exc_info:
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
vllm_model.generate_greedy(prompts, 200, use_tqdm=False)
assert isinstance(exc_info.value, RuntimeError)
assert "Mocked Critical Error" in str(exc_info.value)

View File

@ -0,0 +1,84 @@
# SPDX-License-Identifier: Apache-2.0
import contextlib
import enum
import json
from typing import Optional
import torch
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.metrics.stats import SchedulerStats
from vllm.version import __version__ as VLLM_VERSION
logger = init_logger(__name__)
def prepare_object_to_dump(obj) -> str:
if isinstance(obj, str):
return "'{obj}'" # Double quotes
elif isinstance(obj, dict):
dict_str = ', '.join({f'{str(k)}: {prepare_object_to_dump(v)}' \
for k, v in obj.items()})
return f'{{{dict_str}}}'
elif isinstance(obj, list):
return f"[{', '.join([prepare_object_to_dump(v) for v in obj])}]"
elif isinstance(obj, set):
return f"[{', '.join([prepare_object_to_dump(v) for v in list(obj)])}]"
# return [prepare_object_to_dump(v) for v in list(obj)]
elif isinstance(obj, tuple):
return f"[{', '.join([prepare_object_to_dump(v) for v in obj])}]"
elif isinstance(obj, enum.Enum):
return repr(obj)
elif isinstance(obj, torch.Tensor):
# We only print the 'draft' of the tensor to not expose sensitive data
# and to get some metadata in case of CUDA runtime crashed
return (f"Tensor(shape={obj.shape}, "
f"device={obj.device},"
f"dtype={obj.dtype})")
elif hasattr(obj, 'anon_repr'):
return obj.anon_repr()
elif hasattr(obj, '__dict__'):
items = obj.__dict__.items()
dict_str = ','.join([f'{str(k)}={prepare_object_to_dump(v)}' \
for k, v in items])
return (f"{type(obj).__name__}({dict_str})")
else:
# Hacky way to make sure we can serialize the object in JSON format
try:
return json.dumps(obj)
except (TypeError, OverflowError):
return repr(obj)
def dump_engine_exception(config: VllmConfig,
scheduler_output: SchedulerOutput,
scheduler_stats: Optional[SchedulerStats]):
# NOTE: ensure we can log extra info without risking raises
# unexpected errors during logging
with contextlib.suppress(BaseException):
_dump_engine_exception(config, scheduler_output, scheduler_stats)
def _dump_engine_exception(config: VllmConfig,
scheduler_output: SchedulerOutput,
scheduler_stats: Optional[SchedulerStats]):
logger.error("Dumping input data")
logger.error(
"V1 LLM engine (v%s) with config: %s, ",
VLLM_VERSION,
config,
)
try:
dump_obj = prepare_object_to_dump(scheduler_output)
logger.error("Dumping scheduler output for model execution:")
logger.error(dump_obj)
if scheduler_stats:
logger.error(scheduler_stats)
except BaseException as exception:
logger.error("Error preparing object to dump")
logger.error(repr(exception))

View File

@ -48,6 +48,33 @@ class NewRequestData:
lora_request=request.lora_request,
)
def __repr__(self):
return (f"NewRequestData("
f"req_id={self.req_id},"
f"prompt_token_ids={self.prompt_token_ids},"
f"mm_inputs={self.mm_inputs},"
f"mm_hashes={self.mm_hashes},"
f"mm_positions={self.mm_positions},"
f"sampling_params={self.sampling_params},"
f"block_ids={self.block_ids},"
f"num_computed_tokens={self.num_computed_tokens},"
f"lora_request={self.lora_request}"
")")
# Version of __repr__ with the prompt data obfuscated
def anon_repr(self):
return (f"NewRequestData("
f"req_id={self.req_id},"
f"prompt_token_ids_len={len(self.prompt_token_ids)},"
f"mm_inputs={self.mm_inputs},"
f"mm_hashes={self.mm_hashes},"
f"mm_positions={self.mm_positions},"
f"sampling_params={self.sampling_params},"
f"block_ids={self.block_ids},"
f"num_computed_tokens={self.num_computed_tokens},"
f"lora_request={self.lora_request}"
")")
@dataclass
class CachedRequestData:

View File

@ -19,6 +19,7 @@ from vllm.config import ParallelConfig, VllmConfig
from vllm.distributed import stateless_destroy_torch_distributed_process_group
from vllm.executor.multiproc_worker_utils import _add_prefix
from vllm.logger import init_logger
from vllm.logging_utils.dump_input import dump_engine_exception
from vllm.lora.request import LoRARequest
from vllm.transformers_utils.config import (
maybe_register_config_serialize_by_value)
@ -56,6 +57,7 @@ class EngineCore:
executor_fail_callback: Optional[Callable] = None):
assert vllm_config.model_config.runner_type != "pooling"
self.vllm_config = vllm_config
logger.info("Initializing a V1 LLM engine (v%s) with config: %s",
VLLM_VERSION, vllm_config)
@ -191,6 +193,16 @@ class EngineCore:
self.scheduler.finish_requests(request_ids,
RequestStatus.FINISHED_ABORTED)
def execute_model(self, scheduler_output: SchedulerOutput):
try:
return self.model_executor.execute_model(scheduler_output)
except BaseException as err:
# NOTE: This method is exception-free
dump_engine_exception(self.vllm_config, scheduler_output,
self.scheduler.make_stats())
# Re-raise exception
raise err
def step(self) -> EngineCoreOutputs:
"""Schedule, execute, and make output."""
@ -202,9 +214,9 @@ class EngineCore:
scheduler_stats=self.scheduler.make_stats(),
)
scheduler_output = self.scheduler.schedule()
output = self.model_executor.execute_model(scheduler_output)
model_output = self.execute_model(scheduler_output)
engine_core_outputs = self.scheduler.update_from_output(
scheduler_output, output) # type: ignore
scheduler_output, model_output) # type: ignore
return engine_core_outputs