From a9e4106f28834315de4bfb1cb1186c9a2dc95856 Mon Sep 17 00:00:00 2001 From: Will Eaton Date: Wed, 10 Dec 2025 14:00:52 -0500 Subject: [PATCH] [P/D] KV Load Failure Recovery/Abort Configuration (#26813) Signed-off-by: Will Eaton Signed-off-by: Will Eaton Signed-off-by: Nick Hill Co-authored-by: Mark McLoughlin Co-authored-by: Nick Hill Co-authored-by: chaunceyjiang --- tests/entrypoints/openai/test_chat_error.py | 228 +++++++++ .../openai/test_completion_error.py | 216 +++++++++ .../openai/test_responses_error.py | 89 ++++ .../unit/test_cache_pollution_prevention.py | 163 +++++++ .../unit/test_error_propagation.py | 147 ++++++ .../unit/test_invalid_blocks_correctness.py | 454 ++++++++++++++++++ vllm/config/kv_transfer.py | 5 + vllm/entrypoints/openai/serving_chat.py | 17 +- vllm/entrypoints/openai/serving_completion.py | 15 +- vllm/entrypoints/openai/serving_engine.py | 61 +++ vllm/entrypoints/openai/serving_responses.py | 53 +- vllm/v1/core/block_pool.py | 19 + vllm/v1/core/kv_cache_manager.py | 8 + vllm/v1/core/sched/scheduler.py | 114 +++-- vllm/v1/engine/__init__.py | 9 +- vllm/v1/request.py | 2 + 16 files changed, 1552 insertions(+), 48 deletions(-) create mode 100644 tests/entrypoints/openai/test_chat_error.py create mode 100644 tests/entrypoints/openai/test_completion_error.py create mode 100644 tests/entrypoints/openai/test_responses_error.py create mode 100644 tests/v1/kv_connector/unit/test_cache_pollution_prevention.py create mode 100644 tests/v1/kv_connector/unit/test_error_propagation.py create mode 100644 tests/v1/kv_connector/unit/test_invalid_blocks_correctness.py diff --git a/tests/entrypoints/openai/test_chat_error.py b/tests/entrypoints/openai/test_chat_error.py new file mode 100644 index 0000000000000..102eeaf614410 --- /dev/null +++ b/tests/entrypoints/openai/test_chat_error.py @@ -0,0 +1,228 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from dataclasses import dataclass, field +from http import HTTPStatus +from typing import Any +from unittest.mock import AsyncMock, MagicMock + +import pytest + +from vllm.config.multimodal import MultiModalConfig +from vllm.entrypoints.openai.protocol import ChatCompletionRequest, ErrorResponse +from vllm.entrypoints.openai.serving_chat import OpenAIServingChat +from vllm.entrypoints.openai.serving_models import BaseModelPath, OpenAIServingModels +from vllm.outputs import CompletionOutput, RequestOutput +from vllm.transformers_utils.tokenizer import get_tokenizer +from vllm.v1.engine.async_llm import AsyncLLM + +MODEL_NAME = "openai-community/gpt2" +MODEL_NAME_SHORT = "gpt2" +BASE_MODEL_PATHS = [ + BaseModelPath(name=MODEL_NAME, model_path=MODEL_NAME), + BaseModelPath(name=MODEL_NAME_SHORT, model_path=MODEL_NAME_SHORT), +] + + +@dataclass +class MockHFConfig: + model_type: str = "any" + + +@dataclass +class MockModelConfig: + task = "generate" + runner_type = "generate" + tokenizer = MODEL_NAME + trust_remote_code = False + tokenizer_mode = "auto" + max_model_len = 100 + tokenizer_revision = None + multimodal_config = MultiModalConfig() + hf_config = MockHFConfig() + logits_processor_pattern = None + logits_processors: list[str] | None = None + diff_sampling_param: dict | None = None + allowed_local_media_path: str = "" + allowed_media_domains: list[str] | None = None + encoder_config = None + generation_config: str = "auto" + media_io_kwargs: dict[str, dict[str, Any]] = field(default_factory=dict) + skip_tokenizer_init = False + + def get_diff_sampling_param(self): + return self.diff_sampling_param or {} + + +def _build_serving_chat(engine: AsyncLLM) -> OpenAIServingChat: + models = OpenAIServingModels( + engine_client=engine, + base_model_paths=BASE_MODEL_PATHS, + ) + serving_chat = OpenAIServingChat( + engine, + models, + response_role="assistant", + request_logger=None, + chat_template=None, + chat_template_content_format="auto", + ) + + async def _fake_process_inputs( + request_id, + engine_prompt, + sampling_params, + *, + lora_request, + trace_headers, + priority, + ): + return dict(engine_prompt), {} + + async def _fake_preprocess_chat(*args, **kwargs): + # return conversation, request_prompts, engine_prompts + return ( + [{"role": "user", "content": "Test"}], + [[1, 2, 3]], + [{"prompt_token_ids": [1, 2, 3]}], + ) + + serving_chat._process_inputs = AsyncMock(side_effect=_fake_process_inputs) + serving_chat._preprocess_chat = AsyncMock(side_effect=_fake_preprocess_chat) + return serving_chat + + +@pytest.mark.asyncio +async def test_chat_error_non_stream(): + """test finish_reason='error' returns 500 InternalServerError (non-streaming)""" + mock_engine = MagicMock(spec=AsyncLLM) + mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME) + mock_engine.errored = False + mock_engine.model_config = MockModelConfig() + mock_engine.input_processor = MagicMock() + mock_engine.io_processor = MagicMock() + + serving_chat = _build_serving_chat(mock_engine) + + completion_output = CompletionOutput( + index=0, + text="", + token_ids=[], + cumulative_logprob=None, + logprobs=None, + finish_reason="error", + ) + + request_output = RequestOutput( + request_id="test-id", + prompt="Test prompt", + prompt_token_ids=[1, 2, 3], + prompt_logprobs=None, + outputs=[completion_output], + finished=True, + metrics=None, + lora_request=None, + encoder_prompt=None, + encoder_prompt_token_ids=None, + ) + + async def mock_generate(*args, **kwargs): + yield request_output + + mock_engine.generate = MagicMock(side_effect=mock_generate) + + request = ChatCompletionRequest( + model=MODEL_NAME, + messages=[{"role": "user", "content": "Test prompt"}], + max_tokens=10, + stream=False, + ) + + response = await serving_chat.create_chat_completion(request) + + assert isinstance(response, ErrorResponse) + assert response.error.type == "InternalServerError" + assert response.error.message == "Internal server error" + assert response.error.code == HTTPStatus.INTERNAL_SERVER_ERROR + + +@pytest.mark.asyncio +async def test_chat_error_stream(): + """test finish_reason='error' returns 500 InternalServerError (streaming)""" + mock_engine = MagicMock(spec=AsyncLLM) + mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME) + mock_engine.errored = False + mock_engine.model_config = MockModelConfig() + mock_engine.input_processor = MagicMock() + mock_engine.io_processor = MagicMock() + + serving_chat = _build_serving_chat(mock_engine) + + completion_output_1 = CompletionOutput( + index=0, + text="Hello", + token_ids=[100], + cumulative_logprob=None, + logprobs=None, + finish_reason=None, + ) + + request_output_1 = RequestOutput( + request_id="test-id", + prompt="Test prompt", + prompt_token_ids=[1, 2, 3], + prompt_logprobs=None, + outputs=[completion_output_1], + finished=False, + metrics=None, + lora_request=None, + encoder_prompt=None, + encoder_prompt_token_ids=None, + ) + + completion_output_2 = CompletionOutput( + index=0, + text="Hello", + token_ids=[100], + cumulative_logprob=None, + logprobs=None, + finish_reason="error", + ) + + request_output_2 = RequestOutput( + request_id="test-id", + prompt="Test prompt", + prompt_token_ids=[1, 2, 3], + prompt_logprobs=None, + outputs=[completion_output_2], + finished=True, + metrics=None, + lora_request=None, + encoder_prompt=None, + encoder_prompt_token_ids=None, + ) + + async def mock_generate(*args, **kwargs): + yield request_output_1 + yield request_output_2 + + mock_engine.generate = MagicMock(side_effect=mock_generate) + + request = ChatCompletionRequest( + model=MODEL_NAME, + messages=[{"role": "user", "content": "Test prompt"}], + max_tokens=10, + stream=True, + ) + + response = await serving_chat.create_chat_completion(request) + + chunks = [] + async for chunk in response: + chunks.append(chunk) + + assert len(chunks) >= 2 + assert any("Internal server error" in chunk for chunk in chunks), ( + f"Expected error message in chunks: {chunks}" + ) + assert chunks[-1] == "data: [DONE]\n\n" diff --git a/tests/entrypoints/openai/test_completion_error.py b/tests/entrypoints/openai/test_completion_error.py new file mode 100644 index 0000000000000..ca56cc2ddb6a7 --- /dev/null +++ b/tests/entrypoints/openai/test_completion_error.py @@ -0,0 +1,216 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from dataclasses import dataclass, field +from http import HTTPStatus +from typing import Any +from unittest.mock import AsyncMock, MagicMock + +import pytest + +from vllm.config.multimodal import MultiModalConfig +from vllm.entrypoints.openai.protocol import CompletionRequest, ErrorResponse +from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion +from vllm.entrypoints.openai.serving_models import BaseModelPath, OpenAIServingModels +from vllm.outputs import CompletionOutput, RequestOutput +from vllm.transformers_utils.tokenizer import get_tokenizer +from vllm.v1.engine.async_llm import AsyncLLM + +MODEL_NAME = "openai-community/gpt2" +MODEL_NAME_SHORT = "gpt2" +BASE_MODEL_PATHS = [ + BaseModelPath(name=MODEL_NAME, model_path=MODEL_NAME), + BaseModelPath(name=MODEL_NAME_SHORT, model_path=MODEL_NAME_SHORT), +] + + +@dataclass +class MockHFConfig: + model_type: str = "any" + + +@dataclass +class MockModelConfig: + task = "generate" + runner_type = "generate" + tokenizer = MODEL_NAME + trust_remote_code = False + tokenizer_mode = "auto" + max_model_len = 100 + tokenizer_revision = None + multimodal_config = MultiModalConfig() + hf_config = MockHFConfig() + logits_processor_pattern = None + logits_processors: list[str] | None = None + diff_sampling_param: dict | None = None + allowed_local_media_path: str = "" + allowed_media_domains: list[str] | None = None + encoder_config = None + generation_config: str = "auto" + media_io_kwargs: dict[str, dict[str, Any]] = field(default_factory=dict) + skip_tokenizer_init = False + + def get_diff_sampling_param(self): + return self.diff_sampling_param or {} + + +def _build_serving_completion(engine: AsyncLLM) -> OpenAIServingCompletion: + models = OpenAIServingModels( + engine_client=engine, + base_model_paths=BASE_MODEL_PATHS, + ) + serving_completion = OpenAIServingCompletion( + engine, + models, + request_logger=None, + ) + + async def _fake_process_inputs( + request_id, + engine_prompt, + sampling_params, + *, + lora_request, + trace_headers, + priority, + ): + return dict(engine_prompt), {} + + serving_completion._process_inputs = AsyncMock(side_effect=_fake_process_inputs) + return serving_completion + + +@pytest.mark.asyncio +async def test_completion_error_non_stream(): + """test finish_reason='error' returns 500 InternalServerError (non-streaming)""" + mock_engine = MagicMock(spec=AsyncLLM) + mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME) + mock_engine.errored = False + mock_engine.model_config = MockModelConfig() + mock_engine.input_processor = MagicMock() + mock_engine.io_processor = MagicMock() + + serving_completion = _build_serving_completion(mock_engine) + + completion_output = CompletionOutput( + index=0, + text="", + token_ids=[], + cumulative_logprob=None, + logprobs=None, + finish_reason="error", + ) + + request_output = RequestOutput( + request_id="test-id", + prompt="Test prompt", + prompt_token_ids=[1, 2, 3], + prompt_logprobs=None, + outputs=[completion_output], + finished=True, + metrics=None, + lora_request=None, + encoder_prompt=None, + encoder_prompt_token_ids=None, + ) + + async def mock_generate(*args, **kwargs): + yield request_output + + mock_engine.generate = MagicMock(side_effect=mock_generate) + + request = CompletionRequest( + model=MODEL_NAME, + prompt="Test prompt", + max_tokens=10, + stream=False, + ) + + response = await serving_completion.create_completion(request) + + assert isinstance(response, ErrorResponse) + assert response.error.type == "InternalServerError" + assert response.error.message == "Internal server error" + assert response.error.code == HTTPStatus.INTERNAL_SERVER_ERROR + + +@pytest.mark.asyncio +async def test_completion_error_stream(): + """test finish_reason='error' returns 500 InternalServerError (streaming)""" + mock_engine = MagicMock(spec=AsyncLLM) + mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME) + mock_engine.errored = False + mock_engine.model_config = MockModelConfig() + mock_engine.input_processor = MagicMock() + mock_engine.io_processor = MagicMock() + + serving_completion = _build_serving_completion(mock_engine) + + completion_output_1 = CompletionOutput( + index=0, + text="Hello", + token_ids=[100], + cumulative_logprob=None, + logprobs=None, + finish_reason=None, + ) + + request_output_1 = RequestOutput( + request_id="test-id", + prompt="Test prompt", + prompt_token_ids=[1, 2, 3], + prompt_logprobs=None, + outputs=[completion_output_1], + finished=False, + metrics=None, + lora_request=None, + encoder_prompt=None, + encoder_prompt_token_ids=None, + ) + + completion_output_2 = CompletionOutput( + index=0, + text="Hello", + token_ids=[100], + cumulative_logprob=None, + logprobs=None, + finish_reason="error", + ) + + request_output_2 = RequestOutput( + request_id="test-id", + prompt="Test prompt", + prompt_token_ids=[1, 2, 3], + prompt_logprobs=None, + outputs=[completion_output_2], + finished=True, + metrics=None, + lora_request=None, + encoder_prompt=None, + encoder_prompt_token_ids=None, + ) + + async def mock_generate(*args, **kwargs): + yield request_output_1 + yield request_output_2 + + mock_engine.generate = MagicMock(side_effect=mock_generate) + + request = CompletionRequest( + model=MODEL_NAME, + prompt="Test prompt", + max_tokens=10, + stream=True, + ) + + response = await serving_completion.create_completion(request) + + chunks = [] + async for chunk in response: + chunks.append(chunk) + + assert len(chunks) >= 2 + assert any("Internal server error" in chunk for chunk in chunks), ( + f"Expected error message in chunks: {chunks}" + ) + assert chunks[-1] == "data: [DONE]\n\n" diff --git a/tests/entrypoints/openai/test_responses_error.py b/tests/entrypoints/openai/test_responses_error.py new file mode 100644 index 0000000000000..f8ea178288835 --- /dev/null +++ b/tests/entrypoints/openai/test_responses_error.py @@ -0,0 +1,89 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from http import HTTPStatus +from unittest.mock import MagicMock + +import pytest + +from vllm.entrypoints.openai.protocol import ErrorResponse +from vllm.entrypoints.openai.serving_engine import GenerationError, OpenAIServing + + +@pytest.mark.asyncio +async def test_raise_if_error_raises_generation_error(): + """test _raise_if_error raises GenerationError""" + # create a minimal OpenAIServing instance + mock_engine = MagicMock() + mock_engine.model_config = MagicMock() + mock_engine.model_config.max_model_len = 100 + mock_models = MagicMock() + + serving = OpenAIServing( + engine_client=mock_engine, + models=mock_models, + request_logger=None, + ) + + # test that error finish_reason raises GenerationError + with pytest.raises(GenerationError) as exc_info: + serving._raise_if_error("error", "test-request-id") + + assert str(exc_info.value) == "Internal server error" + assert exc_info.value.status_code == HTTPStatus.INTERNAL_SERVER_ERROR + + # test that other finish_reasons don't raise + serving._raise_if_error("stop", "test-request-id") # should not raise + serving._raise_if_error("length", "test-request-id") # should not raise + serving._raise_if_error(None, "test-request-id") # should not raise + + +@pytest.mark.asyncio +async def test_convert_generation_error_to_response(): + """test _convert_generation_error_to_response creates proper ErrorResponse""" + mock_engine = MagicMock() + mock_engine.model_config = MagicMock() + mock_engine.model_config.max_model_len = 100 + mock_models = MagicMock() + + serving = OpenAIServing( + engine_client=mock_engine, + models=mock_models, + request_logger=None, + ) + + # create a GenerationError + gen_error = GenerationError("Internal server error") + + # convert to ErrorResponse + error_response = serving._convert_generation_error_to_response(gen_error) + + assert isinstance(error_response, ErrorResponse) + assert error_response.error.type == "InternalServerError" + assert error_response.error.message == "Internal server error" + assert error_response.error.code == HTTPStatus.INTERNAL_SERVER_ERROR + + +@pytest.mark.asyncio +async def test_convert_generation_error_to_streaming_response(): + """test _convert_generation_error_to_streaming_response output""" + mock_engine = MagicMock() + mock_engine.model_config = MagicMock() + mock_engine.model_config.max_model_len = 100 + mock_models = MagicMock() + + serving = OpenAIServing( + engine_client=mock_engine, + models=mock_models, + request_logger=None, + ) + + # create a GenerationError + gen_error = GenerationError("Internal server error") + + # convert to streaming error response + error_json = serving._convert_generation_error_to_streaming_response(gen_error) + + assert isinstance(error_json, str) + assert "Internal server error" in error_json + assert "InternalServerError" in error_json diff --git a/tests/v1/kv_connector/unit/test_cache_pollution_prevention.py b/tests/v1/kv_connector/unit/test_cache_pollution_prevention.py new file mode 100644 index 0000000000000..ec3fb8231e19e --- /dev/null +++ b/tests/v1/kv_connector/unit/test_cache_pollution_prevention.py @@ -0,0 +1,163 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +""" +test that invalid blocks are evicted from prefix cache to prevent pollution. + +verifies that when sync-loading fails, invalid blocks are removed from the +prefix cache hash table so future requests cannot match and reuse corrupted data. +""" + +from collections.abc import Callable +from unittest.mock import Mock + +import pytest + +from vllm.v1.core.sched.scheduler import Scheduler +from vllm.v1.request import Request, RequestStatus + +from .utils import ( + create_model_runner_output, + create_request, + create_scheduler, + create_vllm_config, +) + +pytestmark = pytest.mark.cpu_test + + +def _make_get_num_new_matched_tokens( + req_num_new_matched_tokens: dict[str, int], + async_load: bool, +) -> Callable[[Request, int], tuple[int, bool]]: + def get_num_new_matched_tokens(request: Request, _: int) -> tuple[int, bool]: + value = req_num_new_matched_tokens.get(request.request_id, 0) + return value, async_load + + return get_num_new_matched_tokens + + +@pytest.fixture +def fail_scheduler(): + """scheduler with kv_load_failure_policy='fail'""" + vllm_config = create_vllm_config() + vllm_config.kv_transfer_config.kv_load_failure_policy = "fail" + return create_scheduler(vllm_config) + + +def test_invalid_blocks_evicted_prevents_cache_pollution( + fail_scheduler: Scheduler, +): + """ + verify invalid blocks are evicted to prevent future cache hits. + + scenario: + 1. request 1 loads externally-computed blocks (sync mode) + 2. some blocks fail to load and are marked invalid + 3. with fail policy, invalid blocks should be evicted from prefix cache + 4. request is marked as FINISHED_ERROR + """ + num_prompt_blocks = 100 + num_external_computed_blocks = 99 + invalid_block_idx = 50 + + num_prompt_tokens = num_prompt_blocks * fail_scheduler.block_size + num_external_computed_tokens = ( + num_external_computed_blocks * fail_scheduler.block_size + ) + + # request 1: will have invalid blocks + request1 = create_request(num_tokens=num_prompt_tokens, request_id=1) + fail_scheduler.add_request(request=request1) + + req_num_new_matched_tokens = { + request1.request_id: num_external_computed_tokens, + } + + # mock connector indicating sync load + fail_scheduler.connector = Mock() + fail_scheduler.connector.get_num_new_matched_tokens.side_effect = ( + _make_get_num_new_matched_tokens(req_num_new_matched_tokens, False) + ) + fail_scheduler.connector.request_finished.return_value = (False, None) + fail_scheduler.connector.take_events.return_value = () + + scheduler_output = fail_scheduler.schedule() + + # request should be running with sync KV load + assert len(fail_scheduler.running) == 1 + assert request1.status == RequestStatus.RUNNING + + # get allocated block IDs + req_block_ids = scheduler_output.scheduled_new_reqs[0].block_ids[0] + invalid_block_id = req_block_ids[invalid_block_idx] + invalid_block_ids = {invalid_block_id} + + # get the block object to verify eviction later + block = fail_scheduler.kv_cache_manager.block_pool.blocks[invalid_block_id] + + # cache the blocks to simulate they've been computed and cached + # (in real scenario blocks would be cached after compute) + fail_scheduler.kv_cache_manager.cache_blocks(request1, num_external_computed_tokens) + + # verify block has a hash (is cached) before reporting invalid blocks + assert block.block_hash is not None, ( + f"block {invalid_block_id} should be cached (have a hash) before " + f"eviction test, but hash is None" + ) + + # report invalid blocks + model_runner_output = create_model_runner_output( + [request1], + invalid_block_ids=invalid_block_ids, + use_eos=False, + ) + + fail_scheduler.update_from_output(scheduler_output, model_runner_output) + + # verify request finished with error (fail policy) + assert request1.status == RequestStatus.FINISHED_ERROR + + # critical assertion: invalid block and all subsequent blocks should be evicted + # all blocks from invalid_block_idx onwards become invalid since they were + # computed based on the failed block + for idx in range(invalid_block_idx, len(req_block_ids)): + block_id = req_block_ids[idx] + block_obj = fail_scheduler.kv_cache_manager.block_pool.blocks[block_id] + assert block_obj.block_hash is None, ( + f"block {block_id} at index {idx} should have been evicted " + f"(hash reset to None), but hash is {block_obj.block_hash}. " + f"All blocks from index {invalid_block_idx} onwards should be evicted " + f"since they depend on the invalid block at index {invalid_block_idx}." + ) + + # verify cache contains exactly the valid blocks (before first affected block) + # and none of the invalid blocks (from first affected block onwards) + + # valid blocks: all blocks before invalid_block_idx should be cached + for idx in range(invalid_block_idx): + block_id = req_block_ids[idx] + block_obj = fail_scheduler.kv_cache_manager.block_pool.blocks[block_id] + assert block_obj.block_hash is not None, ( + f"valid block {block_id} at index {idx} should still be cached " + f"(have a hash), but hash is None. Only blocks from index " + f"{invalid_block_idx} onwards should be evicted." + ) + + # invalid blocks: verify they're not in the cached_block_hash_to_block map + cached_blocks = ( + fail_scheduler.kv_cache_manager.block_pool.cached_block_hash_to_block + ) + cached_block_ids = { + b.block_id + for blocks_val in cached_blocks._cache.values() + for b in ( + [blocks_val] if not isinstance(blocks_val, dict) else blocks_val.values() + ) + } + + for idx in range(invalid_block_idx, len(req_block_ids)): + block_id = req_block_ids[idx] + assert block_id not in cached_block_ids, ( + f"invalid block {block_id} at index {idx} should not be in cache hash table" + ) diff --git a/tests/v1/kv_connector/unit/test_error_propagation.py b/tests/v1/kv_connector/unit/test_error_propagation.py new file mode 100644 index 0000000000000..20e181f379f5c --- /dev/null +++ b/tests/v1/kv_connector/unit/test_error_propagation.py @@ -0,0 +1,147 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from collections.abc import Callable +from unittest.mock import Mock + +import pytest + +from vllm.v1.core.sched.scheduler import Scheduler +from vllm.v1.request import FinishReason, Request, RequestStatus + +from .utils import ( + create_model_runner_output, + create_request, + create_scheduler, + create_vllm_config, +) + +pytestmark = pytest.mark.cpu_test + + +def _make_get_num_new_matched_tokens( + req_num_new_matched_tokens: dict[str, int], + async_load: bool, +) -> Callable[[Request, int], tuple[int, bool]]: + def get_num_new_matched_tokens(request: Request, _: int) -> tuple[int, bool]: + value = req_num_new_matched_tokens.get(request.request_id, 0) + return value, async_load + + return get_num_new_matched_tokens + + +@pytest.fixture +def fail_scheduler(): + """scheduler with kv_load_failure_policy='fail'""" + vllm_config = create_vllm_config() + vllm_config.kv_transfer_config.kv_load_failure_policy = "fail" + return create_scheduler(vllm_config) + + +def test_error_propagation_sync_load(fail_scheduler: Scheduler): + """test invalid_block_ids with fail policy -> FINISHED_ERROR (sync load)""" + num_prompt_blocks = 100 + num_external_computed_blocks = 99 + invalid_block_idx = 50 + + num_prompt_tokens = num_prompt_blocks * fail_scheduler.block_size + num_external_computed_tokens = ( + num_external_computed_blocks * fail_scheduler.block_size + ) + + request = create_request(num_tokens=num_prompt_tokens) + fail_scheduler.add_request(request=request) + + req_num_new_matched_tokens = { + request.request_id: num_external_computed_tokens, + } + + fail_scheduler.connector = Mock() + fail_scheduler.connector.get_num_new_matched_tokens.side_effect = ( + _make_get_num_new_matched_tokens(req_num_new_matched_tokens, False) + ) + fail_scheduler.connector.request_finished.return_value = (False, None) + fail_scheduler.connector.take_events.return_value = () + + scheduler_output = fail_scheduler.schedule() + + assert len(fail_scheduler.running) == 1 + assert len(scheduler_output.scheduled_new_reqs) == 1 + assert fail_scheduler.connector.get_num_new_matched_tokens.call_count == 1 + + req_block_ids = scheduler_output.scheduled_new_reqs[0].block_ids[0] + invalid_block_ids = {req_block_ids[invalid_block_idx]} + model_runner_output = create_model_runner_output( + [request], + invalid_block_ids=invalid_block_ids, + use_eos=True, + ) + + outputs = fail_scheduler.update_from_output(scheduler_output, model_runner_output) + + assert request.status == RequestStatus.FINISHED_ERROR + assert request.get_finished_reason() == FinishReason.ERROR + + assert len(outputs) == 1 + engine_outputs = next(iter(outputs.values())) + assert len(engine_outputs.outputs) == 1 + output = engine_outputs.outputs[0] + assert output.request_id == request.request_id + assert output.finish_reason == FinishReason.ERROR + + assert len(fail_scheduler.running) == 0 + + +def test_error_propagation_async_load(fail_scheduler: Scheduler): + """test invalid_block_ids with fail policy -> FINISHED_ERROR (async load)""" + num_prompt_blocks = 100 + num_external_computed_blocks = 99 + invalid_block_idx = 50 + + num_prompt_tokens = num_prompt_blocks * fail_scheduler.block_size + num_external_computed_tokens = ( + num_external_computed_blocks * fail_scheduler.block_size + ) + + request = create_request(num_tokens=num_prompt_tokens) + fail_scheduler.add_request(request=request) + + req_num_new_matched_tokens = { + request.request_id: num_external_computed_tokens, + } + + fail_scheduler.connector = Mock() + fail_scheduler.connector.get_num_new_matched_tokens.side_effect = ( + _make_get_num_new_matched_tokens(req_num_new_matched_tokens, True) + ) + fail_scheduler.connector.request_finished.return_value = (False, None) + fail_scheduler.connector.take_events.return_value = () + + scheduler_output = fail_scheduler.schedule() + + assert len(fail_scheduler.waiting) == 1 + assert request.status == RequestStatus.WAITING_FOR_REMOTE_KVS + assert request.num_computed_tokens == 0 + + (req_block_ids,) = fail_scheduler.kv_cache_manager.get_block_ids(request.request_id) + invalid_block_ids = {req_block_ids[invalid_block_idx]} + model_runner_output = create_model_runner_output( + reqs=[], + finished_recving=set(), + invalid_block_ids=invalid_block_ids, + use_eos=True, + ) + + outputs = fail_scheduler.update_from_output(scheduler_output, model_runner_output) + + assert request.status == RequestStatus.FINISHED_ERROR + assert request.get_finished_reason() == FinishReason.ERROR + + assert len(outputs) == 1 + engine_outputs = next(iter(outputs.values())) + assert len(engine_outputs.outputs) == 1 + output = engine_outputs.outputs[0] + assert output.request_id == request.request_id + assert output.finish_reason == FinishReason.ERROR + + assert len(fail_scheduler.waiting) == 0 diff --git a/tests/v1/kv_connector/unit/test_invalid_blocks_correctness.py b/tests/v1/kv_connector/unit/test_invalid_blocks_correctness.py new file mode 100644 index 0000000000000..940f3a98308b6 --- /dev/null +++ b/tests/v1/kv_connector/unit/test_invalid_blocks_correctness.py @@ -0,0 +1,454 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +""" +Tests for correctness in invalid block handling. + +These tests verify correct behavior in three scenarios: +1. Sync recompute case: Blocks should not be freed for running requests + that need to recompute invalid blocks +2. Sync fail case: Invalid blocks must be evicted from cache when request fails +3. Async recompute case: Invalid blocks should not be cached after transfer +""" + +from collections.abc import Callable +from unittest.mock import Mock + +import pytest + +from vllm.v1.core.sched.scheduler import Scheduler +from vllm.v1.request import FinishReason, Request, RequestStatus + +from .utils import ( + create_model_runner_output, + create_request, + create_scheduler, + create_vllm_config, +) + +pytestmark = pytest.mark.cpu_test + + +def _make_get_num_new_matched_tokens( + req_num_new_matched_tokens: dict[str, int], + async_load: bool, +) -> Callable[[Request, int], tuple[int, bool]]: + def get_num_new_matched_tokens(request: Request, _: int) -> tuple[int, bool]: + value = req_num_new_matched_tokens.get(request.request_id, 0) + return value, async_load + + return get_num_new_matched_tokens + + +@pytest.fixture +def fail_scheduler(): + """scheduler with kv_load_failure_policy='fail'""" + vllm_config = create_vllm_config() + vllm_config.kv_transfer_config.kv_load_failure_policy = "fail" + return create_scheduler(vllm_config) + + +@pytest.fixture +def recompute_scheduler(): + """scheduler with kv_load_failure_policy='recompute'""" + vllm_config = create_vllm_config() + vllm_config.kv_transfer_config.kv_load_failure_policy = "recompute" + return create_scheduler(vllm_config) + + +def test_sync_recompute_blocks_not_freed_for_running_requests( + recompute_scheduler: Scheduler, +): + """ + Test sync recompute case - blocks must not be freed for running requests. + + When a running request has invalid blocks and retry_policy is 'recompute': + 1. Request should remain in RUNNING state + 2. num_computed_tokens should be truncated to invalid block boundary + 3. Blocks should NOT be freed (request still needs them for recomputation) + 4. Request should remain in scheduler.requests and scheduler.running + """ + num_prompt_blocks = 100 + num_external_computed_blocks = 99 + invalid_block_idx = 50 + + num_prompt_tokens = num_prompt_blocks * recompute_scheduler.block_size + num_external_computed_tokens = ( + num_external_computed_blocks * recompute_scheduler.block_size + ) + + request = create_request(num_tokens=num_prompt_tokens) + recompute_scheduler.add_request(request=request) + + req_num_new_matched_tokens = { + request.request_id: num_external_computed_tokens, + } + + # mock connector indicating sync load + recompute_scheduler.connector = Mock() + recompute_scheduler.connector.get_num_new_matched_tokens.side_effect = ( + _make_get_num_new_matched_tokens(req_num_new_matched_tokens, False) + ) + recompute_scheduler.connector.request_finished.return_value = (False, None) + recompute_scheduler.connector.take_events.return_value = () + + scheduler_output = recompute_scheduler.schedule() + + # request should be running with sync KV load + assert len(recompute_scheduler.running) == 1 + assert len(scheduler_output.scheduled_new_reqs) == 1 + assert request.status == RequestStatus.RUNNING + + # get the allocated block IDs before invalid blocks are reported + req_block_ids = scheduler_output.scheduled_new_reqs[0].block_ids[0] + invalid_block_ids = {req_block_ids[invalid_block_idx]} + + # store original num_computed_tokens for comparison + original_num_computed_tokens = request.num_computed_tokens + + model_runner_output = create_model_runner_output( + [request], + invalid_block_ids=invalid_block_ids, + use_eos=False, # not finished - should continue running + ) + + outputs = recompute_scheduler.update_from_output( + scheduler_output, model_runner_output + ) + + # critical assertions for recompute case: + + # 1. request should still be RUNNING (not finished, not aborted) + assert request.status == RequestStatus.RUNNING, ( + f"Request should remain RUNNING for recompute, got {request.status}" + ) + + # 2. num_computed_tokens should be truncated to first invalid block + expected_truncated_tokens = invalid_block_idx * recompute_scheduler.block_size + assert request.num_computed_tokens == expected_truncated_tokens, ( + f"num_computed_tokens should be truncated to {expected_truncated_tokens}, " + f"got {request.num_computed_tokens}" + ) + assert request.num_computed_tokens < original_num_computed_tokens, ( + "num_computed_tokens should be reduced after invalid block detection" + ) + + # 3. no output should be generated (request is still running) + # the request should be skipped in the output loop + assert len(outputs) == 0 or request.request_id not in [ + out.request_id for outs in outputs.values() for out in outs.outputs + ], "No output should be generated for recompute requests" + + # 4. request should still be in running queue + assert request in recompute_scheduler.running, ( + "Request should remain in running queue for recomputation" + ) + + # 5. request should still be in scheduler.requests (not deleted) + assert request.request_id in recompute_scheduler.requests, ( + "Request should not be deleted from scheduler.requests" + ) + + # 6. blocks should NOT be freed - verify blocks are still allocated + try: + allocated_blocks = recompute_scheduler.kv_cache_manager.get_block_ids( + request.request_id + ) + assert allocated_blocks is not None + assert len(allocated_blocks[0]) > 0, ( + "Blocks should still be allocated for recomputation" + ) + except KeyError: + pytest.fail( + "Blocks were freed incorrectly! Running requests need their blocks " + "to recompute invalid portions." + ) + + # 7. verify request can be rescheduled in next step + scheduler_output_2 = recompute_scheduler.schedule() + + # request should appear in the new schedule to recompute invalid blocks + scheduled_req_ids = [ + req.request_id for req in scheduler_output_2.scheduled_new_reqs + ] + if scheduler_output_2.num_scheduled_tokens: + scheduled_req_ids.extend(scheduler_output_2.num_scheduled_tokens.keys()) + + assert ( + request.request_id in scheduled_req_ids or len(recompute_scheduler.running) > 0 + ), "Request should be reschedulable for recomputation" + + +def test_sync_fail_invalid_blocks_evicted(fail_scheduler: Scheduler): + """ + Test sync fail case - invalid blocks must be evicted from cache. + + When a request fails with policy='fail' and has invalid blocks from sync loading: + 1. Request should be finished with FINISHED_ERROR + 2. Invalid blocks should be evicted from the KV cache + 3. Valid blocks (if shared) should remain in cache + 4. Future requests should not reuse the invalid blocks + + This test verifies that invalid blocks are properly evicted to prevent + cache corruption and reuse of invalid data. + """ + num_prompt_blocks = 100 + num_external_computed_blocks = 99 + invalid_block_idx = 50 + + num_prompt_tokens = num_prompt_blocks * fail_scheduler.block_size + num_external_computed_tokens = ( + num_external_computed_blocks * fail_scheduler.block_size + ) + + request = create_request(num_tokens=num_prompt_tokens) + fail_scheduler.add_request(request=request) + + req_num_new_matched_tokens = { + request.request_id: num_external_computed_tokens, + } + + # mock connector indicating sync load + fail_scheduler.connector = Mock() + fail_scheduler.connector.get_num_new_matched_tokens.side_effect = ( + _make_get_num_new_matched_tokens(req_num_new_matched_tokens, False) + ) + fail_scheduler.connector.request_finished.return_value = (False, None) + fail_scheduler.connector.take_events.return_value = () + + scheduler_output = fail_scheduler.schedule() + + # request should be running with sync KV load + assert len(fail_scheduler.running) == 1 + assert request.status == RequestStatus.RUNNING + + # get allocated block IDs + req_block_ids = scheduler_output.scheduled_new_reqs[0].block_ids[0] + invalid_block_id = req_block_ids[invalid_block_idx] + invalid_block_ids = {invalid_block_id} + + # verify the block is in the block pool before we report it as invalid + block = fail_scheduler.kv_cache_manager.block_pool.blocks[invalid_block_id] + assert block is not None + + # report invalid blocks - request should fail + model_runner_output = create_model_runner_output( + [request], + invalid_block_ids=invalid_block_ids, + use_eos=True, + ) + + outputs = fail_scheduler.update_from_output(scheduler_output, model_runner_output) + + # verify request is finished with error + assert request.status == RequestStatus.FINISHED_ERROR + assert request.get_finished_reason() == FinishReason.ERROR + + # verify output is generated + assert len(outputs) == 1 + engine_outputs = next(iter(outputs.values())) + assert len(engine_outputs.outputs) == 1 + output = engine_outputs.outputs[0] + assert output.request_id == request.request_id + assert output.finish_reason == FinishReason.ERROR + + # verify the request was removed from scheduler + assert request.request_id not in fail_scheduler.requests + assert len(fail_scheduler.running) == 0 + + # critical: verify invalid block was actually freed from cache + # this is the key assertion - the invalid block should no longer be + # tracked by the KV cache manager for this request + # if it's still there, a future request could reuse the invalid data + try: + block_ids = fail_scheduler.kv_cache_manager.get_block_ids(request.request_id) + # if we get here, check if blocks were actually freed + if block_ids is not None and len(block_ids[0]) > 0: + pytest.fail( + f"Invalid blocks still tracked for finished request! " + f"Request {request.request_id} should have been freed but " + f"still has {len(block_ids[0])} blocks allocated." + ) + # blocks list exists but is empty - this is fine, they were freed + except KeyError: + # expected - request completely removed from tracking + pass + + # critical: verify invalid block was evicted from prefix cache + # the block should no longer have a hash (hash is reset on eviction) + assert block.block_hash is None, ( + f"Invalid block {invalid_block_id} should have been evicted from cache " + f"(hash should be None), but hash is still {block.block_hash}" + ) + + +def test_async_recompute_blocks_not_cached_when_invalid( + recompute_scheduler: Scheduler, +): + """ + Test async recompute case - invalid blocks not cached after transfer. + + When async KV loading has invalid blocks and retry_policy is 'recompute': + 1. Blocks are allocated but not cached yet + 2. When async transfer completes, only valid blocks should be cached + 3. Invalid blocks should never enter the prefix cache + + This test verifies correctness, the failed_recving_kv_req_ids protection + ensures only valid blocks are cached when the transfer completes, and we + only evict blocks from cache that are already hashed in the block table. + """ + from unittest.mock import patch + + num_prompt_blocks = 100 + num_external_computed_blocks = 99 + invalid_block_idx = 50 + + num_prompt_tokens = num_prompt_blocks * recompute_scheduler.block_size + num_external_computed_tokens = ( + num_external_computed_blocks * recompute_scheduler.block_size + ) + + request = create_request(num_tokens=num_prompt_tokens) + recompute_scheduler.add_request(request=request) + + req_num_new_matched_tokens = { + request.request_id: num_external_computed_tokens, + } + + # mock connector indicating async load + recompute_scheduler.connector = Mock() + recompute_scheduler.connector.get_num_new_matched_tokens.side_effect = ( + _make_get_num_new_matched_tokens(req_num_new_matched_tokens, True) + ) + recompute_scheduler.connector.request_finished.return_value = (False, None) + recompute_scheduler.connector.take_events.return_value = () + + scheduler_output = recompute_scheduler.schedule() + + # request should be waiting for remote KVs + assert len(recompute_scheduler.waiting) == 1 + assert request.status == RequestStatus.WAITING_FOR_REMOTE_KVS + assert request.num_computed_tokens == 0 + + # get the allocated block IDs + (req_block_ids,) = recompute_scheduler.kv_cache_manager.get_block_ids( + request.request_id + ) + invalid_block_id = req_block_ids[invalid_block_idx] + invalid_block_ids = {invalid_block_id} + + # get the block object to verify it's not cached yet and stays uncached + block = recompute_scheduler.kv_cache_manager.block_pool.blocks[invalid_block_id] + + # verify block has no hash before invalid blocks are reported + assert block.block_hash is None, ( + "Async loading blocks should not be cached yet (no hash)" + ) + + # report invalid blocks (transfer not finished yet) + model_runner_output = create_model_runner_output( + reqs=[], + finished_recving=None, # transfer NOT finished + invalid_block_ids=invalid_block_ids, + use_eos=False, + ) + + # critical: spy on evict_blocks to verify it's NOT called for async blocks + original_evict_blocks = recompute_scheduler.kv_cache_manager.evict_blocks + evict_blocks_calls = [] + + def evict_blocks_spy(block_ids): + evict_blocks_calls.append(set(block_ids)) + return original_evict_blocks(block_ids) + + with patch.object( + recompute_scheduler.kv_cache_manager, "evict_blocks", evict_blocks_spy + ): + recompute_scheduler.update_from_output(scheduler_output, model_runner_output) + + # verify evict_blocks was NOT called (async blocks excluded from eviction) + assert len(evict_blocks_calls) == 0, ( + f"evict_blocks should not be called for async-only invalid blocks, " + f"but was called {len(evict_blocks_calls)} time(s) with {evict_blocks_calls}" + ) + + # request should still be waiting (not finished with error due to recompute policy) + assert request.status == RequestStatus.WAITING_FOR_REMOTE_KVS + assert request.request_id in recompute_scheduler.failed_recving_kv_req_ids + + # verify num_computed_tokens was truncated to before invalid block + expected_valid_tokens = invalid_block_idx * recompute_scheduler.block_size + assert request.num_computed_tokens == expected_valid_tokens + + # verify invalid block still has no hash (was not evicted) + assert block.block_hash is None, ( + f"Async loading blocks shouldn't be cached or evicted. " + f"Block {invalid_block_id} hash should be None but is {block.block_hash}" + ) + + # now simulate async transfer completing + model_runner_output_2 = create_model_runner_output( + reqs=[], + finished_recving={request.request_id}, + invalid_block_ids=None, + use_eos=False, + ) + + recompute_scheduler.update_from_output(scheduler_output, model_runner_output_2) + + # verify request is now marked as finished receiving and ready to be processed + assert request.request_id in recompute_scheduler.finished_recving_kv_req_ids + assert request.request_id in recompute_scheduler.failed_recving_kv_req_ids + + # critical: verify invalid block still has no hash before recompute + # the async transfer invalid data was never cached + assert block.block_hash is None, ( + f"Invalid block {invalid_block_id} should not be cached before recompute " + f"(hash should be None), but hash is {block.block_hash}" + ) + + # critical end-to-end test: spy on cache_blocks to verify it's called with + # the truncated num_computed_tokens value + original_cache_blocks = recompute_scheduler.kv_cache_manager.cache_blocks + cache_blocks_calls = [] + + def cache_blocks_spy(req, num_tokens): + cache_blocks_calls.append((req.request_id, num_tokens)) + return original_cache_blocks(req, num_tokens) + + with patch.object( + recompute_scheduler.kv_cache_manager, "cache_blocks", cache_blocks_spy + ): + # call schedule() again - this triggers _update_waiting_for_remote_kv() + # which should call cache_blocks with the truncated value + recompute_scheduler.schedule() + + # verify cache_blocks was called with the truncated value + assert len(cache_blocks_calls) == 1, ( + f"cache_blocks should be called exactly once, " + f"got {len(cache_blocks_calls)} calls" + ) + cached_req_id, cached_num_tokens = cache_blocks_calls[0] + assert cached_req_id == request.request_id + assert cached_num_tokens == expected_valid_tokens, ( + f"cache_blocks should be called with truncated value {expected_valid_tokens}, " + f"but was called with {cached_num_tokens}" + ) + + # request should now be RUNNING (scheduled immediately after transfer completes) + # the flow is: WAITING_FOR_REMOTE_KVS -> WAITING -> RUNNING in same schedule() call + assert request.status == RequestStatus.RUNNING + + # num_computed_tokens should be >= expected_valid_tokens because the scheduler + # will schedule additional new tokens (up to max_num_batched_tokens) for the request + assert request.num_computed_tokens >= expected_valid_tokens, ( + f"num_computed_tokens should be at least {expected_valid_tokens}, " + f"got {request.num_computed_tokens}" + ) + + # request should no longer be in the failed/finished receiving sets + assert request.request_id not in recompute_scheduler.failed_recving_kv_req_ids + assert request.request_id not in recompute_scheduler.finished_recving_kv_req_ids + + # request should be in the running queue + assert request in recompute_scheduler.running diff --git a/vllm/config/kv_transfer.py b/vllm/config/kv_transfer.py index 88f8b91c292bb..98cea821c678e 100644 --- a/vllm/config/kv_transfer.py +++ b/vllm/config/kv_transfer.py @@ -64,6 +64,11 @@ class KVTransferConfig: enable_permute_local_kv: bool = False """Experiment feature flag to enable HND to NHD KV Transfer""" + kv_load_failure_policy: Literal["recompute", "fail"] = "recompute" + """Policy for handling KV cache load failures. + 'recompute': reschedule the request to recompute failed blocks (default) + 'fail': immediately fail the request with an error finish reason""" + def compute_hash(self) -> str: """ WARNING: Whenever a new field is added to this config, diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index c6333d170c663..2560a5b2cdf41 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -51,7 +51,11 @@ from vllm.entrypoints.openai.protocol import ( ToolCall, UsageInfo, ) -from vllm.entrypoints.openai.serving_engine import OpenAIServing, clamp_prompt_logprobs +from vllm.entrypoints.openai.serving_engine import ( + GenerationError, + OpenAIServing, + clamp_prompt_logprobs, +) from vllm.entrypoints.openai.serving_models import OpenAIServingModels from vllm.entrypoints.openai.tool_parsers import ToolParser from vllm.entrypoints.openai.tool_parsers.mistral_tool_parser import MistralToolCall @@ -380,6 +384,8 @@ class OpenAIServingChat(OpenAIServing): tokenizer, request_metadata, ) + except GenerationError as e: + return self._convert_generation_error_to_response(e) except ValueError as e: # TODO: Use a vllm-specific Validation Error return self.create_error_response(str(e)) @@ -1120,6 +1126,10 @@ class OpenAIServingChat(OpenAIServing): # if the model is finished generating else: + # check for error finish reason and abort streaming + # finish_reason='error' indicates a retryable error + self._raise_if_error(output.finish_reason, request_id) + # check to make sure we haven't "forgotten" to stream # any tokens that were generated but previously # matched by partial json parsing @@ -1287,6 +1297,8 @@ class OpenAIServingChat(OpenAIServing): delta=False, ) + except GenerationError as e: + yield f"data: {self._convert_generation_error_to_streaming_response(e)}\n\n" except Exception as e: # TODO: Use a vllm-specific Validation Error logger.exception("Error in chat completion stream generator.") @@ -1327,6 +1339,9 @@ class OpenAIServingChat(OpenAIServing): role = self.get_chat_request_role(request) for output in final_res.outputs: + # check for error finish reason and raise GenerationError + # finish_reason='error' indicates a retryable request-level internal error + self._raise_if_error(output.finish_reason, request_id) token_ids = output.token_ids out_logprobs = output.logprobs tool_call_info = None diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/serving_completion.py index 3e421e21e3e80..1be0afc8c74e5 100644 --- a/vllm/entrypoints/openai/serving_completion.py +++ b/vllm/entrypoints/openai/serving_completion.py @@ -24,7 +24,11 @@ from vllm.entrypoints.openai.protocol import ( RequestResponseMetadata, UsageInfo, ) -from vllm.entrypoints.openai.serving_engine import OpenAIServing, clamp_prompt_logprobs +from vllm.entrypoints.openai.serving_engine import ( + GenerationError, + OpenAIServing, + clamp_prompt_logprobs, +) from vllm.entrypoints.openai.serving_models import OpenAIServingModels from vllm.entrypoints.renderer import RenderConfig from vllm.entrypoints.utils import get_max_tokens, should_include_usage @@ -300,6 +304,8 @@ class OpenAIServingCompletion(OpenAIServing): ) except asyncio.CancelledError: return self.create_error_response("Client disconnected") + except GenerationError as e: + return self._convert_generation_error_to_response(e) except ValueError as e: # TODO: Use a vllm-specific Validation Error return self.create_error_response(str(e)) @@ -437,6 +443,8 @@ class OpenAIServingCompletion(OpenAIServing): finish_reason = output.finish_reason stop_reason = output.stop_reason + self._raise_if_error(finish_reason, request_id) + chunk = CompletionStreamResponse( id=request_id, created=created_time, @@ -498,8 +506,11 @@ class OpenAIServingCompletion(OpenAIServing): # report to FastAPI middleware aggregate usage across all choices request_metadata.final_usage_info = final_usage_info + except GenerationError as e: + yield f"data: {self._convert_generation_error_to_streaming_response(e)}\n\n" except Exception as e: # TODO: Use a vllm-specific Validation Error + logger.exception("Error in completion stream generator.") data = self.create_streaming_error_response(str(e)) yield f"data: {data}\n\n" yield "data: [DONE]\n\n" @@ -530,6 +541,8 @@ class OpenAIServingCompletion(OpenAIServing): out_logprobs: GenericSequence[dict[int, Logprob] | None] | None for output in final_res.outputs: + self._raise_if_error(output.finish_reason, request_id) + assert request.max_tokens is not None if request.echo: if request.return_token_ids: diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py index 44b0f1842a6c1..a799432baeb40 100644 --- a/vllm/entrypoints/openai/serving_engine.py +++ b/vllm/entrypoints/openai/serving_engine.py @@ -133,6 +133,15 @@ from vllm.utils.async_utils import ( from vllm.utils.collection_utils import is_list_of from vllm.v1.engine import EngineCoreRequest + +class GenerationError(Exception): + """raised when finish_reason indicates internal server error (500)""" + + def __init__(self, message: str = "Internal server error"): + super().__init__(message) + self.status_code = HTTPStatus.INTERNAL_SERVER_ERROR + + logger = init_logger(__name__) CompletionLikeRequest: TypeAlias = ( @@ -456,6 +465,29 @@ class OpenAIServing: # Iterate through all beam inference results for i, result in enumerate(output): current_beam = all_beams[i] + + # check for error finish reason and abort beam search + if result.outputs[0].finish_reason == "error": + # yield error output and terminate beam search + yield RequestOutput( + request_id=request_id, + prompt=prompt_text, + outputs=[ + CompletionOutput( + index=0, + text="", + token_ids=[], + cumulative_logprob=None, + logprobs=None, + finish_reason="error", + ) + ], + finished=True, + prompt_token_ids=prompt_token_ids, + prompt_logprobs=None, + ) + return + if result.outputs[0].logprobs is not None: logprobs = result.outputs[0].logprobs[0] all_beams_token_id.extend(list(logprobs.keys())) @@ -780,6 +812,35 @@ class OpenAIServing: ) return json_str + def _raise_if_error(self, finish_reason: str | None, request_id: str) -> None: + """Raise GenerationError if finish_reason indicates an error.""" + if finish_reason == "error": + logger.error( + "Request %s failed with an internal error during generation", + request_id, + ) + raise GenerationError("Internal server error") + + def _convert_generation_error_to_response( + self, e: GenerationError + ) -> ErrorResponse: + """Convert GenerationError to ErrorResponse.""" + return self.create_error_response( + str(e), + err_type="InternalServerError", + status_code=e.status_code, + ) + + def _convert_generation_error_to_streaming_response( + self, e: GenerationError + ) -> str: + """Convert GenerationError to streaming error response.""" + return self.create_streaming_error_response( + str(e), + err_type="InternalServerError", + status_code=e.status_code, + ) + async def _check_model( self, request: AnyRequest, diff --git a/vllm/entrypoints/openai/serving_responses.py b/vllm/entrypoints/openai/serving_responses.py index 91616a78e11dc..60d14337dcaaf 100644 --- a/vllm/entrypoints/openai/serving_responses.py +++ b/vllm/entrypoints/openai/serving_responses.py @@ -50,6 +50,7 @@ from openai.types.responses.response_reasoning_item import ( ) from openai.types.responses.tool import Mcp, Tool from openai_harmony import Message as OpenAIHarmonyMessage +from pydantic import TypeAdapter from vllm import envs from vllm.engine.protocol import EngineClient @@ -94,7 +95,10 @@ from vllm.entrypoints.openai.protocol import ( ResponseUsage, StreamingResponsesResponse, ) -from vllm.entrypoints.openai.serving_engine import OpenAIServing +from vllm.entrypoints.openai.serving_engine import ( + GenerationError, + OpenAIServing, +) from vllm.entrypoints.openai.serving_models import OpenAIServingModels from vllm.entrypoints.responses_utils import ( construct_input_messages, @@ -541,6 +545,8 @@ class OpenAIServingResponses(OpenAIServing): tokenizer, request_metadata, ) + except GenerationError as e: + return self._convert_generation_error_to_response(e) except Exception as e: return self.create_error_response(str(e)) @@ -648,6 +654,8 @@ class OpenAIServingResponses(OpenAIServing): status = "incomplete" elif context.finish_reason == "abort": status = "cancelled" + else: + self._raise_if_error(context.finish_reason, request.request_id) else: status = "incomplete" elif isinstance(context, ParsableContext): @@ -673,6 +681,9 @@ class OpenAIServingResponses(OpenAIServing): assert len(final_res.outputs) == 1 final_output = final_res.outputs[0] + # finish_reason='error' indicates retryable internal error + self._raise_if_error(final_output.finish_reason, request.request_id) + output = self._make_response_output_items(request, final_output, tokenizer) if request.enable_response_messages: @@ -1066,6 +1077,8 @@ class OpenAIServingResponses(OpenAIServing): async for event in generator: event_deque.append(event) new_event_signal.set() # Signal new event available + except GenerationError as e: + response = self._convert_generation_error_to_response(e) except Exception as e: logger.exception("Background request failed for %s", request.request_id) response = self.create_error_response(str(e)) @@ -1089,6 +1102,8 @@ class OpenAIServingResponses(OpenAIServing): ): try: response = await self.responses_full_generator(request, *args, **kwargs) + except GenerationError as e: + response = self._convert_generation_error_to_response(e) except Exception as e: logger.exception("Background request failed for %s", request.request_id) response = self.create_error_response(str(e)) @@ -1227,6 +1242,8 @@ class OpenAIServingResponses(OpenAIServing): continue if ctx.last_output.outputs: output = ctx.last_output.outputs[0] + # finish_reason='error' indicates a retryable error + self._raise_if_error(output.finish_reason, request.request_id) if reasoning_parser: delta_message = reasoning_parser.extract_reasoning_streaming( previous_text=previous_text, @@ -1522,6 +1539,9 @@ class OpenAIServingResponses(OpenAIServing): async for ctx in result_generator: assert isinstance(ctx, StreamingHarmonyContext) + # finish_reason='error' indicates a retryable error + self._raise_if_error(ctx.finish_reason, request.request_id) + if ctx.is_expecting_start(): current_output_index += 1 sent_output_item_added = False @@ -2016,18 +2036,25 @@ class OpenAIServingResponses(OpenAIServing): ) ) - async for event_data in processer( - request, - sampling_params, - result_generator, - context, - model_name, - tokenizer, - request_metadata, - created_time, - _increment_sequence_number_and_return, - ): - yield event_data + try: + async for event_data in processer( + request, + sampling_params, + result_generator, + context, + model_name, + tokenizer, + request_metadata, + created_time, + _increment_sequence_number_and_return, + ): + yield event_data + except GenerationError as e: + error_json = self._convert_generation_error_to_streaming_response(e) + yield _increment_sequence_number_and_return( + TypeAdapter(StreamingResponsesResponse).validate_json(error_json) + ) + return async def empty_async_generator(): # A hack to trick Python to think this is a generator but diff --git a/vllm/v1/core/block_pool.py b/vllm/v1/core/block_pool.py index cfb2c02e00f1b..c779e3d34b3ed 100644 --- a/vllm/v1/core/block_pool.py +++ b/vllm/v1/core/block_pool.py @@ -397,6 +397,25 @@ class BlockPool: [block for block in blocks_list if block.ref_cnt == 0 and not block.is_null] ) + def evict_blocks(self, block_ids: set[int]) -> None: + """evict blocks from the prefix cache by their block IDs. + + only evicts blocks that are currently cached (have a hash). blocks + with ref_cnt > 0 are not freed from the block pool, only evicted + from the prefix cache hash table. + + Args: + block_ids: Set of block IDs to evict from cache. + """ + for block_id in block_ids: + assert block_id < len(self.blocks), ( + f"Invalid block_id {block_id} >= {len(self.blocks)}. " + f"This indicates a bug in the KV connector - workers should " + f"only report block IDs that were allocated by the scheduler." + ) + block = self.blocks[block_id] + self._maybe_evict_cached_block(block) + def reset_prefix_cache(self) -> bool: """Reset prefix cache. This function may be used in RLHF flows to invalid prefix caching after the weights are updated, diff --git a/vllm/v1/core/kv_cache_manager.py b/vllm/v1/core/kv_cache_manager.py index 33e8c81514c5f..13086a66f6ea6 100644 --- a/vllm/v1/core/kv_cache_manager.py +++ b/vllm/v1/core/kv_cache_manager.py @@ -333,6 +333,14 @@ class KVCacheManager: """ self.coordinator.free(request.request_id) + def evict_blocks(self, block_ids: set[int]) -> None: + """evict blocks from the prefix cache by their block IDs. + + Args: + block_ids: Set of block IDs to evict from cache. + """ + self.block_pool.evict_blocks(block_ids) + def reset_prefix_cache(self) -> bool: """Reset prefix cache. This function may be used in RLHF flows to invalidate prefix caching after the weights are updated, diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index d858e840039c4..c3d504f2e72c3 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -106,6 +106,7 @@ class Scheduler(SchedulerInterface): # KV Connector pushes/pull of remote KVs for P/D and offloading. self.connector = None self.connector_prefix_cache_stats: PrefixCacheStats | None = None + self.recompute_kv_load_failures = True if self.vllm_config.kv_transfer_config is not None: assert not self.is_encoder_decoder, ( "Encoder-decoder models are not currently supported with KV connectors" @@ -117,6 +118,10 @@ class Scheduler(SchedulerInterface): ) if self.log_stats: self.connector_prefix_cache_stats = PrefixCacheStats() + kv_load_failure_policy = ( + self.vllm_config.kv_transfer_config.kv_load_failure_policy + ) + self.recompute_kv_load_failures = kv_load_failure_policy == "recompute" self.kv_event_publisher = EventPublisherFactory.create( self.kv_events_config, @@ -1066,7 +1071,7 @@ class Scheduler(SchedulerInterface): for req_id, num_tokens_scheduled in num_scheduled_tokens.items(): assert num_tokens_scheduled > 0 if failed_kv_load_req_ids and req_id in failed_kv_load_req_ids: - # Skip requests that were recovered from KV load failure + # skip failed or rescheduled requests from KV load failure continue request = self.requests.get(req_id) if request is None: @@ -1177,6 +1182,21 @@ class Scheduler(SchedulerInterface): # This is a rare case and unlikely to impact performance. self.waiting.remove_requests(stopped_preempted_reqs) + if failed_kv_load_req_ids and not self.recompute_kv_load_failures: + requests = [self.requests[req_id] for req_id in failed_kv_load_req_ids] + self.finish_requests(failed_kv_load_req_ids, RequestStatus.FINISHED_ERROR) + for request in requests: + outputs[request.client_index].append( + EngineCoreOutput( + request_id=request.request_id, + new_token_ids=[], + finish_reason=request.get_finished_reason(), + events=request.take_events(), + trace_headers=request.trace_headers, + num_cached_tokens=request.num_cached_tokens, + ) + ) + # KV Connector: update state for finished KV Transfers. if kv_connector_output: self._update_from_kv_xfer_finished(kv_connector_output) @@ -1610,8 +1630,11 @@ class Scheduler(SchedulerInterface): self._free_blocks(self.requests[req_id]) def _update_requests_with_invalid_blocks( - self, requests: Iterable[Request], invalid_block_ids: set[int] - ) -> tuple[set[str], int]: + self, + requests: Iterable[Request], + invalid_block_ids: set[int], + evict_blocks: bool = True, + ) -> tuple[set[str], int, set[int]]: """ Identify and update requests affected by invalid KV cache blocks. @@ -1623,16 +1646,21 @@ class Scheduler(SchedulerInterface): Args: requests: The set of requests to scan for invalid blocks. invalid_block_ids: IDs of invalid blocks. + evict_blocks: Whether to collect blocks for eviction (False for + async requests which aren't cached yet). Returns: tuple: - affected_req_ids (set[str]): IDs of requests impacted by invalid blocks. - total_affected_tokens (int): Total number of tokens that must - be recomputed across all affected requests (for observability). + be recomputed across all affected requests. + - blocks_to_evict (set[int]): Block IDs to evict from cache, + including invalid blocks and downstream dependent blocks. """ affected_req_ids: set[str] = set() total_affected_tokens = 0 + blocks_to_evict: set[int] = set() # If a block is invalid and shared by multiple requests in the batch, # these requests must be rescheduled, but only the first will recompute # it. This set tracks blocks already marked for recomputation. @@ -1690,6 +1718,9 @@ class Scheduler(SchedulerInterface): ) total_affected_tokens += num_affected_tokens request.num_external_computed_tokens -= num_affected_tokens + # collect invalid block and all downstream dependent blocks + if evict_blocks: + blocks_to_evict.update(req_block_ids[idx:]) if is_affected: if not marked_invalid_block: @@ -1705,47 +1736,70 @@ class Scheduler(SchedulerInterface): affected_req_ids.add(request.request_id) - return affected_req_ids, total_affected_tokens + return affected_req_ids, total_affected_tokens, blocks_to_evict def _handle_invalid_blocks(self, invalid_block_ids: set[int]) -> set[str]: - total_requests_to_reschedule = 0 - total_tokens_to_reschedule = 0 + """ + Handle requests affected by invalid KV cache blocks. - # --- Handle async KV loads (WAITING_FOR_REMOTE_KVS) --- + Returns: + Set of affected request IDs to skip in update_from_output main loop. + """ + should_fail = not self.recompute_kv_load_failures + + # handle async KV loads (not cached yet, evict_blocks=False) async_load_reqs = ( req for req in self.waiting if req.status == RequestStatus.WAITING_FOR_REMOTE_KVS ) - async_affected_req_ids, num_tokens_to_reschedule = ( + async_failed_req_ids, num_failed_tokens, _ = ( self._update_requests_with_invalid_blocks( - async_load_reqs, invalid_block_ids + async_load_reqs, invalid_block_ids, evict_blocks=False ) ) - total_requests_to_reschedule += len(async_affected_req_ids) - total_tokens_to_reschedule += num_tokens_to_reschedule + total_failed_requests = len(async_failed_req_ids) + total_failed_tokens = num_failed_tokens - # Mark requests with async KV load failures; they will be rescheduled - # once loading completes. - self.failed_recving_kv_req_ids |= async_affected_req_ids - - # --- Handle sync KV loads (running requests) --- - sync_affected_req_ids, num_tokens_to_reschedule = ( - self._update_requests_with_invalid_blocks(self.running, invalid_block_ids) + # handle sync loads (may be cached, collect blocks for eviction) + sync_failed_req_ids, num_failed_tokens, sync_blocks_to_evict = ( + self._update_requests_with_invalid_blocks( + self.running, invalid_block_ids, evict_blocks=True + ) ) - total_requests_to_reschedule += len(sync_affected_req_ids) - total_tokens_to_reschedule += num_tokens_to_reschedule + total_failed_requests += len(sync_failed_req_ids) + total_failed_tokens += num_failed_tokens - if total_requests_to_reschedule: - logger.warning( - "Recovered from KV load failure: " - "%d request(s) rescheduled (%d tokens affected).", - total_requests_to_reschedule, - total_tokens_to_reschedule, + if not total_failed_requests: + return set() + + # evict invalid blocks and downstream dependent blocks from cache + # only when not using recompute policy (where blocks will be recomputed + # and reused by other requests sharing them) + if sync_blocks_to_evict and not self.recompute_kv_load_failures: + self.kv_cache_manager.evict_blocks(sync_blocks_to_evict) + + if should_fail: + all_failed_req_ids = async_failed_req_ids | sync_failed_req_ids + logger.error( + "Failing %d request(s) due to KV load failure " + "(failure_policy=fail, %d tokens affected). Request IDs: %s", + total_failed_requests, + total_failed_tokens, + all_failed_req_ids, ) + return all_failed_req_ids - # Return the IDs of affected running requests to skip in - # update_from_output. - return sync_affected_req_ids + logger.warning( + "Recovered from KV load failure: " + "%d request(s) rescheduled (%d tokens affected).", + total_failed_requests, + total_failed_tokens, + ) + + # Mark async requests with KV load failures for retry once loading completes + self.failed_recving_kv_req_ids |= async_failed_req_ids + # Return sync affected IDs to skip in update_from_output + return sync_failed_req_ids diff --git a/vllm/v1/engine/__init__.py b/vllm/v1/engine/__init__.py index ce2aae77108da..4f54d12f4b8d0 100644 --- a/vllm/v1/engine/__init__.py +++ b/vllm/v1/engine/__init__.py @@ -19,24 +19,27 @@ from vllm.v1.serial_utils import UtilityResult # These are possible values of RequestOutput.finish_reason, # so form part of the external API. -FINISH_REASON_STRINGS = ("stop", "length", "abort") +FINISH_REASON_STRINGS = ("stop", "length", "abort", "error") class FinishReason(enum.IntEnum): """ - Reason a request finished - stop, length, or abort. + Reason a request finished - stop, length, abort, or error. Int rather than Str for more compact serialization. stop - a stop string was emitted length - max_tokens was consumed, or max_model_len was reached - abort - aborted for another reason + abort - aborted by client + error - retryable request-level internal error (e.g., KV load failure). + Invariant: always converted to 500 Internal Server Error. """ STOP = 0 LENGTH = 1 ABORT = 2 + ERROR = 3 def __str__(self): return FINISH_REASON_STRINGS[self.value] diff --git a/vllm/v1/request.py b/vllm/v1/request.py index 33762fe34e64f..a775e840e841c 100644 --- a/vllm/v1/request.py +++ b/vllm/v1/request.py @@ -255,6 +255,7 @@ class RequestStatus(enum.IntEnum): FINISHED_LENGTH_CAPPED = enum.auto() FINISHED_ABORTED = enum.auto() FINISHED_IGNORED = enum.auto() + FINISHED_ERROR = enum.auto() def __str__(self): return self.name @@ -277,4 +278,5 @@ _FINISHED_REASON_MAP = { RequestStatus.FINISHED_LENGTH_CAPPED: FinishReason.LENGTH, RequestStatus.FINISHED_ABORTED: FinishReason.ABORT, RequestStatus.FINISHED_IGNORED: FinishReason.LENGTH, + RequestStatus.FINISHED_ERROR: FinishReason.ERROR, }