diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index faf34d95735f4..f644504a5b937 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -319,7 +319,10 @@ steps: # TODO: accuracy does not match, whether setting # VLLM_USE_FLASHINFER_SAMPLER or not on H100. - pytest -v -s v1/e2e - - pytest -v -s v1/engine + # Run this test standalone for now; + # need to untangle use (implicit) use of spawn/fork across the tests. + - pytest -v -s v1/engine/test_preprocess_error_handling.py + - pytest -v -s v1/engine --ignore v1/engine/test_preprocess_error_handling.py - label: V1 Test entrypoints # 35min timeout_in_minutes: 50 diff --git a/tests/v1/engine/test_preprocess_error_handling.py b/tests/v1/engine/test_preprocess_error_handling.py new file mode 100644 index 0000000000000..0586cc64fa104 --- /dev/null +++ b/tests/v1/engine/test_preprocess_error_handling.py @@ -0,0 +1,56 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import pytest +import torch.cuda + +from vllm import LLM, SamplingParams +from vllm.v1.engine import EngineCoreRequest +from vllm.v1.engine.core import EngineCore + +MODEL_NAME = "hmellor/tiny-random-LlamaForCausalLM" + + +def test_preprocess_error_handling(monkeypatch: pytest.MonkeyPatch): + """Test that preprocessing errors are handled gracefully.""" + + assert not torch.cuda.is_initialized(), ( + "fork needs to be used for the engine " + "core process and this isn't possible if cuda is already initialized" + ) + + # Store original method to call for non-failing requests + original_preprocess = EngineCore.preprocess_add_request + + # Monkeypatch to make preprocess_add_request raise an exception + # only for requests with "FAIL" in the first token + def conditional_failing_preprocess(self, request: EngineCoreRequest): + # Fail if the first token id is 333 + if request.prompt_token_ids and request.prompt_token_ids[0] == 333: + raise ValueError("Simulated preprocessing error!") + return original_preprocess(self, request) + + monkeypatch.setattr( + EngineCore, "preprocess_add_request", conditional_failing_preprocess + ) + + llm = LLM(model=MODEL_NAME) + + # Create a failing request by crafting a request with an invalid token + # We need to use a direct approach since LLM.generate tokenizes for us + from vllm.inputs import TokensPrompt + + # This should raise an exception due to the preprocessing failure + # Special token id to trigger the failure + failing_prompt = TokensPrompt(prompt_token_ids=[333]) + outputs = llm.generate(failing_prompt, SamplingParams(max_tokens=10)) # type: ignore + assert len(outputs) == 1 + assert len(outputs[0].outputs[0].token_ids) == 0 + assert outputs[0].finished + assert outputs[0].outputs[0].finish_reason == "error" + + # Verify the engine is still functional with a normal request + outputs = llm.generate("Hello, my name is", SamplingParams(max_tokens=10)) + assert len(outputs) == 1 + assert len(outputs[0].outputs[0].token_ids) > 0 + assert outputs[0].outputs[0].finish_reason in ("stop", "length") diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index 9e2571201a684..40c3e9a515e18 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -43,9 +43,11 @@ from vllm.v1.core.kv_cache_utils import ( from vllm.v1.core.sched.interface import SchedulerInterface from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.engine import ( + EngineCoreOutput, EngineCoreOutputs, EngineCoreRequest, EngineCoreRequestType, + FinishReason, ReconfigureDistributedRequest, ReconfigureRankType, UtilityOutput, @@ -1055,9 +1057,14 @@ class EngineCoreProc(EngineCore): request_type = EngineCoreRequestType(bytes(type_frame.buffer)) # Deserialize the request data. + request: Any if request_type == EngineCoreRequestType.ADD: - request = add_request_decoder.decode(data_frames) - request = self.preprocess_add_request(request) + req: EngineCoreRequest = add_request_decoder.decode(data_frames) + try: + request = self.preprocess_add_request(req) + except Exception: + self._handle_request_preproc_error(req) + continue else: request = generic_decoder.decode(data_frames) @@ -1141,6 +1148,30 @@ class EngineCoreProc(EngineCore): # Limit the number of buffers to reuse. reuse_buffers.append(buffer) + def _handle_request_preproc_error(self, request: EngineCoreRequest) -> None: + """Log and return a request-scoped error response for exceptions raised + from the add request preprocessing in the input socket processing thread. + """ + logger.exception( + "Unexpected error pre-processing request %s", request.request_id + ) + self.output_queue.put_nowait( + ( + request.client_index, + EngineCoreOutputs( + engine_index=self.engine_index, + finished_requests={request.request_id}, + outputs=[ + EngineCoreOutput( + request_id=request.request_id, + new_token_ids=[], + finish_reason=FinishReason.ERROR, + ) + ], + ), + ) + ) + class DPEngineCoreProc(EngineCoreProc): """ZMQ-wrapper for running EngineCore in background process