[P/D] KV Load Failure Recovery/Abort Configuration (#26813)

Signed-off-by: Will Eaton <weaton@redhat.com>
Signed-off-by: Will Eaton <me@wseaton.com>
Signed-off-by: Nick Hill <nhill@redhat.com>
Co-authored-by: Mark McLoughlin <markmc@redhat.com>
Co-authored-by: Nick Hill <nhill@redhat.com>
Co-authored-by: chaunceyjiang <chaunceyjiang@gmail.com>
This commit is contained in:
Will Eaton 2025-12-10 14:00:52 -05:00 committed by GitHub
parent e8e8cd73e5
commit a9e4106f28
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
16 changed files with 1552 additions and 48 deletions

View File

@ -0,0 +1,228 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass, field
from http import HTTPStatus
from typing import Any
from unittest.mock import AsyncMock, MagicMock
import pytest
from vllm.config.multimodal import MultiModalConfig
from vllm.entrypoints.openai.protocol import ChatCompletionRequest, ErrorResponse
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
from vllm.entrypoints.openai.serving_models import BaseModelPath, OpenAIServingModels
from vllm.outputs import CompletionOutput, RequestOutput
from vllm.transformers_utils.tokenizer import get_tokenizer
from vllm.v1.engine.async_llm import AsyncLLM
MODEL_NAME = "openai-community/gpt2"
MODEL_NAME_SHORT = "gpt2"
BASE_MODEL_PATHS = [
BaseModelPath(name=MODEL_NAME, model_path=MODEL_NAME),
BaseModelPath(name=MODEL_NAME_SHORT, model_path=MODEL_NAME_SHORT),
]
@dataclass
class MockHFConfig:
model_type: str = "any"
@dataclass
class MockModelConfig:
task = "generate"
runner_type = "generate"
tokenizer = MODEL_NAME
trust_remote_code = False
tokenizer_mode = "auto"
max_model_len = 100
tokenizer_revision = None
multimodal_config = MultiModalConfig()
hf_config = MockHFConfig()
logits_processor_pattern = None
logits_processors: list[str] | None = None
diff_sampling_param: dict | None = None
allowed_local_media_path: str = ""
allowed_media_domains: list[str] | None = None
encoder_config = None
generation_config: str = "auto"
media_io_kwargs: dict[str, dict[str, Any]] = field(default_factory=dict)
skip_tokenizer_init = False
def get_diff_sampling_param(self):
return self.diff_sampling_param or {}
def _build_serving_chat(engine: AsyncLLM) -> OpenAIServingChat:
models = OpenAIServingModels(
engine_client=engine,
base_model_paths=BASE_MODEL_PATHS,
)
serving_chat = OpenAIServingChat(
engine,
models,
response_role="assistant",
request_logger=None,
chat_template=None,
chat_template_content_format="auto",
)
async def _fake_process_inputs(
request_id,
engine_prompt,
sampling_params,
*,
lora_request,
trace_headers,
priority,
):
return dict(engine_prompt), {}
async def _fake_preprocess_chat(*args, **kwargs):
# return conversation, request_prompts, engine_prompts
return (
[{"role": "user", "content": "Test"}],
[[1, 2, 3]],
[{"prompt_token_ids": [1, 2, 3]}],
)
serving_chat._process_inputs = AsyncMock(side_effect=_fake_process_inputs)
serving_chat._preprocess_chat = AsyncMock(side_effect=_fake_preprocess_chat)
return serving_chat
@pytest.mark.asyncio
async def test_chat_error_non_stream():
"""test finish_reason='error' returns 500 InternalServerError (non-streaming)"""
mock_engine = MagicMock(spec=AsyncLLM)
mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME)
mock_engine.errored = False
mock_engine.model_config = MockModelConfig()
mock_engine.input_processor = MagicMock()
mock_engine.io_processor = MagicMock()
serving_chat = _build_serving_chat(mock_engine)
completion_output = CompletionOutput(
index=0,
text="",
token_ids=[],
cumulative_logprob=None,
logprobs=None,
finish_reason="error",
)
request_output = RequestOutput(
request_id="test-id",
prompt="Test prompt",
prompt_token_ids=[1, 2, 3],
prompt_logprobs=None,
outputs=[completion_output],
finished=True,
metrics=None,
lora_request=None,
encoder_prompt=None,
encoder_prompt_token_ids=None,
)
async def mock_generate(*args, **kwargs):
yield request_output
mock_engine.generate = MagicMock(side_effect=mock_generate)
request = ChatCompletionRequest(
model=MODEL_NAME,
messages=[{"role": "user", "content": "Test prompt"}],
max_tokens=10,
stream=False,
)
response = await serving_chat.create_chat_completion(request)
assert isinstance(response, ErrorResponse)
assert response.error.type == "InternalServerError"
assert response.error.message == "Internal server error"
assert response.error.code == HTTPStatus.INTERNAL_SERVER_ERROR
@pytest.mark.asyncio
async def test_chat_error_stream():
"""test finish_reason='error' returns 500 InternalServerError (streaming)"""
mock_engine = MagicMock(spec=AsyncLLM)
mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME)
mock_engine.errored = False
mock_engine.model_config = MockModelConfig()
mock_engine.input_processor = MagicMock()
mock_engine.io_processor = MagicMock()
serving_chat = _build_serving_chat(mock_engine)
completion_output_1 = CompletionOutput(
index=0,
text="Hello",
token_ids=[100],
cumulative_logprob=None,
logprobs=None,
finish_reason=None,
)
request_output_1 = RequestOutput(
request_id="test-id",
prompt="Test prompt",
prompt_token_ids=[1, 2, 3],
prompt_logprobs=None,
outputs=[completion_output_1],
finished=False,
metrics=None,
lora_request=None,
encoder_prompt=None,
encoder_prompt_token_ids=None,
)
completion_output_2 = CompletionOutput(
index=0,
text="Hello",
token_ids=[100],
cumulative_logprob=None,
logprobs=None,
finish_reason="error",
)
request_output_2 = RequestOutput(
request_id="test-id",
prompt="Test prompt",
prompt_token_ids=[1, 2, 3],
prompt_logprobs=None,
outputs=[completion_output_2],
finished=True,
metrics=None,
lora_request=None,
encoder_prompt=None,
encoder_prompt_token_ids=None,
)
async def mock_generate(*args, **kwargs):
yield request_output_1
yield request_output_2
mock_engine.generate = MagicMock(side_effect=mock_generate)
request = ChatCompletionRequest(
model=MODEL_NAME,
messages=[{"role": "user", "content": "Test prompt"}],
max_tokens=10,
stream=True,
)
response = await serving_chat.create_chat_completion(request)
chunks = []
async for chunk in response:
chunks.append(chunk)
assert len(chunks) >= 2
assert any("Internal server error" in chunk for chunk in chunks), (
f"Expected error message in chunks: {chunks}"
)
assert chunks[-1] == "data: [DONE]\n\n"

View File

