From d43f914d42dc00a59ca8b6d26363cf02b3b898b2 Mon Sep 17 00:00:00 2001 From: Wallas Henrique Date: Wed, 7 May 2025 19:15:09 -0300 Subject: [PATCH] [Core][Feature] Input metadata dump on crash (#13407) Signed-off-by: Wallas Santos --- .github/ISSUE_TEMPLATE/400-bug-report.yml | 2 +- .../test_basic_correctness.py | 49 +++++++++-- vllm/logging_utils/dump_input.py | 84 +++++++++++++++++++ vllm/v1/core/sched/output.py | 27 ++++++ vllm/v1/engine/core.py | 16 +++- 5 files changed, 169 insertions(+), 9 deletions(-) create mode 100644 vllm/logging_utils/dump_input.py diff --git a/.github/ISSUE_TEMPLATE/400-bug-report.yml b/.github/ISSUE_TEMPLATE/400-bug-report.yml index 637d2dd114548..00b0f024c0da5 100644 --- a/.github/ISSUE_TEMPLATE/400-bug-report.yml +++ b/.github/ISSUE_TEMPLATE/400-bug-report.yml @@ -75,7 +75,7 @@ body: ``` ``` - The error message you got, with the full traceback. + The error message you got, with the full traceback and the error logs with [dump_input.py:##] if present. ``` validations: required: true diff --git a/tests/basic_correctness/test_basic_correctness.py b/tests/basic_correctness/test_basic_correctness.py index 1458f0893a93c..9f3b0e8ae079b 100644 --- a/tests/basic_correctness/test_basic_correctness.py +++ b/tests/basic_correctness/test_basic_correctness.py @@ -5,11 +5,13 @@ Run `pytest tests/basic_correctness/test_basic_correctness.py`. """ import os import weakref +from unittest.mock import Mock import pytest from vllm import LLM from vllm.platforms import current_platform +from vllm.v1.engine.llm_engine import LLMEngine as LLMEngineV1 from ..conftest import VllmRunner from ..models.utils import check_outputs_equal @@ -152,9 +154,44 @@ def test_models_distributed( with hf_runner(model, dtype=dtype) as hf_model: hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens) - check_outputs_equal( - outputs_0_lst=hf_outputs, - outputs_1_lst=vllm_outputs, - name_0="hf", - name_1="vllm", - ) + check_outputs_equal( + outputs_0_lst=hf_outputs, + outputs_1_lst=vllm_outputs, + name_0="hf", + name_1="vllm", + ) + + +def test_failed_model_execution(vllm_runner, monkeypatch) -> None: + + from vllm.envs import VLLM_USE_V1 + + if not VLLM_USE_V1: + pytest.skip("Skipping V0 test, dump input not supported") + + # Needed to mock an error in the same process + monkeypatch.setenv('VLLM_ENABLE_V1_MULTIPROCESSING', '0') + + with vllm_runner('facebook/opt-125m', enforce_eager=True) as vllm_model: + if isinstance(vllm_model.model.llm_engine, LLMEngineV1): + v1_test_failed_model_execution(vllm_model) + + +def v1_test_failed_model_execution(vllm_model): + + engine = vllm_model.model.llm_engine + mocked_execute_model = Mock( + side_effect=RuntimeError("Mocked Critical Error")) + engine.engine_core.engine_core.model_executor.execute_model =\ + mocked_execute_model + + with pytest.raises(RuntimeError) as exc_info: + prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", + ] + vllm_model.generate_greedy(prompts, 200, use_tqdm=False) + assert isinstance(exc_info.value, RuntimeError) + assert "Mocked Critical Error" in str(exc_info.value) diff --git a/vllm/logging_utils/dump_input.py b/vllm/logging_utils/dump_input.py new file mode 100644 index 0000000000000..169e247940953 --- /dev/null +++ b/vllm/logging_utils/dump_input.py @@ -0,0 +1,84 @@ +# SPDX-License-Identifier: Apache-2.0 + +import contextlib +import enum +import json +from typing import Optional + +import torch + +from vllm.config import VllmConfig +from vllm.logger import init_logger +from vllm.v1.core.sched.output import SchedulerOutput +from vllm.v1.metrics.stats import SchedulerStats +from vllm.version import __version__ as VLLM_VERSION + +logger = init_logger(__name__) + + +def prepare_object_to_dump(obj) -> str: + if isinstance(obj, str): + return "'{obj}'" # Double quotes + elif isinstance(obj, dict): + dict_str = ', '.join({f'{str(k)}: {prepare_object_to_dump(v)}' \ + for k, v in obj.items()}) + return f'{{{dict_str}}}' + elif isinstance(obj, list): + return f"[{', '.join([prepare_object_to_dump(v) for v in obj])}]" + elif isinstance(obj, set): + return f"[{', '.join([prepare_object_to_dump(v) for v in list(obj)])}]" + # return [prepare_object_to_dump(v) for v in list(obj)] + elif isinstance(obj, tuple): + return f"[{', '.join([prepare_object_to_dump(v) for v in obj])}]" + elif isinstance(obj, enum.Enum): + return repr(obj) + elif isinstance(obj, torch.Tensor): + # We only print the 'draft' of the tensor to not expose sensitive data + # and to get some metadata in case of CUDA runtime crashed + return (f"Tensor(shape={obj.shape}, " + f"device={obj.device}," + f"dtype={obj.dtype})") + elif hasattr(obj, 'anon_repr'): + return obj.anon_repr() + elif hasattr(obj, '__dict__'): + items = obj.__dict__.items() + dict_str = ','.join([f'{str(k)}={prepare_object_to_dump(v)}' \ + for k, v in items]) + return (f"{type(obj).__name__}({dict_str})") + else: + # Hacky way to make sure we can serialize the object in JSON format + try: + return json.dumps(obj) + except (TypeError, OverflowError): + return repr(obj) + + +def dump_engine_exception(config: VllmConfig, + scheduler_output: SchedulerOutput, + scheduler_stats: Optional[SchedulerStats]): + # NOTE: ensure we can log extra info without risking raises + # unexpected errors during logging + with contextlib.suppress(BaseException): + _dump_engine_exception(config, scheduler_output, scheduler_stats) + + +def _dump_engine_exception(config: VllmConfig, + scheduler_output: SchedulerOutput, + scheduler_stats: Optional[SchedulerStats]): + logger.error("Dumping input data") + + logger.error( + "V1 LLM engine (v%s) with config: %s, ", + VLLM_VERSION, + config, + ) + + try: + dump_obj = prepare_object_to_dump(scheduler_output) + logger.error("Dumping scheduler output for model execution:") + logger.error(dump_obj) + if scheduler_stats: + logger.error(scheduler_stats) + except BaseException as exception: + logger.error("Error preparing object to dump") + logger.error(repr(exception)) diff --git a/vllm/v1/core/sched/output.py b/vllm/v1/core/sched/output.py index 928fb231a1f2d..24032498e50ba 100644 --- a/vllm/v1/core/sched/output.py +++ b/vllm/v1/core/sched/output.py @@ -48,6 +48,33 @@ class NewRequestData: lora_request=request.lora_request, ) + def __repr__(self): + return (f"NewRequestData(" + f"req_id={self.req_id}," + f"prompt_token_ids={self.prompt_token_ids}," + f"mm_inputs={self.mm_inputs}," + f"mm_hashes={self.mm_hashes}," + f"mm_positions={self.mm_positions}," + f"sampling_params={self.sampling_params}," + f"block_ids={self.block_ids}," + f"num_computed_tokens={self.num_computed_tokens}," + f"lora_request={self.lora_request}" + ")") + + # Version of __repr__ with the prompt data obfuscated + def anon_repr(self): + return (f"NewRequestData(" + f"req_id={self.req_id}," + f"prompt_token_ids_len={len(self.prompt_token_ids)}," + f"mm_inputs={self.mm_inputs}," + f"mm_hashes={self.mm_hashes}," + f"mm_positions={self.mm_positions}," + f"sampling_params={self.sampling_params}," + f"block_ids={self.block_ids}," + f"num_computed_tokens={self.num_computed_tokens}," + f"lora_request={self.lora_request}" + ")") + @dataclass class CachedRequestData: diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index e772615b7861a..d9dd4957cff2f 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -19,6 +19,7 @@ from vllm.config import ParallelConfig, VllmConfig from vllm.distributed import stateless_destroy_torch_distributed_process_group from vllm.executor.multiproc_worker_utils import _add_prefix from vllm.logger import init_logger +from vllm.logging_utils.dump_input import dump_engine_exception from vllm.lora.request import LoRARequest from vllm.transformers_utils.config import ( maybe_register_config_serialize_by_value) @@ -56,6 +57,7 @@ class EngineCore: executor_fail_callback: Optional[Callable] = None): assert vllm_config.model_config.runner_type != "pooling" + self.vllm_config = vllm_config logger.info("Initializing a V1 LLM engine (v%s) with config: %s", VLLM_VERSION, vllm_config) @@ -191,6 +193,16 @@ class EngineCore: self.scheduler.finish_requests(request_ids, RequestStatus.FINISHED_ABORTED) + def execute_model(self, scheduler_output: SchedulerOutput): + try: + return self.model_executor.execute_model(scheduler_output) + except BaseException as err: + # NOTE: This method is exception-free + dump_engine_exception(self.vllm_config, scheduler_output, + self.scheduler.make_stats()) + # Re-raise exception + raise err + def step(self) -> EngineCoreOutputs: """Schedule, execute, and make output.""" @@ -202,9 +214,9 @@ class EngineCore: scheduler_stats=self.scheduler.make_stats(), ) scheduler_output = self.scheduler.schedule() - output = self.model_executor.execute_model(scheduler_output) + model_output = self.execute_model(scheduler_output) engine_core_outputs = self.scheduler.update_from_output( - scheduler_output, output) # type: ignore + scheduler_output, model_output) # type: ignore return engine_core_outputs