mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 02:25:01 +08:00
[Bugfix] Do not crash V0 engine on input errors (#13101)
Signed-off-by: Joe Runde <Joseph.Runde@ibm.com>
This commit is contained in:
parent
ec8a5e5386
commit
3f808cc044
@ -18,6 +18,7 @@ from vllm.engine.multiprocessing.engine import MQLLMEngine
|
|||||||
from vllm.entrypoints.openai.api_server import build_async_engine_client
|
from vllm.entrypoints.openai.api_server import build_async_engine_client
|
||||||
from vllm.entrypoints.openai.cli_args import make_arg_parser
|
from vllm.entrypoints.openai.cli_args import make_arg_parser
|
||||||
from vllm.lora.request import LoRARequest
|
from vllm.lora.request import LoRARequest
|
||||||
|
from vllm.sequence import SequenceGroupMetadata
|
||||||
from vllm.usage.usage_lib import UsageContext
|
from vllm.usage.usage_lib import UsageContext
|
||||||
from vllm.utils import FlexibleArgumentParser
|
from vllm.utils import FlexibleArgumentParser
|
||||||
|
|
||||||
@ -292,3 +293,80 @@ async def test_engine_process_death(tmp_socket):
|
|||||||
await client.check_health()
|
await client.check_health()
|
||||||
|
|
||||||
client.close()
|
client.close()
|
||||||
|
|
||||||
|
|
||||||
|
def run_with_evil_input_processing(engine_args: AsyncEngineArgs,
|
||||||
|
ipc_path: str):
|
||||||
|
"""Simulate an exception while preparing inputs for the model.
|
||||||
|
In the wild, this could be something like a multimodal input processor
|
||||||
|
failing on invalid image data."""
|
||||||
|
|
||||||
|
# Make engine.
|
||||||
|
engine = MQLLMEngine.from_engine_args(
|
||||||
|
engine_args=engine_args,
|
||||||
|
usage_context=UsageContext.UNKNOWN_CONTEXT,
|
||||||
|
ipc_path=ipc_path)
|
||||||
|
|
||||||
|
runner = engine.engine.model_executor.driver_worker.worker.model_runner
|
||||||
|
|
||||||
|
# Raise error in the model runner when adding a sequence group.
|
||||||
|
# See class ModelInputForGPUBuilder
|
||||||
|
def raiser(_, seq_group_metadata: SequenceGroupMetadata):
|
||||||
|
if seq_group_metadata.request_id.startswith("evil"):
|
||||||
|
raise RAISED_ERROR(RAISED_VALUE)
|
||||||
|
|
||||||
|
runner.builder.per_seq_group_compute_fns.append(raiser)
|
||||||
|
|
||||||
|
# Run engine.
|
||||||
|
engine.start()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_failed_inputs(tmp_socket):
|
||||||
|
with RemoteMQLLMEngine(engine_args=ENGINE_ARGS,
|
||||||
|
ipc_path=tmp_socket,
|
||||||
|
run_fn=run_with_evil_input_processing) as engine:
|
||||||
|
|
||||||
|
client = await engine.make_client()
|
||||||
|
assert client.is_running
|
||||||
|
|
||||||
|
# Engine should be healthy
|
||||||
|
await client.check_health()
|
||||||
|
|
||||||
|
async def run_failing_request():
|
||||||
|
async for _ in client.generate(
|
||||||
|
prompt="Hello my name is",
|
||||||
|
sampling_params=SamplingParams(max_tokens=10),
|
||||||
|
request_id="evil" + str(uuid.uuid4())):
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def run_passing_request():
|
||||||
|
async for _ in client.generate(
|
||||||
|
prompt="Hello my name is",
|
||||||
|
sampling_params=SamplingParams(max_tokens=10),
|
||||||
|
request_id=str(uuid.uuid4())):
|
||||||
|
pass
|
||||||
|
|
||||||
|
passing_tasks = [
|
||||||
|
asyncio.create_task(run_passing_request()) for _ in range(10)
|
||||||
|
]
|
||||||
|
failing_tasks = [
|
||||||
|
asyncio.create_task(run_failing_request()) for _ in range(10)
|
||||||
|
]
|
||||||
|
await asyncio.gather(*failing_tasks, return_exceptions=True)
|
||||||
|
await asyncio.gather(*passing_tasks)
|
||||||
|
|
||||||
|
# All the bad inputs should have raised
|
||||||
|
for task in failing_tasks:
|
||||||
|
with pytest.raises(RAISED_ERROR):
|
||||||
|
task.result()
|
||||||
|
|
||||||
|
# But all good inputs should have still succeeded
|
||||||
|
for task in passing_tasks:
|
||||||
|
task.result()
|
||||||
|
|
||||||
|
# And the engine should remain healthy
|
||||||
|
assert not client.errored
|
||||||
|
await client.check_health()
|
||||||
|
|
||||||
|
client.close()
|
||||||
|
|||||||
@ -60,6 +60,7 @@ from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled,
|
|||||||
from vllm.utils import (Counter, Device, deprecate_kwargs,
|
from vllm.utils import (Counter, Device, deprecate_kwargs,
|
||||||
resolve_obj_by_qualname, weak_bind)
|
resolve_obj_by_qualname, weak_bind)
|
||||||
from vllm.version import __version__ as VLLM_VERSION
|
from vllm.version import __version__ as VLLM_VERSION
|
||||||
|
from vllm.worker.model_runner_base import InputProcessingError
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
_LOCAL_LOGGING_INTERVAL_SEC = 5
|
_LOCAL_LOGGING_INTERVAL_SEC = 5
|
||||||
@ -410,6 +411,10 @@ class LLMEngine:
|
|||||||
|
|
||||||
self.seq_id_to_seq_group: Dict[str, SequenceGroupBase] = {}
|
self.seq_id_to_seq_group: Dict[str, SequenceGroupBase] = {}
|
||||||
|
|
||||||
|
# Flag to set when an input fails to process and the engine should run
|
||||||
|
# the next step without re-scheduling.
|
||||||
|
self._skip_scheduling_next_step = False
|
||||||
|
|
||||||
def _initialize_kv_caches(self) -> None:
|
def _initialize_kv_caches(self) -> None:
|
||||||
"""Initialize the KV cache in the worker(s).
|
"""Initialize the KV cache in the worker(s).
|
||||||
|
|
||||||
@ -1334,7 +1339,11 @@ class LLMEngine:
|
|||||||
# Skip the scheduler if there are any remaining steps in the seq groups.
|
# Skip the scheduler if there are any remaining steps in the seq groups.
|
||||||
# This ensures that the scheduler is only called again when the current
|
# This ensures that the scheduler is only called again when the current
|
||||||
# batch has completed.
|
# batch has completed.
|
||||||
if not self._has_remaining_steps(seq_group_metadata_list):
|
# The scheduler is also skipped if a single request caused the last
|
||||||
|
# engine step to fail, and the previous schedule needs to be rerun.
|
||||||
|
if not self._has_remaining_steps(
|
||||||
|
seq_group_metadata_list
|
||||||
|
) and not self._skip_scheduling_next_step:
|
||||||
# Schedule iteration
|
# Schedule iteration
|
||||||
(seq_group_metadata_list, scheduler_outputs,
|
(seq_group_metadata_list, scheduler_outputs,
|
||||||
allow_async_output_proc
|
allow_async_output_proc
|
||||||
@ -1388,8 +1397,23 @@ class LLMEngine:
|
|||||||
execute_model_req.async_callback = self.async_callbacks[
|
execute_model_req.async_callback = self.async_callbacks[
|
||||||
virtual_engine]
|
virtual_engine]
|
||||||
|
|
||||||
outputs = self.model_executor.execute_model(
|
try:
|
||||||
execute_model_req=execute_model_req)
|
outputs = self.model_executor.execute_model(
|
||||||
|
execute_model_req=execute_model_req)
|
||||||
|
self._skip_scheduling_next_step = False
|
||||||
|
except InputProcessingError as e:
|
||||||
|
# The input for this request cannot be processed, so we must
|
||||||
|
# abort it. If there are remaining requests in the batch that
|
||||||
|
# have been scheduled, they will be retried on the next step.
|
||||||
|
invalid_request_id = e.request_id
|
||||||
|
self._abort_and_cache_schedule(
|
||||||
|
request_id=invalid_request_id,
|
||||||
|
virtual_engine=virtual_engine,
|
||||||
|
seq_group_metadata_list=seq_group_metadata_list,
|
||||||
|
scheduler_outputs=scheduler_outputs,
|
||||||
|
allow_async_output_proc=allow_async_output_proc)
|
||||||
|
# Raise so the caller is notified that this request failed
|
||||||
|
raise
|
||||||
|
|
||||||
# We need to do this here so that last step's sampled_token_ids can
|
# We need to do this here so that last step's sampled_token_ids can
|
||||||
# be passed to the next iteration for PP.
|
# be passed to the next iteration for PP.
|
||||||
@ -1464,6 +1488,38 @@ class LLMEngine:
|
|||||||
|
|
||||||
return ctx.request_outputs
|
return ctx.request_outputs
|
||||||
|
|
||||||
|
def _abort_and_cache_schedule(
|
||||||
|
self, request_id: str, virtual_engine: int,
|
||||||
|
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||||
|
scheduler_outputs: SchedulerOutputs,
|
||||||
|
allow_async_output_proc: bool) -> None:
|
||||||
|
"""Aborts a single request, and caches the scheduler outputs minus that
|
||||||
|
request. This allows the next step to continue processing the remaining
|
||||||
|
requests without having to re-run the scheduler."""
|
||||||
|
|
||||||
|
# Abort the request and remove its sequence group from the current
|
||||||
|
# schedule
|
||||||
|
self.abort_request(request_id)
|
||||||
|
for i, metadata in enumerate(seq_group_metadata_list):
|
||||||
|
if metadata.request_id == request_id:
|
||||||
|
del seq_group_metadata_list[i]
|
||||||
|
break
|
||||||
|
for i, group in enumerate(scheduler_outputs.scheduled_seq_groups):
|
||||||
|
if group.seq_group.request_id == request_id:
|
||||||
|
del scheduler_outputs.scheduled_seq_groups[i]
|
||||||
|
break
|
||||||
|
|
||||||
|
# If there are still other sequence groups left in the schedule, cache
|
||||||
|
# them and flag the engine to reuse the schedule.
|
||||||
|
if len(seq_group_metadata_list) > 0:
|
||||||
|
self._skip_scheduling_next_step = True
|
||||||
|
# Reuse multi-step caching logic
|
||||||
|
self._cache_scheduler_outputs_for_multi_step(
|
||||||
|
virtual_engine=virtual_engine,
|
||||||
|
scheduler_outputs=scheduler_outputs,
|
||||||
|
seq_group_metadata_list=seq_group_metadata_list,
|
||||||
|
allow_async_output_proc=allow_async_output_proc)
|
||||||
|
|
||||||
def _has_remaining_steps(
|
def _has_remaining_steps(
|
||||||
self, seq_group_metadata_list: Optional[List[SequenceGroupMetadata]]
|
self, seq_group_metadata_list: Optional[List[SequenceGroupMetadata]]
|
||||||
) -> bool:
|
) -> bool:
|
||||||
|
|||||||
@ -27,6 +27,7 @@ from vllm.engine.multiprocessing import (ENGINE_DEAD_ERROR, IPC_DATA_EXT,
|
|||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.outputs import RequestOutput
|
from vllm.outputs import RequestOutput
|
||||||
from vllm.usage.usage_lib import UsageContext
|
from vllm.usage.usage_lib import UsageContext
|
||||||
|
from vllm.worker.model_runner_base import InputProcessingError
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
@ -210,6 +211,14 @@ class MQLLMEngine:
|
|||||||
return self.engine.step()
|
return self.engine.step()
|
||||||
except SystemExit:
|
except SystemExit:
|
||||||
raise
|
raise
|
||||||
|
except InputProcessingError as e:
|
||||||
|
# Special case where we handle an error preparing the inputs for
|
||||||
|
# a single request in the batch
|
||||||
|
rpc_err = RPCError(request_id=e.request_id,
|
||||||
|
is_engine_errored=False,
|
||||||
|
exception=e.__cause__)
|
||||||
|
self._send_outputs(rpc_err)
|
||||||
|
return []
|
||||||
except BaseException as e:
|
except BaseException as e:
|
||||||
self._set_errored(e)
|
self._set_errored(e)
|
||||||
rpc_err = RPCError(request_id=None,
|
rpc_err = RPCError(request_id=None,
|
||||||
|
|||||||
@ -53,8 +53,8 @@ from vllm.utils import (DeviceMemoryProfiler, GiB_bytes, PyObjectCache,
|
|||||||
is_pin_memory_available, supports_dynamo,
|
is_pin_memory_available, supports_dynamo,
|
||||||
weak_ref_tensor)
|
weak_ref_tensor)
|
||||||
from vllm.worker.model_runner_base import (
|
from vllm.worker.model_runner_base import (
|
||||||
ModelRunnerBase, ModelRunnerInputBase, ModelRunnerInputBuilderBase,
|
InputProcessingError, ModelRunnerBase, ModelRunnerInputBase,
|
||||||
_add_attn_metadata_broadcastable_dict,
|
ModelRunnerInputBuilderBase, _add_attn_metadata_broadcastable_dict,
|
||||||
_add_sampling_metadata_broadcastable_dict,
|
_add_sampling_metadata_broadcastable_dict,
|
||||||
_init_attn_metadata_from_tensor_dict,
|
_init_attn_metadata_from_tensor_dict,
|
||||||
_init_sampling_metadata_from_tensor_dict)
|
_init_sampling_metadata_from_tensor_dict)
|
||||||
@ -1216,7 +1216,12 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
|
|||||||
"""
|
"""
|
||||||
self.builder.prepare(finished_requests_ids)
|
self.builder.prepare(finished_requests_ids)
|
||||||
for seq_group_metadata in seq_group_metadata_list:
|
for seq_group_metadata in seq_group_metadata_list:
|
||||||
self.builder.add_seq_group(seq_group_metadata)
|
try:
|
||||||
|
self.builder.add_seq_group(seq_group_metadata)
|
||||||
|
except Exception as e:
|
||||||
|
# Raise an exception that tracks the ID of the bad request
|
||||||
|
raise InputProcessingError(seq_group_metadata.request_id,
|
||||||
|
str(e)) from e
|
||||||
|
|
||||||
self.builder.reset_cached_inter_data()
|
self.builder.reset_cached_inter_data()
|
||||||
|
|
||||||
|
|||||||
@ -261,3 +261,21 @@ class ModelRunnerWrapperBase:
|
|||||||
|
|
||||||
def __getattr__(self, attr):
|
def __getattr__(self, attr):
|
||||||
return getattr(self.model_runner, attr)
|
return getattr(self.model_runner, attr)
|
||||||
|
|
||||||
|
|
||||||
|
class InputProcessingError(Exception):
|
||||||
|
"""This exception is raised when an error occurs preparing the inputs for
|
||||||
|
a single sequence group.
|
||||||
|
This allows the engine to gracefully handle errors with a single sequence
|
||||||
|
group without having to fail the entire batch.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, request_id, message):
|
||||||
|
"""request_id is the id of the offending sequence group"""
|
||||||
|
self.request_id = request_id
|
||||||
|
self.message = message
|
||||||
|
super().__init__(self.message)
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return "Failed to prepare inputs for sequence group with request id: " \
|
||||||
|
f"{self.request_id}, Error: {self.message}"
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user