@ -0,0 +1,216 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass, field
from http import HTTPStatus
from typing import Any
from unittest.mock import AsyncMock, MagicMock
import pytest
from vllm.config.multimodal import MultiModalConfig
from vllm.entrypoints.openai.protocol import CompletionRequest, ErrorResponse
from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
from vllm.entrypoints.openai.serving_models import BaseModelPath, OpenAIServingModels
from vllm.outputs import CompletionOutput, RequestOutput
from vllm.transformers_utils.tokenizer import get_tokenizer
from vllm.v1.engine.async_llm import AsyncLLM
MODEL_NAME = "openai-community/gpt2"
MODEL_NAME_SHORT = "gpt2"
BASE_MODEL_PATHS = [
BaseModelPath(name=MODEL_NAME, model_path=MODEL_NAME),
BaseModelPath(name=MODEL_NAME_SHORT, model_path=MODEL_NAME_SHORT),
]
@dataclass
class MockHFConfig:
model_type: str = "any"
@dataclass
class MockModelConfig:
task = "generate"
runner_type = "generate"
tokenizer = MODEL_NAME
trust_remote_code = False
tokenizer_mode = "auto"
max_model_len = 100
tokenizer_revision = None
multimodal_config = MultiModalConfig()
hf_config = MockHFConfig()
logits_processor_pattern = None
logits_processors: list[str] | None = None
diff_sampling_param: dict | None = None
allowed_local_media_path: str = ""
allowed_media_domains: list[str] | None = None
encoder_config = None
generation_config: str = "auto"
media_io_kwargs: dict[str, dict[str, Any]] = field(default_factory=dict)
skip_tokenizer_init = False
def get_diff_sampling_param(self):
return self.diff_sampling_param or {}
def _build_serving_completion(engine: AsyncLLM) -> OpenAIServingCompletion:
models = OpenAIServingModels(
engine_client=engine,
base_model_paths=BASE_MODEL_PATHS,
)
serving_completion = OpenAIServingCompletion(
engine,
models,
request_logger=None,
)
async def _fake_process_inputs(
request_id,
engine_prompt,
sampling_params,
*,
lora_request,
trace_headers,
priority,
):
return dict(engine_prompt), {}
serving_completion._process_inputs = AsyncMock(side_effect=_fake_process_inputs)
return serving_completion
@pytest.mark.asyncio
async def test_completion_error_non_stream():
"""test finish_reason='error' returns 500 InternalServerError (non-streaming)"""
mock_engine = MagicMock(spec=AsyncLLM)
mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME)
mock_engine.errored = False
mock_engine.model_config = MockModelConfig()
mock_engine.input_processor = MagicMock()
mock_engine.io_processor = MagicMock()
serving_completion = _build_serving_completion(mock_engine)
completion_output = CompletionOutput(
index=0,
text="",
token_ids=[],
cumulative_logprob=None,
logprobs=None,
finish_reason="error",
)
request_output = RequestOutput(
request_id="test-id",
prompt="Test prompt",
prompt_token_ids=[1, 2, 3],
prompt_logprobs=None,
outputs=[completion_output],
finished=True,
metrics=None,
lora_request=None,
encoder_prompt=None,
encoder_prompt_token_ids=None,
)
async def mock_generate(*args, **kwargs):
yield request_output
mock_engine.generate = MagicMock(side_effect=mock_generate)
request = CompletionRequest(
model=MODEL_NAME,
prompt="Test prompt",
max_tokens=10,
stream=False,
)
response = await serving_completion.create_completion(request)
assert isinstance(response, ErrorResponse)
assert response.error.type == "InternalServerError"
assert response.error.message == "Internal server error"
assert response.error.code == HTTPStatus.INTERNAL_SERVER_ERROR
@pytest.mark.asyncio
async def test_completion_error_stream():
"""test finish_reason='error' returns 500 InternalServerError (streaming)"""
mock_engine = MagicMock(spec=AsyncLLM)
mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME)
mock_engine.errored = False
mock_engine.model_config = MockModelConfig()
mock_engine.input_processor = MagicMock()
mock_engine.io_processor = MagicMock()
serving_completion = _build_serving_completion(mock_engine)
completion_output_1 = CompletionOutput(
index=0,
text="Hello",
token_ids=[100],
cumulative_logprob=None,
logprobs=None,
finish_reason=None,
)
request_output_1 = RequestOutput(
request_id="test-id",
prompt="Test prompt",
prompt_token_ids=[1, 2, 3],
prompt_logprobs=None,
outputs=[completion_output_1],
finished=False,
metrics=None,
lora_request=None,
encoder_prompt=None,
encoder_prompt_token_ids=None,
)
completion_output_2 = CompletionOutput(
index=0,
text="Hello",
token_ids=[100],
cumulative_logprob=None,
logprobs=None,
finish_reason="error",
)
request_output_2 = RequestOutput(
request_id="test-id",
prompt="Test prompt",
prompt_token_ids=[1, 2, 3],
prompt_logprobs=None,
outputs=[completion_output_2],
finished=True,
metrics=None,
lora_request=None,
encoder_prompt=None,
encoder_prompt_token_ids=None,
)
async def mock_generate(*args, **kwargs):
yield request_output_1
yield request_output_2
mock_engine.generate = MagicMock(side_effect=mock_generate)
request = CompletionRequest(
model=MODEL_NAME,
prompt="Test prompt",
max_tokens=10,
stream=True,
)
response = await serving_completion.create_completion(request)
chunks = []
async for chunk in response:
chunks.append(chunk)
assert len(chunks) >= 2
assert any("Internal server error" in chunk for chunk in chunks), (
f"Expected error message in chunks: {chunks}"
)
assert chunks[-1] == "data: [DONE]\n\n"

View File

@ -0,0 +1,89 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from http import HTTPStatus
from unittest.mock import MagicMock
import pytest
from vllm.entrypoints.openai.protocol import ErrorResponse
from vllm.entrypoints.openai.serving_engine import GenerationError, OpenAIServing
@pytest.mark.asyncio
async def test_raise_if_error_raises_generation_error():
"""test _raise_if_error raises GenerationError"""
# create a minimal OpenAIServing instance
mock_engine = MagicMock()
mock_engine.model_config = MagicMock()
mock_engine.model_config.max_model_len = 100
mock_models = MagicMock()
serving = OpenAIServing(
engine_client=mock_engine,
models=mock_models,
request_logger=None,
)
# test that error finish_reason raises GenerationError
with pytest.raises(GenerationError) as exc_info:
serving._raise_if_error("error", "test-request-id")
assert str(exc_info.value) == "Internal server error"
assert exc_info.value.status_code == HTTPStatus.INTERNAL_SERVER_ERROR
# test that other finish_reasons don't raise
serving._raise_if_error("stop", "test-request-id") # should not raise
serving._raise_if_error("length", "test-request-id") # should not raise
serving._raise_if_error(None, "test-request-id") # should not raise
@pytest.mark.asyncio
async def test_convert_generation_error_to_response():
"""test _convert_generation_error_to_response creates proper ErrorResponse"""
mock_engine = MagicMock()
mock_engine.model_config = MagicMock()
mock_engine.model_config.max_model_len = 100
mock_models = MagicMock()
serving = OpenAIServing(
engine_client=mock_engine,
models=mock_models,
request_logger=None,
)
# create a GenerationError
gen_error = GenerationError("Internal server error")
# convert to ErrorResponse
error_response = serving._convert_generation_error_to_response(gen_error)
assert isinstance(error_response, ErrorResponse)
assert error_response.error.type == "InternalServerError"
assert error_response.error.message == "Internal server error"
assert error_response.error.code == HTTPStatus.INTERNAL_SERVER_ERROR
@pytest.mark.asyncio
async def test_convert_generation_error_to_streaming_response():
"""test _convert_generation_error_to_streaming_response output"""
mock_engine = MagicMock()
mock_engine.model_config = MagicMock()
mock_engine.model_config.max_model_len = 100
mock_models = MagicMock()
serving = OpenAIServing(
engine_client=mock_engine,
models=mock_models,
request_logger=None,
)
# create a GenerationError
gen_error = GenerationError("Internal server error")
# convert to streaming error response
error_json = serving._convert_generation_error_to_streaming_response(gen_error)
assert isinstance(error_json, str)
assert "Internal server error" in error_json
assert "InternalServerError" in error_json

View File

