mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-23 22:44:29 +08:00
[Core][Feature] Input metadata dump on crash (#13407)
Signed-off-by: Wallas Santos <wallashss@ibm.com>
This commit is contained in:
parent
ed5272cf21
commit
d43f914d42
2
.github/ISSUE_TEMPLATE/400-bug-report.yml
vendored
2
.github/ISSUE_TEMPLATE/400-bug-report.yml
vendored
@ -75,7 +75,7 @@ body:
|
||||
```
|
||||
|
||||
```
|
||||
The error message you got, with the full traceback.
|
||||
The error message you got, with the full traceback and the error logs with [dump_input.py:##] if present.
|
||||
```
|
||||
validations:
|
||||
required: true
|
||||
|
||||
@ -5,11 +5,13 @@ Run `pytest tests/basic_correctness/test_basic_correctness.py`.
|
||||
"""
|
||||
import os
|
||||
import weakref
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm import LLM
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.v1.engine.llm_engine import LLMEngine as LLMEngineV1
|
||||
|
||||
from ..conftest import VllmRunner
|
||||
from ..models.utils import check_outputs_equal
|
||||
@ -152,9 +154,44 @@ def test_models_distributed(
|
||||
with hf_runner(model, dtype=dtype) as hf_model:
|
||||
hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens)
|
||||
|
||||
check_outputs_equal(
|
||||
outputs_0_lst=hf_outputs,
|
||||
outputs_1_lst=vllm_outputs,
|
||||
name_0="hf",
|
||||
name_1="vllm",
|
||||
)
|
||||
check_outputs_equal(
|
||||
outputs_0_lst=hf_outputs,
|
||||
outputs_1_lst=vllm_outputs,
|
||||
name_0="hf",
|
||||
name_1="vllm",
|
||||
)
|
||||
|
||||
|
||||
def test_failed_model_execution(vllm_runner, monkeypatch) -> None:
|
||||
|
||||
from vllm.envs import VLLM_USE_V1
|
||||
|
||||
if not VLLM_USE_V1:
|
||||
pytest.skip("Skipping V0 test, dump input not supported")
|
||||
|
||||
# Needed to mock an error in the same process
|
||||
monkeypatch.setenv('VLLM_ENABLE_V1_MULTIPROCESSING', '0')
|
||||
|
||||
with vllm_runner('facebook/opt-125m', enforce_eager=True) as vllm_model:
|
||||
if isinstance(vllm_model.model.llm_engine, LLMEngineV1):
|
||||
v1_test_failed_model_execution(vllm_model)
|
||||
|
||||
|
||||
def v1_test_failed_model_execution(vllm_model):
|
||||
|
||||
engine = vllm_model.model.llm_engine
|
||||
mocked_execute_model = Mock(
|
||||
side_effect=RuntimeError("Mocked Critical Error"))
|
||||
engine.engine_core.engine_core.model_executor.execute_model =\
|
||||
mocked_execute_model
|
||||
|
||||
with pytest.raises(RuntimeError) as exc_info:
|
||||
prompts = [
|
||||
"Hello, my name is",
|
||||
"The president of the United States is",
|
||||
"The capital of France is",
|
||||
"The future of AI is",
|
||||
]
|
||||
vllm_model.generate_greedy(prompts, 200, use_tqdm=False)
|
||||
assert isinstance(exc_info.value, RuntimeError)
|
||||
assert "Mocked Critical Error" in str(exc_info.value)
|
||||
|
||||
84
vllm/logging_utils/dump_input.py
Normal file
84
vllm/logging_utils/dump_input.py
Normal file
@ -0,0 +1,84 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import contextlib
|
||||
import enum
|
||||
import json
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
from vllm.v1.metrics.stats import SchedulerStats
|
||||
from vllm.version import __version__ as VLLM_VERSION
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def prepare_object_to_dump(obj) -> str:
|
||||
if isinstance(obj, str):
|
||||
return "'{obj}'" # Double quotes
|
||||
elif isinstance(obj, dict):
|
||||
dict_str = ', '.join({f'{str(k)}: {prepare_object_to_dump(v)}' \
|
||||
for k, v in obj.items()})
|
||||
return f'{{{dict_str}}}'
|
||||
elif isinstance(obj, list):
|
||||
return f"[{', '.join([prepare_object_to_dump(v) for v in obj])}]"
|
||||
elif isinstance(obj, set):
|
||||
return f"[{', '.join([prepare_object_to_dump(v) for v in list(obj)])}]"
|
||||
# return [prepare_object_to_dump(v) for v in list(obj)]
|
||||
elif isinstance(obj, tuple):
|
||||
return f"[{', '.join([prepare_object_to_dump(v) for v in obj])}]"
|
||||
elif isinstance(obj, enum.Enum):
|
||||
return repr(obj)
|
||||
elif isinstance(obj, torch.Tensor):
|
||||
# We only print the 'draft' of the tensor to not expose sensitive data
|
||||
# and to get some metadata in case of CUDA runtime crashed
|
||||
return (f"Tensor(shape={obj.shape}, "
|
||||
f"device={obj.device},"
|
||||
f"dtype={obj.dtype})")
|
||||
elif hasattr(obj, 'anon_repr'):
|
||||
return obj.anon_repr()
|
||||
elif hasattr(obj, '__dict__'):
|
||||
items = obj.__dict__.items()
|
||||
dict_str = ','.join([f'{str(k)}={prepare_object_to_dump(v)}' \
|
||||
for k, v in items])
|
||||
return (f"{type(obj).__name__}({dict_str})")
|
||||
else:
|
||||
# Hacky way to make sure we can serialize the object in JSON format
|
||||
try:
|
||||
return json.dumps(obj)
|
||||
except (TypeError, OverflowError):
|
||||
return repr(obj)
|
||||
|
||||
|
||||
def dump_engine_exception(config: VllmConfig,
|
||||
scheduler_output: SchedulerOutput,
|
||||
scheduler_stats: Optional[SchedulerStats]):
|
||||
# NOTE: ensure we can log extra info without risking raises
|
||||
# unexpected errors during logging
|
||||
with contextlib.suppress(BaseException):
|
||||
_dump_engine_exception(config, scheduler_output, scheduler_stats)
|
||||
|
||||
|
||||
def _dump_engine_exception(config: VllmConfig,
|
||||
scheduler_output: SchedulerOutput,
|
||||
scheduler_stats: Optional[SchedulerStats]):
|
||||
logger.error("Dumping input data")
|
||||
|
||||
logger.error(
|
||||
"V1 LLM engine (v%s) with config: %s, ",
|
||||
VLLM_VERSION,
|
||||
config,
|
||||
)
|
||||
|
||||
try:
|
||||
dump_obj = prepare_object_to_dump(scheduler_output)
|
||||
logger.error("Dumping scheduler output for model execution:")
|
||||
logger.error(dump_obj)
|
||||
if scheduler_stats:
|
||||
logger.error(scheduler_stats)
|
||||
except BaseException as exception:
|
||||
logger.error("Error preparing object to dump")
|
||||
logger.error(repr(exception))
|
||||
@ -48,6 +48,33 @@ class NewRequestData:
|
||||
lora_request=request.lora_request,
|
||||
)
|
||||
|
||||
def __repr__(self):
|
||||
return (f"NewRequestData("
|
||||
f"req_id={self.req_id},"
|
||||
f"prompt_token_ids={self.prompt_token_ids},"
|
||||
f"mm_inputs={self.mm_inputs},"
|
||||
f"mm_hashes={self.mm_hashes},"
|
||||
f"mm_positions={self.mm_positions},"
|
||||
f"sampling_params={self.sampling_params},"
|
||||
f"block_ids={self.block_ids},"
|
||||
f"num_computed_tokens={self.num_computed_tokens},"
|
||||
f"lora_request={self.lora_request}"
|
||||
")")
|
||||
|
||||
# Version of __repr__ with the prompt data obfuscated
|
||||
def anon_repr(self):
|
||||
return (f"NewRequestData("
|
||||
f"req_id={self.req_id},"
|
||||
f"prompt_token_ids_len={len(self.prompt_token_ids)},"
|
||||
f"mm_inputs={self.mm_inputs},"
|
||||
f"mm_hashes={self.mm_hashes},"
|
||||
f"mm_positions={self.mm_positions},"
|
||||
f"sampling_params={self.sampling_params},"
|
||||
f"block_ids={self.block_ids},"
|
||||
f"num_computed_tokens={self.num_computed_tokens},"
|
||||
f"lora_request={self.lora_request}"
|
||||
")")
|
||||
|
||||
|
||||
@dataclass
|
||||
class CachedRequestData:
|
||||
|
||||
@ -19,6 +19,7 @@ from vllm.config import ParallelConfig, VllmConfig
|
||||
from vllm.distributed import stateless_destroy_torch_distributed_process_group
|
||||
from vllm.executor.multiproc_worker_utils import _add_prefix
|
||||
from vllm.logger import init_logger
|
||||
from vllm.logging_utils.dump_input import dump_engine_exception
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.transformers_utils.config import (
|
||||
maybe_register_config_serialize_by_value)
|
||||
@ -56,6 +57,7 @@ class EngineCore:
|
||||
executor_fail_callback: Optional[Callable] = None):
|
||||
assert vllm_config.model_config.runner_type != "pooling"
|
||||
|
||||
self.vllm_config = vllm_config
|
||||
logger.info("Initializing a V1 LLM engine (v%s) with config: %s",
|
||||
VLLM_VERSION, vllm_config)
|
||||
|
||||
@ -191,6 +193,16 @@ class EngineCore:
|
||||
self.scheduler.finish_requests(request_ids,
|
||||
RequestStatus.FINISHED_ABORTED)
|
||||
|
||||
def execute_model(self, scheduler_output: SchedulerOutput):
|
||||
try:
|
||||
return self.model_executor.execute_model(scheduler_output)
|
||||
except BaseException as err:
|
||||
# NOTE: This method is exception-free
|
||||
dump_engine_exception(self.vllm_config, scheduler_output,
|
||||
self.scheduler.make_stats())
|
||||
# Re-raise exception
|
||||
raise err
|
||||
|
||||
def step(self) -> EngineCoreOutputs:
|
||||
"""Schedule, execute, and make output."""
|
||||
|
||||
@ -202,9 +214,9 @@ class EngineCore:
|
||||
scheduler_stats=self.scheduler.make_stats(),
|
||||
)
|
||||
scheduler_output = self.scheduler.schedule()
|
||||
output = self.model_executor.execute_model(scheduler_output)
|
||||
model_output = self.execute_model(scheduler_output)
|
||||
engine_core_outputs = self.scheduler.update_from_output(
|
||||
scheduler_output, output) # type: ignore
|
||||
scheduler_output, model_output) # type: ignore
|
||||
|
||||
return engine_core_outputs
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user