@ -0,0 +1,163 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
test that invalid blocks are evicted from prefix cache to prevent pollution.
verifies that when sync-loading fails, invalid blocks are removed from the
prefix cache hash table so future requests cannot match and reuse corrupted data.
"""
from collections.abc import Callable
from unittest.mock import Mock
import pytest
from vllm.v1.core.sched.scheduler import Scheduler
from vllm.v1.request import Request, RequestStatus
from .utils import (
create_model_runner_output,
create_request,
create_scheduler,
create_vllm_config,
)
pytestmark = pytest.mark.cpu_test
def _make_get_num_new_matched_tokens(
req_num_new_matched_tokens: dict[str, int],
async_load: bool,
) -> Callable[[Request, int], tuple[int, bool]]:
def get_num_new_matched_tokens(request: Request, _: int) -> tuple[int, bool]:
value = req_num_new_matched_tokens.get(request.request_id, 0)
return value, async_load
return get_num_new_matched_tokens
@pytest.fixture
def fail_scheduler():
"""scheduler with kv_load_failure_policy='fail'"""
vllm_config = create_vllm_config()
vllm_config.kv_transfer_config.kv_load_failure_policy = "fail"
return create_scheduler(vllm_config)
def test_invalid_blocks_evicted_prevents_cache_pollution(
fail_scheduler: Scheduler,
):
"""
verify invalid blocks are evicted to prevent future cache hits.
scenario:
1. request 1 loads externally-computed blocks (sync mode)
2. some blocks fail to load and are marked invalid
3. with fail policy, invalid blocks should be evicted from prefix cache
4. request is marked as FINISHED_ERROR
"""
num_prompt_blocks = 100
num_external_computed_blocks = 99
invalid_block_idx = 50
num_prompt_tokens = num_prompt_blocks * fail_scheduler.block_size
num_external_computed_tokens = (
num_external_computed_blocks * fail_scheduler.block_size
)
# request 1: will have invalid blocks
request1 = create_request(num_tokens=num_prompt_tokens, request_id=1)
fail_scheduler.add_request(request=request1)
req_num_new_matched_tokens = {
request1.request_id: num_external_computed_tokens,
}
# mock connector indicating sync load
fail_scheduler.connector = Mock()
fail_scheduler.connector.get_num_new_matched_tokens.side_effect = (
_make_get_num_new_matched_tokens(req_num_new_matched_tokens, False)
)
fail_scheduler.connector.request_finished.return_value = (False, None)
fail_scheduler.connector.take_events.return_value = ()
scheduler_output = fail_scheduler.schedule()
# request should be running with sync KV load
assert len(fail_scheduler.running) == 1
assert request1.status == RequestStatus.RUNNING
# get allocated block IDs
req_block_ids = scheduler_output.scheduled_new_reqs[0].block_ids[0]
invalid_block_id = req_block_ids[invalid_block_idx]
invalid_block_ids = {invalid_block_id}
# get the block object to verify eviction later
block = fail_scheduler.kv_cache_manager.block_pool.blocks[invalid_block_id]
# cache the blocks to simulate they've been computed and cached
# (in real scenario blocks would be cached after compute)
fail_scheduler.kv_cache_manager.cache_blocks(request1, num_external_computed_tokens)
# verify block has a hash (is cached) before reporting invalid blocks
assert block.block_hash is not None, (
f"block {invalid_block_id} should be cached (have a hash) before "
f"eviction test, but hash is None"
)
# report invalid blocks
model_runner_output = create_model_runner_output(
[request1],
invalid_block_ids=invalid_block_ids,
use_eos=False,
)
fail_scheduler.update_from_output(scheduler_output, model_runner_output)
# verify request finished with error (fail policy)
assert request1.status == RequestStatus.FINISHED_ERROR
# critical assertion: invalid block and all subsequent blocks should be evicted
# all blocks from invalid_block_idx onwards become invalid since they were
# computed based on the failed block
for idx in range(invalid_block_idx, len(req_block_ids)):
block_id = req_block_ids[idx]
block_obj = fail_scheduler.kv_cache_manager.block_pool.blocks[block_id]
assert block_obj.block_hash is None, (
f"block {block_id} at index {idx} should have been evicted "
f"(hash reset to None), but hash is {block_obj.block_hash}. "
f"All blocks from index {invalid_block_idx} onwards should be evicted "
f"since they depend on the invalid block at index {invalid_block_idx}."
)
# verify cache contains exactly the valid blocks (before first affected block)
# and none of the invalid blocks (from first affected block onwards)
# valid blocks: all blocks before invalid_block_idx should be cached
for idx in range(invalid_block_idx):
block_id = req_block_ids[idx]
block_obj = fail_scheduler.kv_cache_manager.block_pool.blocks[block_id]
assert block_obj.block_hash is not None, (
f"valid block {block_id} at index {idx} should still be cached "
f"(have a hash), but hash is None. Only blocks from index "
f"{invalid_block_idx} onwards should be evicted."
)
# invalid blocks: verify they're not in the cached_block_hash_to_block map
cached_blocks = (
fail_scheduler.kv_cache_manager.block_pool.cached_block_hash_to_block
)
cached_block_ids = {
b.block_id
for blocks_val in cached_blocks._cache.values()
for b in (
[blocks_val] if not isinstance(blocks_val, dict) else blocks_val.values()
)
}
for idx in range(invalid_block_idx, len(req_block_ids)):
block_id = req_block_ids[idx]
assert block_id not in cached_block_ids, (
f"invalid block {block_id} at index {idx} should not be in cache hash table"
)

View File

@ -0,0 +1,147 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Callable
from unittest.mock import Mock
import pytest
from vllm.v1.core.sched.scheduler import Scheduler
from vllm.v1.request import FinishReason, Request, RequestStatus
from .utils import (
create_model_runner_output,
create_request,
create_scheduler,
create_vllm_config,
)
pytestmark = pytest.mark.cpu_test
def _make_get_num_new_matched_tokens(
req_num_new_matched_tokens: dict[str, int],
async_load: bool,
) -> Callable[[Request, int], tuple[int, bool]]:
def get_num_new_matched_tokens(request: Request, _: int) -> tuple[int, bool]:
value = req_num_new_matched_tokens.get(request.request_id, 0)
return value, async_load
return get_num_new_matched_tokens
@pytest.fixture
def fail_scheduler():
"""scheduler with kv_load_failure_policy='fail'"""
vllm_config = create_vllm_config()
vllm_config.kv_transfer_config.kv_load_failure_policy = "fail"
return create_scheduler(vllm_config)
def test_error_propagation_sync_load(fail_scheduler: Scheduler):
"""test invalid_block_ids with fail policy -> FINISHED_ERROR (sync load)"""
num_prompt_blocks = 100
num_external_computed_blocks = 99
invalid_block_idx = 50
num_prompt_tokens = num_prompt_blocks * fail_scheduler.block_size
num_external_computed_tokens = (
num_external_computed_blocks * fail_scheduler.block_size
)
request = create_request(num_tokens=num_prompt_tokens)
fail_scheduler.add_request(request=request)
req_num_new_matched_tokens = {
request.request_id: num_external_computed_tokens,
}
fail_scheduler.connector = Mock()
fail_scheduler.connector.get_num_new_matched_tokens.side_effect = (
_make_get_num_new_matched_tokens(req_num_new_matched_tokens, False)
)
fail_scheduler.connector.request_finished.return_value = (False, None)
fail_scheduler.connector.take_events.return_value = ()
scheduler_output = fail_scheduler.schedule()
assert len(fail_scheduler.running) == 1
assert len(scheduler_output.scheduled_new_reqs) == 1
assert fail_scheduler.connector.get_num_new_matched_tokens.call_count == 1
req_block_ids = scheduler_output.scheduled_new_reqs[0].block_ids[0]
invalid_block_ids = {req_block_ids[invalid_block_idx]}
model_runner_output = create_model_runner_output(
[request],
invalid_block_ids=invalid_block_ids,
use_eos=True,
)
outputs = fail_scheduler.update_from_output(scheduler_output, model_runner_output)
assert request.status == RequestStatus.FINISHED_ERROR
assert request.get_finished_reason() == FinishReason.ERROR
assert len(outputs) == 1
engine_outputs = next(iter(outputs.values()))
assert len(engine_outputs.outputs) == 1
output = engine_outputs.outputs[0]
assert output.request_id == request.request_id
assert output.finish_reason == FinishReason.ERROR
assert len(fail_scheduler.running) == 0
def test_error_propagation_async_load(fail_scheduler: Scheduler):
"""test invalid_block_ids with fail policy -> FINISHED_ERROR (async load)"""
num_prompt_blocks = 100
num_external_computed_blocks = 99
invalid_block_idx = 50
num_prompt_tokens = num_prompt_blocks * fail_scheduler.block_size
num_external_computed_tokens = (
num_external_computed_blocks * fail_scheduler.block_size
)
request = create_request(num_tokens=num_prompt_tokens)
fail_scheduler.add_request(request=request)
req_num_new_matched_tokens = {
request.request_id: num_external_computed_tokens,
}
fail_scheduler.connector = Mock()
fail_scheduler.connector.get_num_new_matched_tokens.side_effect = (
_make_get_num_new_matched_tokens(req_num_new_matched_tokens, True)
)
fail_scheduler.connector.request_finished.return_value = (False, None)
fail_scheduler.connector.take_events.return_value = ()
scheduler_output = fail_scheduler.schedule()
assert len(fail_scheduler.waiting) == 1
assert request.status == RequestStatus.WAITING_FOR_REMOTE_KVS
assert request.num_computed_tokens == 0
(req_block_ids,) = fail_scheduler.kv_cache_manager.get_block_ids(request.request_id)
invalid_block_ids = {req_block_ids[invalid_block_idx]}
model_runner_output = create_model_runner_output(
reqs=[],
finished_recving=set(),
invalid_block_ids=invalid_block_ids,
use_eos=True,
)
outputs = fail_scheduler.update_from_output(scheduler_output, model_runner_output)
assert request.status == RequestStatus.FINISHED_ERROR
assert request.get_finished_reason() == FinishReason.ERROR
assert len(outputs) == 1
engine_outputs = next(iter(outputs.values()))
assert len(engine_outputs.outputs) == 1
output = engine_outputs.outputs[0]
assert output.request_id == request.request_id
assert output.finish_reason == FinishReason.ERROR
assert len(fail_scheduler.waiting) == 0

View File

@ -0,0 +1,454 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Tests for correctness in invalid block handling.
These tests verify correct behavior in three scenarios:
1. Sync recompute case: Blocks should not be freed for running requests
that need to recompute invalid blocks
2. Sync fail case: Invalid blocks must be evicted from cache when request fails
3. Async recompute case: Invalid blocks should not be cached after transfer
"""
from collections.abc import Callable
from unittest.mock import Mock
import pytest
from vllm.v1.core.sched.scheduler import Scheduler
from vllm.v1.request import FinishReason, Request, RequestStatus
from .utils import (
create_model_runner_output,
create_request,
create_scheduler,
create_vllm_config,
)
pytestmark = pytest.mark.cpu_test
def _make_get_num_new_matched_tokens(
req_num_new_matched_tokens: dict[str, int],
async_load: bool,
) -> Callable[[Request, int], tuple[int, bool]]:
def get_num_new_matched_tokens(request: Request, _: int) -> tuple[int, bool]:
value = req_num_new_matched_tokens.get(request.request_id, 0)
return value, async_load
return get_num_new_matched_tokens
@pytest.fixture
def fail_scheduler():
"""scheduler with kv_load_failure_policy='fail'"""
vllm_config = create_vllm_config()
vllm_config.kv_transfer_config.kv_load_failure_policy = "fail"
return create_scheduler(vllm_config)
@pytest.fixture
def recompute_scheduler():
"""scheduler with kv_load_failure_policy='recompute'"""
vllm_config = create_vllm_config()
vllm_config.kv_transfer_config.kv_load_failure_policy = "recompute"
return create_scheduler(vllm_config)
def test_sync_recompute_blocks_not_freed_for_running_requests(
recompute_scheduler: Scheduler,
):
"""
Test sync recompute case - blocks must not be freed for running requests.
When a running request has invalid blocks and retry_policy is 'recompute':
1. Request should remain in RUNNING state
2. num_computed_tokens should be truncated to invalid block boundary
3. Blocks should NOT be freed (request still needs them for recomputation)
4. Request should remain in scheduler.requests and scheduler.running
"""
num_prompt_blocks = 100
num_external_computed_blocks = 99
invalid_block_idx = 50
num_prompt_tokens = num_prompt_blocks * recompute_scheduler.block_size
num_external_computed_tokens = (
num_external_computed_blocks * recompute_scheduler.block_size
)
request = create_request(num_tokens=num_prompt_tokens)
recompute_scheduler.add_request(request=request)
req_num_new_matched_tokens = {
request.request_id: num_external_computed_tokens,
}
# mock connector indicating sync load
recompute_scheduler.connector = Mock()
recompute_scheduler.connector.get_num_new_matched_tokens.side_effect = (
_make_get_num_new_matched_tokens(req_num_new_matched_tokens, False)
)
recompute_scheduler.connector.request_finished.return_value = (False, None)
recompute_scheduler.connector.take_events.return_value = ()
scheduler_output = recompute_scheduler.schedule()
# request should be running with sync KV load
assert len(recompute_scheduler.running) == 1
assert len(scheduler_output.scheduled_new_reqs) == 1
assert request.status == RequestStatus.RUNNING
# get the allocated block IDs before invalid blocks are reported
req_block_ids = scheduler_output.scheduled_new_reqs[0].block_ids[0]
invalid_block_ids = {req_block_ids[invalid_block_idx]}
# store original num_computed_tokens for comparison
original_num_computed_tokens = request.num_computed_tokens
model_runner_output = create_model_runner_output(
[request],
invalid_block_ids=invalid_block_ids,
use_eos=False, # not finished - should continue running
)
outputs = recompute_scheduler.update_from_output(
scheduler_output, model_runner_output
)
# critical assertions for recompute case:
# 1. request should still be RUNNING (not finished, not aborted)
assert request.status == RequestStatus.RUNNING, (
f"Request should remain RUNNING for recompute, got {request.status}"
)
# 2. num_computed_tokens should be truncated to first invalid block
expected_truncated_tokens = invalid_block_idx * recompute_scheduler.block_size
assert request.num_computed_tokens == expected_truncated_tokens, (
f"num_computed_tokens should be truncated to {expected_truncated_tokens}, "
f"got {request.num_computed_tokens}"
)
assert request.num_computed_tokens < original_num_computed_tokens, (
"num_computed_tokens should be reduced after invalid block detection"
)
# 3. no output should be generated (request is still running)
# the request should be skipped in the output loop
assert len(outputs) == 0 or request.request_id not in [
out.request_id for outs in outputs.values() for out in outs.outputs
], "No output should be generated for recompute requests"
# 4. request should still be in running queue
assert request in recompute_scheduler.running, (
"Request should remain in running queue for recomputation"
)
# 5. request should still be in scheduler.requests (not deleted)
assert request.request_id in recompute_scheduler.requests, (
"Request should not be deleted from scheduler.requests"
)
# 6. blocks should NOT be freed - verify blocks are still allocated
try:
allocated_blocks = recompute_scheduler.kv_cache_manager.get_block_ids(
request.request_id
)
assert allocated_blocks is not None
assert len(allocated_blocks[0]) > 0, (
"Blocks should still be allocated for recomputation"
)
except KeyError:
pytest.fail(
"Blocks were freed incorrectly! Running requests need their blocks "
"to recompute invalid portions."
)
# 7. verify request can be rescheduled in next step
scheduler_output_2 = recompute_scheduler.schedule()
# request should appear in the new schedule to recompute invalid blocks
scheduled_req_ids = [
req.request_id for req in scheduler_output_2.scheduled_new_reqs
]
if scheduler_output_2.num_scheduled_tokens:
scheduled_req_ids.extend(scheduler_output_2.num_scheduled_tokens.keys())
assert (
request.request_id in scheduled_req_ids or len(recompute_scheduler.running) > 0
), "Request should be reschedulable for recomputation"
def test_sync_fail_invalid_blocks_evicted(fail_scheduler: Scheduler):
"""
Test sync fail case - invalid blocks must be evicted from cache.
When a request fails with policy='fail' and has invalid blocks from sync loading:
1. Request should be finished with FINISHED_ERROR
2. Invalid blocks should be evicted from the KV cache
3. Valid blocks (if shared) should remain in cache
4. Future requests should not reuse the invalid blocks
This test verifies that invalid blocks are properly evicted to prevent
cache corruption and reuse of invalid data.
"""
num_prompt_blocks = 100
num_external_computed_blocks = 99
invalid_block_idx = 50
num_prompt_tokens = num_prompt_blocks * fail_scheduler.block_size
num_external_computed_tokens = (
num_external_computed_blocks * fail_scheduler.block_size
)
request = create_request(num_tokens=num_prompt_tokens)
fail_scheduler.add_request(request=request)
req_num_new_matched_tokens = {
request.request_id: num_external_computed_tokens,
}
# mock connector indicating sync load
fail_scheduler.connector = Mock()
fail_scheduler.connector.get_num_new_matched_tokens.side_effect = (
_make_get_num_new_matched_tokens(req_num_new_matched_tokens, False)
)
fail_scheduler.connector.request_finished.return_value = (False, None)
fail_scheduler.connector.take_events.return_value = ()
scheduler_output = fail_scheduler.schedule()
# request should be running with sync KV load
assert len(fail_scheduler.running) == 1
assert request.status == RequestStatus.RUNNING
# get allocated block IDs
req_block_ids = scheduler_output.scheduled_new_reqs[0].block_ids[0]
invalid_block_id = req_block_ids[invalid_block_idx]
invalid_block_ids = {invalid_block_id}
# verify the block is in the block pool before we report it as invalid
block = fail_scheduler.kv_cache_manager.block_pool.blocks[invalid_block_id]
assert block is not None
# report invalid blocks - request should fail
model_runner_output = create_model_runner_output(
[request],
invalid_block_ids=invalid_block_ids,
use_eos=True,
)
outputs = fail_scheduler.update_from_output(scheduler_output, model_runner_output)
# verify request is finished with error
assert request.status == RequestStatus.FINISHED_ERROR
assert request.get_finished_reason() == FinishReason.ERROR
# verify output is generated
assert len(outputs) == 1
engine_outputs = next(iter(outputs.values()))
assert len(engine_outputs.outputs) == 1
output = engine_outputs.outputs[0]
assert output.request_id == request.request_id
assert output.finish_reason == FinishReason.ERROR
# verify the request was removed from scheduler
assert request.request_id not in fail_scheduler.requests
assert len(fail_scheduler.running) == 0
# critical: verify invalid block was actually freed from cache
# this is the key assertion - the invalid block should no longer be
# tracked by the KV cache manager for this request
# if it's still there, a future request could reuse the invalid data
try:
block_ids = fail_scheduler.kv_cache_manager.get_block_ids(request.request_id)
# if we get here, check if blocks were actually freed
if block_ids is not None and len(block_ids[0]) > 0:
pytest.fail(
f"Invalid blocks still tracked for finished request! "
f"Request {request.request_id} should have been freed but "
f"still has {len(block_ids[0])} blocks allocated."
)
# blocks list exists but is empty - this is fine, they were freed
except KeyError:
# expected - request completely removed from tracking
pass
# critical: verify invalid block was evicted from prefix cache
# the block should no longer have a hash (hash is reset on eviction)
assert block.block_hash is None, (
f"Invalid block {invalid_block_id} should have been evicted from cache "
f"(hash should be None), but hash is still {block.block_hash}"
)
def test_async_recompute_blocks_not_cached_when_invalid(
recompute_scheduler: Scheduler,
):
"""
Test async recompute case - invalid blocks not cached after transfer.
When async KV loading has invalid blocks and retry_policy is 'recompute':
1. Blocks are allocated but not cached yet
2. When async transfer completes, only valid blocks should be cached
3. Invalid blocks should never enter the prefix cache
This test verifies correctness, the failed_recving_kv_req_ids protection
ensures only valid blocks are cached when the transfer completes, and we
only evict blocks from cache that are already hashed in the block table.
"""
from unittest.mock import patch
num_prompt_blocks = 100
num_external_computed_blocks = 99
invalid_block_idx = 50
num_prompt_tokens = num_prompt_blocks * recompute_scheduler.block_size
num_external_computed_tokens = (
num_external_computed_blocks * recompute_scheduler.block_size
)
request = create_request(num_tokens=num_prompt_tokens)
recompute_scheduler.add_request(request=request)
req_num_new_matched_tokens = {
request.request_id: num_external_computed_tokens,
}
# mock connector indicating async load
recompute_scheduler.connector = Mock()
recompute_scheduler.connector.get_num_new_matched_tokens.side_effect = (
_make_get_num_new_matched_tokens(req_num_new_matched_tokens, True)
)
recompute_scheduler.connector.request_finished.return_value = (False, None)
recompute_scheduler.connector.take_events.return_value = ()
scheduler_output = recompute_scheduler.schedule()
# request should be waiting for remote KVs
assert len(recompute_scheduler.waiting) == 1
assert request.status == RequestStatus.WAITING_FOR_REMOTE_KVS
assert request.num_computed_tokens == 0
# get the allocated block IDs
(req_block_ids,) = recompute_scheduler.kv_cache_manager.get_block_ids(
request.request_id
)
invalid_block_id = req_block_ids[invalid_block_idx]
invalid_block_ids = {invalid_block_id}
# get the block object to verify it's not cached yet and stays uncached
block = recompute_scheduler.kv_cache_manager.block_pool.blocks[invalid_block_id]
# verify block has no hash before invalid blocks are reported
assert block.block_hash is None, (
"Async loading blocks should not be cached yet (no hash)"
)
# report invalid blocks (transfer not finished yet)
model_runner_output = create_model_runner_output(
reqs=[],
finished_recving=None, # transfer NOT finished
invalid_block_ids=invalid_block_ids,
use_eos=False,
)
# critical: spy on evict_blocks to verify it's NOT called for async blocks
original_evict_blocks = recompute_scheduler.kv_cache_manager.evict_blocks
evict_blocks_calls = []
def evict_blocks_spy(block_ids):
evict_blocks_calls.append(set(block_ids))
return original_evict_blocks(block_ids)
with patch.object(
recompute_scheduler.kv_cache_manager, "evict_blocks", evict_blocks_spy
):
recompute_scheduler.update_from_output(scheduler_output, model_runner_output)
# verify evict_blocks was NOT called (async blocks excluded from eviction)
assert len(evict_blocks_calls) == 0, (
f"evict_blocks should not be called for async-only invalid blocks, "
f"but was called {len(evict_blocks_calls)} time(s) with {evict_blocks_calls}"
)
# request should still be waiting (not finished with error due to recompute policy)
assert request.status == RequestStatus.WAITING_FOR_REMOTE_KVS
assert request.request_id in recompute_scheduler.failed_recving_kv_req_ids
# verify num_computed_tokens was truncated to before invalid block
expected_valid_tokens = invalid_block_idx * recompute_scheduler.block_size
assert request.num_computed_tokens == expected_valid_tokens
# verify invalid block still has no hash (was not evicted)
assert block.block_hash is None, (
f"Async loading blocks shouldn't be cached or evicted. "
f"Block {invalid_block_id} hash should be None but is {block.block_hash}"
)
# now simulate async transfer completing
model_runner_output_2 = create_model_runner_output(
reqs=[],
finished_recving={request.request_id},
invalid_block_ids=None,
use_eos=False,
)
recompute_scheduler.update_from_output(scheduler_output, model_runner_output_2)
# verify request is now marked as finished receiving and ready to be processed
assert request.request_id in recompute_scheduler.finished_recving_kv_req_ids
assert request.request_id in recompute_scheduler.failed_recving_kv_req_ids
# critical: verify invalid block still has no hash before recompute
# the async transfer invalid data was never cached
assert block.block_hash is None, (
f"Invalid block {invalid_block_id} should not be cached before recompute "
f"(hash should be None), but hash is {block.block_hash}"
)
# critical end-to-end test: spy on cache_blocks to verify it's called with
# the truncated num_computed_tokens value
original_cache_blocks = recompute_scheduler.kv_cache_manager.cache_blocks
cache_blocks_calls = []
def cache_blocks_spy(req, num_tokens):
cache_blocks_calls.append((req.request_id, num_tokens))
return original_cache_blocks(req, num_tokens)
with patch.object(
recompute_scheduler.kv_cache_manager, "cache_blocks", cache_blocks_spy
):
# call schedule() again - this triggers _update_waiting_for_remote_kv()
# which should call cache_blocks with the truncated value
recompute_scheduler.schedule()
# verify cache_blocks was called with the truncated value
assert len(cache_blocks_calls) == 1, (
f"cache_blocks should be called exactly once, "
f"got {len(cache_blocks_calls)} calls"
)
cached_req_id, cached_num_tokens = cache_blocks_calls[0]
assert cached_req_id == request.request_id
assert cached_num_tokens == expected_valid_tokens, (
f"cache_blocks should be called with truncated value {expected_valid_tokens}, "
f"but was called with {cached_num_tokens}"
)
# request should now be RUNNING (scheduled immediately after transfer completes)
# the flow is: WAITING_FOR_REMOTE_KVS -> WAITING -> RUNNING in same schedule() call
assert request.status == RequestStatus.RUNNING
# num_computed_tokens should be >= expected_valid_tokens because the scheduler
# will schedule additional new tokens (up to max_num_batched_tokens) for the request
assert request.num_computed_tokens >= expected_valid_tokens, (
f"num_computed_tokens should be at least {expected_valid_tokens}, "
f"got {request.num_computed_tokens}"
)
# request should no longer be in the failed/finished receiving sets
assert request.request_id not in recompute_scheduler.failed_recving_kv_req_ids
assert request.request_id not in recompute_scheduler.finished_recving_kv_req_ids
# request should be in the running queue
assert request in recompute_scheduler.running

View File

@ -64,6 +64,11 @@ class KVTransferConfig:
enable_permute_local_kv: bool = False
"""Experiment feature flag to enable HND to NHD KV Transfer"""
kv_load_failure_policy: Literal["recompute", "fail"] = "recompute"
"""Policy for handling KV cache load failures.
'recompute': reschedule the request to recompute failed blocks (default)
'fail': immediately fail the request with an error finish reason"""
def compute_hash(self) -> str:
"""
WARNING: Whenever a new field is added to this config,

View File

@ -51,7 +51,11 @@ from vllm.entrypoints.openai.protocol import (
ToolCall,
UsageInfo,
)
from vllm.entrypoints.openai.serving_engine import OpenAIServing, clamp_prompt_logprobs
from vllm.entrypoints.openai.serving_engine import (
GenerationError,
OpenAIServing,
clamp_prompt_logprobs,
)
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
from vllm.entrypoints.openai.tool_parsers import ToolParser
from vllm.entrypoints.openai.tool_parsers.mistral_tool_parser import MistralToolCall
@ -380,6 +384,8 @@ class OpenAIServingChat(OpenAIServing):
tokenizer,
request_metadata,
)
except GenerationError as e:
return self._convert_generation_error_to_response(e)
except ValueError as e:
# TODO: Use a vllm-specific Validation Error
return self.create_error_response(str(e))
@ -1120,6 +1126,10 @@ class OpenAIServingChat(OpenAIServing):
# if the model is finished generating
else:
# check for error finish reason and abort streaming
# finish_reason='error' indicates a retryable error
self._raise_if_error(output.finish_reason, request_id)
# check to make sure we haven't "forgotten" to stream
# any tokens that were generated but previously
# matched by partial json parsing
@ -1287,6 +1297,8 @@ class OpenAIServingChat(OpenAIServing):
delta=False,
)
except GenerationError as e:
yield f"data: {self._convert_generation_error_to_streaming_response(e)}\n\n"
except Exception as e:
# TODO: Use a vllm-specific Validation Error
logger.exception("Error in chat completion stream generator.")
@ -1327,6 +1339,9 @@ class OpenAIServingChat(OpenAIServing):
role = self.get_chat_request_role(request)
for output in final_res.outputs:
# check for error finish reason and raise GenerationError
# finish_reason='error' indicates a retryable request-level internal error
self._raise_if_error(output.finish_reason, request_id)
token_ids = output.token_ids
out_logprobs = output.logprobs
tool_call_info = None

View File

@ -24,7 +24,11 @@ from vllm.entrypoints.openai.protocol import (
RequestResponseMetadata,
UsageInfo,
)
from vllm.entrypoints.openai.serving_engine import OpenAIServing, clamp_prompt_logprobs
from vllm.entrypoints.openai.serving_engine import (
GenerationError,
OpenAIServing,
clamp_prompt_logprobs,
)
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
from vllm.entrypoints.renderer import RenderConfig
from vllm.entrypoints.utils import get_max_tokens, should_include_usage
@ -300,6 +304,8 @@ class OpenAIServingCompletion(OpenAIServing):
)
except asyncio.CancelledError:
return self.create_error_response("Client disconnected")
except GenerationError as e:
return self._convert_generation_error_to_response(e)
except ValueError as e:
# TODO: Use a vllm-specific Validation Error
return self.create_error_response(str(e))
@ -437,6 +443,8 @@ class OpenAIServingCompletion(OpenAIServing):
finish_reason = output.finish_reason
stop_reason = output.stop_reason
self._raise_if_error(finish_reason, request_id)
chunk = CompletionStreamResponse(
id=request_id,
created=created_time,
@ -498,8 +506,11 @@ class OpenAIServingCompletion(OpenAIServing):
# report to FastAPI middleware aggregate usage across all choices
request_metadata.final_usage_info = final_usage_info
except GenerationError as e:
yield f"data: {self._convert_generation_error_to_streaming_response(e)}\n\n"
except Exception as e:
# TODO: Use a vllm-specific Validation Error
logger.exception("Error in completion stream generator.")
data = self.create_streaming_error_response(str(e))
yield f"data: {data}\n\n"
yield "data: [DONE]\n\n"
@ -530,6 +541,8 @@ class OpenAIServingCompletion(OpenAIServing):
out_logprobs: GenericSequence[dict[int, Logprob] | None] | None
for output in final_res.outputs:
self._raise_if_error(output.finish_reason, request_id)
assert request.max_tokens is not None
if request.echo:
if request.return_token_ids:

View File

@ -133,6 +133,15 @@ from vllm.utils.async_utils import (
from vllm.utils.collection_utils import is_list_of
from vllm.v1.engine import EngineCoreRequest
class GenerationError(Exception):
"""raised when finish_reason indicates internal server error (500)"""
def __init__(self, message: str = "Internal server error"):
super().__init__(message)
self.status_code = HTTPStatus.INTERNAL_SERVER_ERROR
logger = init_logger(__name__)
CompletionLikeRequest: TypeAlias = (
@ -456,6 +465,29 @@ class OpenAIServing:
# Iterate through all beam inference results
for i, result in enumerate(output):
current_beam = all_beams[i]
# check for error finish reason and abort beam search
if result.outputs[0].finish_reason == "error":
# yield error output and terminate beam search
yield RequestOutput(
request_id=request_id,
prompt=prompt_text,
outputs=[
CompletionOutput(
index=0,
text="",
token_ids=[],
cumulative_logprob=None,
logprobs=None,
finish_reason="error",
)
],
finished=True,
prompt_token_ids=prompt_token_ids,
prompt_logprobs=None,
)
return
if result.outputs[0].logprobs is not None:
logprobs = result.outputs[0].logprobs[0]
all_beams_token_id.extend(list(logprobs.keys()))
@ -780,6 +812,35 @@ class OpenAIServing:
)
return json_str
def _raise_if_error(self, finish_reason: str | None, request_id: str) -> None:
"""Raise GenerationError if finish_reason indicates an error."""
if finish_reason == "error":
logger.error(
"Request %s failed with an internal error during generation",
request_id,
)
raise GenerationError("Internal server error")
def _convert_generation_error_to_response(
self, e: GenerationError
) -> ErrorResponse:
"""Convert GenerationError to ErrorResponse."""
return self.create_error_response(
str(e),
err_type="InternalServerError",
status_code=e.status_code,
)
def _convert_generation_error_to_streaming_response(
self, e: GenerationError
) -> str:
"""Convert GenerationError to streaming error response."""
return self.create_streaming_error_response(
str(e),
err_type="InternalServerError",
status_code=e.status_code,
)
async def _check_model(
self,
request: AnyRequest,

View File

@ -50,6 +50,7 @@ from openai.types.responses.response_reasoning_item import (
)
from openai.types.responses.tool import Mcp, Tool
from openai_harmony import Message as OpenAIHarmonyMessage
from pydantic import TypeAdapter
from vllm import envs
from vllm.engine.protocol import EngineClient
@ -94,7 +95,10 @@ from vllm.entrypoints.openai.protocol import (
ResponseUsage,
StreamingResponsesResponse,
)
from vllm.entrypoints.openai.serving_engine import OpenAIServing
from vllm.entrypoints.openai.serving_engine import (
GenerationError,
OpenAIServing,
)
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
from vllm.entrypoints.responses_utils import (
construct_input_messages,
@ -541,6 +545,8 @@ class OpenAIServingResponses(OpenAIServing):
tokenizer,
request_metadata,
)
except GenerationError as e:
return self._convert_generation_error_to_response(e)
except Exception as e:
return self.create_error_response(str(e))
@ -648,6 +654,8 @@ class OpenAIServingResponses(OpenAIServing):
status = "incomplete"
elif context.finish_reason == "abort":
status = "cancelled"
else:
self._raise_if_error(context.finish_reason, request.request_id)
else:
status = "incomplete"
elif isinstance(context, ParsableContext):
@ -673,6 +681,9 @@ class OpenAIServingResponses(OpenAIServing):
assert len(final_res.outputs) == 1
final_output = final_res.outputs[0]
# finish_reason='error' indicates retryable internal error
self._raise_if_error(final_output.finish_reason, request.request_id)
output = self._make_response_output_items(request, final_output, tokenizer)
if request.enable_response_messages:
@ -1066,6 +1077,8 @@ class OpenAIServingResponses(OpenAIServing):
async for event in generator:
event_deque.append(event)
new_event_signal.set() # Signal new event available
except GenerationError as e:
response = self._convert_generation_error_to_response(e)
except Exception as e:
logger.exception("Background request failed for %s", request.request_id)
response = self.create_error_response(str(e))
@ -1089,6 +1102,8 @@ class OpenAIServingResponses(OpenAIServing):
):
try:
response = await self.responses_full_generator(request, *args, **kwargs)
except GenerationError as e:
response = self._convert_generation_error_to_response(e)
except Exception as e:
logger.exception("Background request failed for %s", request.request_id)
response = self.create_error_response(str(e))
@ -1227,6 +1242,8 @@ class OpenAIServingResponses(OpenAIServing):
continue
if ctx.last_output.outputs:
output = ctx.last_output.outputs[0]
# finish_reason='error' indicates a retryable error
self._raise_if_error(output.finish_reason, request.request_id)
if reasoning_parser:
delta_message = reasoning_parser.extract_reasoning_streaming(
previous_text=previous_text,
@ -1522,6 +1539,9 @@ class OpenAIServingResponses(OpenAIServing):
async for ctx in result_generator:
assert isinstance(ctx, StreamingHarmonyContext)
# finish_reason='error' indicates a retryable error
self._raise_if_error(ctx.finish_reason, request.request_id)
if ctx.is_expecting_start():
current_output_index += 1
sent_output_item_added = False
@ -2016,18 +2036,25 @@ class OpenAIServingResponses(OpenAIServing):
)
)
async for event_data in processer(
request,
sampling_params,
result_generator,
context,
model_name,
tokenizer,
request_metadata,
created_time,
_increment_sequence_number_and_return,
):
yield event_data
try:
async for event_data in processer(
request,
sampling_params,
result_generator,
context,
model_name,
tokenizer,
request_metadata,
created_time,
_increment_sequence_number_and_return,
):
yield event_data
except GenerationError as e:
error_json = self._convert_generation_error_to_streaming_response(e)
yield _increment_sequence_number_and_return(
TypeAdapter(StreamingResponsesResponse).validate_json(error_json)
)
return
async def empty_async_generator():
# A hack to trick Python to think this is a generator but

View File

@ -397,6 +397,25 @@ class BlockPool:
[block for block in blocks_list if block.ref_cnt == 0 and not block.is_null]
)
def evict_blocks(self, block_ids: set[int]) -> None:
"""evict blocks from the prefix cache by their block IDs.
only evicts blocks that are currently cached (have a hash). blocks
with ref_cnt > 0 are not freed from the block pool, only evicted
from the prefix cache hash table.
Args:
block_ids: Set of block IDs to evict from cache.
"""
for block_id in block_ids:
assert block_id < len(self.blocks), (
f"Invalid block_id {block_id} >= {len(self.blocks)}. "
f"This indicates a bug in the KV connector - workers should "
f"only report block IDs that were allocated by the scheduler."
)
block = self.blocks[block_id]
self._maybe_evict_cached_block(block)
def reset_prefix_cache(self) -> bool:
"""Reset prefix cache. This function may be used in RLHF
flows to invalid prefix caching after the weights are updated,

View File

@ -333,6 +333,14 @@ class KVCacheManager:
"""
self.coordinator.free(request.request_id)
def evict_blocks(self, block_ids: set[int]) -> None:
"""evict blocks from the prefix cache by their block IDs.
Args:
block_ids: Set of block IDs to evict from cache.
"""
self.block_pool.evict_blocks(block_ids)
def reset_prefix_cache(self) -> bool:
"""Reset prefix cache. This function may be used in RLHF
flows to invalidate prefix caching after the weights are updated,

View File

@ -106,6 +106,7 @@ class Scheduler(SchedulerInterface):
# KV Connector pushes/pull of remote KVs for P/D and offloading.
self.connector = None
self.connector_prefix_cache_stats: PrefixCacheStats | None = None
self.recompute_kv_load_failures = True
if self.vllm_config.kv_transfer_config is not None:
assert not self.is_encoder_decoder, (
"Encoder-decoder models are not currently supported with KV connectors"
@ -117,6 +118,10 @@ class Scheduler(SchedulerInterface):
)
if self.log_stats:
self.connector_prefix_cache_stats = PrefixCacheStats()
kv_load_failure_policy = (
self.vllm_config.kv_transfer_config.kv_load_failure_policy
)
self.recompute_kv_load_failures = kv_load_failure_policy == "recompute"
self.kv_event_publisher = EventPublisherFactory.create(
self.kv_events_config,
@ -1066,7 +1071,7 @@ class Scheduler(SchedulerInterface):
for req_id, num_tokens_scheduled in num_scheduled_tokens.items():
assert num_tokens_scheduled > 0
if failed_kv_load_req_ids and req_id in failed_kv_load_req_ids:
# Skip requests that were recovered from KV load failure
# skip failed or rescheduled requests from KV load failure
continue
request = self.requests.get(req_id)
if request is None:
@ -1177,6 +1182,21 @@ class Scheduler(SchedulerInterface):
# This is a rare case and unlikely to impact performance.
self.waiting.remove_requests(stopped_preempted_reqs)
if failed_kv_load_req_ids and not self.recompute_kv_load_failures:
requests = [self.requests[req_id] for req_id in failed_kv_load_req_ids]
self.finish_requests(failed_kv_load_req_ids, RequestStatus.FINISHED_ERROR)
for request in requests:
outputs[request.client_index].append(
EngineCoreOutput(
request_id=request.request_id,
new_token_ids=[],
finish_reason=request.get_finished_reason(),
events=request.take_events(),
trace_headers=request.trace_headers,
num_cached_tokens=request.num_cached_tokens,
)
)
# KV Connector: update state for finished KV Transfers.
if kv_connector_output:
self._update_from_kv_xfer_finished(kv_connector_output)
@ -1610,8 +1630,11 @@ class Scheduler(SchedulerInterface):
self._free_blocks(self.requests[req_id])
def _update_requests_with_invalid_blocks(
self, requests: Iterable[Request], invalid_block_ids: set[int]
) -> tuple[set[str], int]:
self,
requests: Iterable[Request],
invalid_block_ids: set[int],
evict_blocks: bool = True,
) -> tuple[set[str], int, set[int]]:
"""
Identify and update requests affected by invalid KV cache blocks.
@ -1623,16 +1646,21 @@ class Scheduler(SchedulerInterface):
Args:
requests: The set of requests to scan for invalid blocks.
invalid_block_ids: IDs of invalid blocks.
evict_blocks: Whether to collect blocks for eviction (False for
async requests which aren't cached yet).
Returns:
tuple:
- affected_req_ids (set[str]): IDs of requests impacted by
invalid blocks.
- total_affected_tokens (int): Total number of tokens that must
be recomputed across all affected requests (for observability).
be recomputed across all affected requests.
- blocks_to_evict (set[int]): Block IDs to evict from cache,
including invalid blocks and downstream dependent blocks.
"""
affected_req_ids: set[str] = set()
total_affected_tokens = 0
blocks_to_evict: set[int] = set()
# If a block is invalid and shared by multiple requests in the batch,
# these requests must be rescheduled, but only the first will recompute
# it. This set tracks blocks already marked for recomputation.
@ -1690,6 +1718,9 @@ class Scheduler(SchedulerInterface):
)
total_affected_tokens += num_affected_tokens
request.num_external_computed_tokens -= num_affected_tokens
# collect invalid block and all downstream dependent blocks
if evict_blocks:
blocks_to_evict.update(req_block_ids[idx:])
if is_affected:
if not marked_invalid_block:
@ -1705,47 +1736,70 @@ class Scheduler(SchedulerInterface):
affected_req_ids.add(request.request_id)
return affected_req_ids, total_affected_tokens
return affected_req_ids, total_affected_tokens, blocks_to_evict
def _handle_invalid_blocks(self, invalid_block_ids: set[int]) -> set[str]:
total_requests_to_reschedule = 0
total_tokens_to_reschedule = 0
"""
Handle requests affected by invalid KV cache blocks.
# --- Handle async KV loads (WAITING_FOR_REMOTE_KVS) ---
Returns:
Set of affected request IDs to skip in update_from_output main loop.
"""
should_fail = not self.recompute_kv_load_failures
# handle async KV loads (not cached yet, evict_blocks=False)
async_load_reqs = (
req
for req in self.waiting
if req.status == RequestStatus.WAITING_FOR_REMOTE_KVS
)
async_affected_req_ids, num_tokens_to_reschedule = (
async_failed_req_ids, num_failed_tokens, _ = (
self._update_requests_with_invalid_blocks(
async_load_reqs, invalid_block_ids
async_load_reqs, invalid_block_ids, evict_blocks=False
)
)
total_requests_to_reschedule += len(async_affected_req_ids)
total_tokens_to_reschedule += num_tokens_to_reschedule
total_failed_requests = len(async_failed_req_ids)
total_failed_tokens = num_failed_tokens
# Mark requests with async KV load failures; they will be rescheduled
# once loading completes.
self.failed_recving_kv_req_ids |= async_affected_req_ids
# --- Handle sync KV loads (running requests) ---
sync_affected_req_ids, num_tokens_to_reschedule = (
self._update_requests_with_invalid_blocks(self.running, invalid_block_ids)
# handle sync loads (may be cached, collect blocks for eviction)
sync_failed_req_ids, num_failed_tokens, sync_blocks_to_evict = (
self._update_requests_with_invalid_blocks(
self.running, invalid_block_ids, evict_blocks=True
)
)
total_requests_to_reschedule += len(sync_affected_req_ids)
total_tokens_to_reschedule += num_tokens_to_reschedule
total_failed_requests += len(sync_failed_req_ids)
total_failed_tokens += num_failed_tokens
if total_requests_to_reschedule:
logger.warning(
"Recovered from KV load failure: "
"%d request(s) rescheduled (%d tokens affected).",
total_requests_to_reschedule,
total_tokens_to_reschedule,
if not total_failed_requests:
return set()
# evict invalid blocks and downstream dependent blocks from cache
# only when not using recompute policy (where blocks will be recomputed
# and reused by other requests sharing them)
if sync_blocks_to_evict and not self.recompute_kv_load_failures:
self.kv_cache_manager.evict_blocks(sync_blocks_to_evict)
if should_fail:
all_failed_req_ids = async_failed_req_ids | sync_failed_req_ids
logger.error(
"Failing %d request(s) due to KV load failure "
"(failure_policy=fail, %d tokens affected). Request IDs: %s",
total_failed_requests,
total_failed_tokens,
all_failed_req_ids,
)
return all_failed_req_ids
# Return the IDs of affected running requests to skip in
# update_from_output.
return sync_affected_req_ids
logger.warning(
"Recovered from KV load failure: "
"%d request(s) rescheduled (%d tokens affected).",
total_failed_requests,
total_failed_tokens,
)
# Mark async requests with KV load failures for retry once loading completes
self.failed_recving_kv_req_ids |= async_failed_req_ids
# Return sync affected IDs to skip in update_from_output
return sync_failed_req_ids

View File

@ -19,24 +19,27 @@ from vllm.v1.serial_utils import UtilityResult
# These are possible values of RequestOutput.finish_reason,
# so form part of the external API.
FINISH_REASON_STRINGS = ("stop", "length", "abort")
FINISH_REASON_STRINGS = ("stop", "length", "abort", "error")
class FinishReason(enum.IntEnum):
"""
Reason a request finished - stop, length, or abort.
Reason a request finished - stop, length, abort, or error.
Int rather than Str for more compact serialization.
stop - a stop string was emitted
length - max_tokens was consumed, or max_model_len was reached
abort - aborted for another reason
abort - aborted by client
error - retryable request-level internal error (e.g., KV load failure).
Invariant: always converted to 500 Internal Server Error.
"""
STOP = 0
LENGTH = 1
ABORT = 2
ERROR = 3
def __str__(self):
return FINISH_REASON_STRINGS[self.value]

View File

@ -255,6 +255,7 @@ class RequestStatus(enum.IntEnum):
FINISHED_LENGTH_CAPPED = enum.auto()
FINISHED_ABORTED = enum.auto()
FINISHED_IGNORED = enum.auto()
FINISHED_ERROR = enum.auto()
def __str__(self):
return self.name
@ -277,4 +278,5 @@ _FINISHED_REASON_MAP = {
RequestStatus.FINISHED_LENGTH_CAPPED: FinishReason.LENGTH,
RequestStatus.FINISHED_ABORTED: FinishReason.ABORT,
RequestStatus.FINISHED_IGNORED: FinishReason.LENGTH,
RequestStatus.FINISHED_ERROR: FinishReason.ERROR,
}