mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-31 06:09:37 +08:00
[P/D] KV Load Failure Recovery/Abort Configuration (#26813)
Signed-off-by: Will Eaton <weaton@redhat.com> Signed-off-by: Will Eaton <me@wseaton.com> Signed-off-by: Nick Hill <nhill@redhat.com> Co-authored-by: Mark McLoughlin <markmc@redhat.com> Co-authored-by: Nick Hill <nhill@redhat.com> Co-authored-by: chaunceyjiang <chaunceyjiang@gmail.com>
This commit is contained in:
parent
e8e8cd73e5
commit
a9e4106f28
228
tests/entrypoints/openai/test_chat_error.py
Normal file
228
tests/entrypoints/openai/test_chat_error.py
Normal file
@ -0,0 +1,228 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from http import HTTPStatus
|
||||
from typing import Any
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.config.multimodal import MultiModalConfig
|
||||
from vllm.entrypoints.openai.protocol import ChatCompletionRequest, ErrorResponse
|
||||
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
|
||||
from vllm.entrypoints.openai.serving_models import BaseModelPath, OpenAIServingModels
|
||||
from vllm.outputs import CompletionOutput, RequestOutput
|
||||
from vllm.transformers_utils.tokenizer import get_tokenizer
|
||||
from vllm.v1.engine.async_llm import AsyncLLM
|
||||
|
||||
MODEL_NAME = "openai-community/gpt2"
|
||||
MODEL_NAME_SHORT = "gpt2"
|
||||
BASE_MODEL_PATHS = [
|
||||
BaseModelPath(name=MODEL_NAME, model_path=MODEL_NAME),
|
||||
BaseModelPath(name=MODEL_NAME_SHORT, model_path=MODEL_NAME_SHORT),
|
||||
]
|
||||
|
||||
|
||||
@dataclass
|
||||
class MockHFConfig:
|
||||
model_type: str = "any"
|
||||
|
||||
|
||||
@dataclass
|
||||
class MockModelConfig:
|
||||
task = "generate"
|
||||
runner_type = "generate"
|
||||
tokenizer = MODEL_NAME
|
||||
trust_remote_code = False
|
||||
tokenizer_mode = "auto"
|
||||
max_model_len = 100
|
||||
tokenizer_revision = None
|
||||
multimodal_config = MultiModalConfig()
|
||||
hf_config = MockHFConfig()
|
||||
logits_processor_pattern = None
|
||||
logits_processors: list[str] | None = None
|
||||
diff_sampling_param: dict | None = None
|
||||
allowed_local_media_path: str = ""
|
||||
allowed_media_domains: list[str] | None = None
|
||||
encoder_config = None
|
||||
generation_config: str = "auto"
|
||||
media_io_kwargs: dict[str, dict[str, Any]] = field(default_factory=dict)
|
||||
skip_tokenizer_init = False
|
||||
|
||||
def get_diff_sampling_param(self):
|
||||
return self.diff_sampling_param or {}
|
||||
|
||||
|
||||
def _build_serving_chat(engine: AsyncLLM) -> OpenAIServingChat:
|
||||
models = OpenAIServingModels(
|
||||
engine_client=engine,
|
||||
base_model_paths=BASE_MODEL_PATHS,
|
||||
)
|
||||
serving_chat = OpenAIServingChat(
|
||||
engine,
|
||||
models,
|
||||
response_role="assistant",
|
||||
request_logger=None,
|
||||
chat_template=None,
|
||||
chat_template_content_format="auto",
|
||||
)
|
||||
|
||||
async def _fake_process_inputs(
|
||||
request_id,
|
||||
engine_prompt,
|
||||
sampling_params,
|
||||
*,
|
||||
lora_request,
|
||||
trace_headers,
|
||||
priority,
|
||||
):
|
||||
return dict(engine_prompt), {}
|
||||
|
||||
async def _fake_preprocess_chat(*args, **kwargs):
|
||||
# return conversation, request_prompts, engine_prompts
|
||||
return (
|
||||
[{"role": "user", "content": "Test"}],
|
||||
[[1, 2, 3]],
|
||||
[{"prompt_token_ids": [1, 2, 3]}],
|
||||
)
|
||||
|
||||
serving_chat._process_inputs = AsyncMock(side_effect=_fake_process_inputs)
|
||||
serving_chat._preprocess_chat = AsyncMock(side_effect=_fake_preprocess_chat)
|
||||
return serving_chat
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_error_non_stream():
|
||||
"""test finish_reason='error' returns 500 InternalServerError (non-streaming)"""
|
||||
mock_engine = MagicMock(spec=AsyncLLM)
|
||||
mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME)
|
||||
mock_engine.errored = False
|
||||
mock_engine.model_config = MockModelConfig()
|
||||
mock_engine.input_processor = MagicMock()
|
||||
mock_engine.io_processor = MagicMock()
|
||||
|
||||
serving_chat = _build_serving_chat(mock_engine)
|
||||
|
||||
completion_output = CompletionOutput(
|
||||
index=0,
|
||||
text="",
|
||||
token_ids=[],
|
||||
cumulative_logprob=None,
|
||||
logprobs=None,
|
||||
finish_reason="error",
|
||||
)
|
||||
|
||||
request_output = RequestOutput(
|
||||
request_id="test-id",
|
||||
prompt="Test prompt",
|
||||
prompt_token_ids=[1, 2, 3],
|
||||
prompt_logprobs=None,
|
||||
outputs=[completion_output],
|
||||
finished=True,
|
||||
metrics=None,
|
||||
lora_request=None,
|
||||
encoder_prompt=None,
|
||||
encoder_prompt_token_ids=None,
|
||||
)
|
||||
|
||||
async def mock_generate(*args, **kwargs):
|
||||
yield request_output
|
||||
|
||||
mock_engine.generate = MagicMock(side_effect=mock_generate)
|
||||
|
||||
request = ChatCompletionRequest(
|
||||
model=MODEL_NAME,
|
||||
messages=[{"role": "user", "content": "Test prompt"}],
|
||||
max_tokens=10,
|
||||
stream=False,
|
||||
)
|
||||
|
||||
response = await serving_chat.create_chat_completion(request)
|
||||
|
||||
assert isinstance(response, ErrorResponse)
|
||||
assert response.error.type == "InternalServerError"
|
||||
assert response.error.message == "Internal server error"
|
||||
assert response.error.code == HTTPStatus.INTERNAL_SERVER_ERROR
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_error_stream():
|
||||
"""test finish_reason='error' returns 500 InternalServerError (streaming)"""
|
||||
mock_engine = MagicMock(spec=AsyncLLM)
|
||||
mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME)
|
||||
mock_engine.errored = False
|
||||
mock_engine.model_config = MockModelConfig()
|
||||
mock_engine.input_processor = MagicMock()
|
||||
mock_engine.io_processor = MagicMock()
|
||||
|
||||
serving_chat = _build_serving_chat(mock_engine)
|
||||
|
||||
completion_output_1 = CompletionOutput(
|
||||
index=0,
|
||||
text="Hello",
|
||||
token_ids=[100],
|
||||
cumulative_logprob=None,
|
||||
logprobs=None,
|
||||
finish_reason=None,
|
||||
)
|
||||
|
||||
request_output_1 = RequestOutput(
|
||||
request_id="test-id",
|
||||
prompt="Test prompt",
|
||||
prompt_token_ids=[1, 2, 3],
|
||||
prompt_logprobs=None,
|
||||
outputs=[completion_output_1],
|
||||
finished=False,
|
||||
metrics=None,
|
||||
lora_request=None,
|
||||
encoder_prompt=None,
|
||||
encoder_prompt_token_ids=None,
|
||||
)
|
||||
|
||||
completion_output_2 = CompletionOutput(
|
||||
index=0,
|
||||
text="Hello",
|
||||
token_ids=[100],
|
||||
cumulative_logprob=None,
|
||||
logprobs=None,
|
||||
finish_reason="error",
|
||||
)
|
||||
|
||||
request_output_2 = RequestOutput(
|
||||
request_id="test-id",
|
||||
prompt="Test prompt",
|
||||
prompt_token_ids=[1, 2, 3],
|
||||
prompt_logprobs=None,
|
||||
outputs=[completion_output_2],
|
||||
finished=True,
|
||||
metrics=None,
|
||||
lora_request=None,
|
||||
encoder_prompt=None,
|
||||
encoder_prompt_token_ids=None,
|
||||
)
|
||||
|
||||
async def mock_generate(*args, **kwargs):
|
||||
yield request_output_1
|
||||
yield request_output_2
|
||||
|
||||
mock_engine.generate = MagicMock(side_effect=mock_generate)
|
||||
|
||||
request = ChatCompletionRequest(
|
||||
model=MODEL_NAME,
|
||||
messages=[{"role": "user", "content": "Test prompt"}],
|
||||
max_tokens=10,
|
||||
stream=True,
|
||||
)
|
||||
|
||||
response = await serving_chat.create_chat_completion(request)
|
||||
|
||||
chunks = []
|
||||
async for chunk in response:
|
||||
chunks.append(chunk)
|
||||
|
||||
assert len(chunks) >= 2
|
||||
assert any("Internal server error" in chunk for chunk in chunks), (
|
||||
f"Expected error message in chunks: {chunks}"
|
||||
)
|
||||
assert chunks[-1] == "data: [DONE]\n\n"
|
||||
216
tests/entrypoints/openai/test_completion_error.py
Normal file
216
tests/entrypoints/openai/test_completion_error.py
Normal file
@ -0,0 +1,216 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from http import HTTPStatus
|
||||
from typing import Any
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.config.multimodal import MultiModalConfig
|
||||
from vllm.entrypoints.openai.protocol import CompletionRequest, ErrorResponse
|
||||
from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
|
||||
from vllm.entrypoints.openai.serving_models import BaseModelPath, OpenAIServingModels
|
||||
from vllm.outputs import CompletionOutput, RequestOutput
|
||||
from vllm.transformers_utils.tokenizer import get_tokenizer
|
||||
from vllm.v1.engine.async_llm import AsyncLLM
|
||||
|
||||
MODEL_NAME = "openai-community/gpt2"
|
||||
MODEL_NAME_SHORT = "gpt2"
|
||||
BASE_MODEL_PATHS = [
|
||||
BaseModelPath(name=MODEL_NAME, model_path=MODEL_NAME),
|
||||
BaseModelPath(name=MODEL_NAME_SHORT, model_path=MODEL_NAME_SHORT),
|
||||
]
|
||||
|
||||
|
||||
@dataclass
|
||||
class MockHFConfig:
|
||||
model_type: str = "any"
|
||||
|
||||
|
||||
@dataclass
|
||||
class MockModelConfig:
|
||||
task = "generate"
|
||||
runner_type = "generate"
|
||||
tokenizer = MODEL_NAME
|
||||
trust_remote_code = False
|
||||
tokenizer_mode = "auto"
|
||||
max_model_len = 100
|
||||
tokenizer_revision = None
|
||||
multimodal_config = MultiModalConfig()
|
||||
hf_config = MockHFConfig()
|
||||
logits_processor_pattern = None
|
||||
logits_processors: list[str] | None = None
|
||||
diff_sampling_param: dict | None = None
|
||||
allowed_local_media_path: str = ""
|
||||
allowed_media_domains: list[str] | None = None
|
||||
encoder_config = None
|
||||
generation_config: str = "auto"
|
||||
media_io_kwargs: dict[str, dict[str, Any]] = field(default_factory=dict)
|
||||
skip_tokenizer_init = False
|
||||
|
||||
def get_diff_sampling_param(self):
|
||||
return self.diff_sampling_param or {}
|
||||
|
||||
|
||||
def _build_serving_completion(engine: AsyncLLM) -> OpenAIServingCompletion:
|
||||
models = OpenAIServingModels(
|
||||
engine_client=engine,
|
||||
base_model_paths=BASE_MODEL_PATHS,
|
||||
)
|
||||
serving_completion = OpenAIServingCompletion(
|
||||
engine,
|
||||
models,
|
||||
request_logger=None,
|
||||
)
|
||||
|
||||
async def _fake_process_inputs(
|
||||
request_id,
|
||||
engine_prompt,
|
||||
sampling_params,
|
||||
*,
|
||||
lora_request,
|
||||
trace_headers,
|
||||
priority,
|
||||
):
|
||||
return dict(engine_prompt), {}
|
||||
|
||||
serving_completion._process_inputs = AsyncMock(side_effect=_fake_process_inputs)
|
||||
return serving_completion
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_completion_error_non_stream():
|
||||
"""test finish_reason='error' returns 500 InternalServerError (non-streaming)"""
|
||||
mock_engine = MagicMock(spec=AsyncLLM)
|
||||
mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME)
|
||||
mock_engine.errored = False
|
||||
mock_engine.model_config = MockModelConfig()
|
||||
mock_engine.input_processor = MagicMock()
|
||||
mock_engine.io_processor = MagicMock()
|
||||
|
||||
serving_completion = _build_serving_completion(mock_engine)
|
||||
|
||||
completion_output = CompletionOutput(
|
||||
index=0,
|
||||
text="",
|
||||
token_ids=[],
|
||||
cumulative_logprob=None,
|
||||
logprobs=None,
|
||||
finish_reason="error",
|
||||
)
|
||||
|
||||
request_output = RequestOutput(
|
||||
request_id="test-id",
|
||||
prompt="Test prompt",
|
||||
prompt_token_ids=[1, 2, 3],
|
||||
prompt_logprobs=None,
|
||||
outputs=[completion_output],
|
||||
finished=True,
|
||||
metrics=None,
|
||||
lora_request=None,
|
||||
encoder_prompt=None,
|
||||
encoder_prompt_token_ids=None,
|
||||
)
|
||||
|
||||
async def mock_generate(*args, **kwargs):
|
||||
yield request_output
|
||||
|
||||
mock_engine.generate = MagicMock(side_effect=mock_generate)
|
||||
|
||||
request = CompletionRequest(
|
||||
model=MODEL_NAME,
|
||||
prompt="Test prompt",
|
||||
max_tokens=10,
|
||||
stream=False,
|
||||
)
|
||||
|
||||
response = await serving_completion.create_completion(request)
|
||||
|
||||
assert isinstance(response, ErrorResponse)
|
||||
assert response.error.type == "InternalServerError"
|
||||
assert response.error.message == "Internal server error"
|
||||
assert response.error.code == HTTPStatus.INTERNAL_SERVER_ERROR
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_completion_error_stream():
|
||||
"""test finish_reason='error' returns 500 InternalServerError (streaming)"""
|
||||
mock_engine = MagicMock(spec=AsyncLLM)
|
||||
mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME)
|
||||
mock_engine.errored = False
|
||||
mock_engine.model_config = MockModelConfig()
|
||||
mock_engine.input_processor = MagicMock()
|
||||
mock_engine.io_processor = MagicMock()
|
||||
|
||||
serving_completion = _build_serving_completion(mock_engine)
|
||||
|
||||
completion_output_1 = CompletionOutput(
|
||||
index=0,
|
||||
text="Hello",
|
||||
token_ids=[100],
|
||||
cumulative_logprob=None,
|
||||
logprobs=None,
|
||||
finish_reason=None,
|
||||
)
|
||||
|
||||
request_output_1 = RequestOutput(
|
||||
request_id="test-id",
|
||||
prompt="Test prompt",
|
||||
prompt_token_ids=[1, 2, 3],
|
||||
prompt_logprobs=None,
|
||||
outputs=[completion_output_1],
|
||||
finished=False,
|
||||
metrics=None,
|
||||
lora_request=None,
|
||||
encoder_prompt=None,
|
||||
encoder_prompt_token_ids=None,
|
||||
)
|
||||
|
||||
completion_output_2 = CompletionOutput(
|
||||
index=0,
|
||||
text="Hello",
|
||||
token_ids=[100],
|
||||
cumulative_logprob=None,
|
||||
logprobs=None,
|
||||
finish_reason="error",
|
||||
)
|
||||
|
||||
request_output_2 = RequestOutput(
|
||||
request_id="test-id",
|
||||
prompt="Test prompt",
|
||||
prompt_token_ids=[1, 2, 3],
|
||||
prompt_logprobs=None,
|
||||
outputs=[completion_output_2],
|
||||
finished=True,
|
||||
metrics=None,
|
||||
lora_request=None,
|
||||
encoder_prompt=None,
|
||||
encoder_prompt_token_ids=None,
|
||||
)
|
||||
|
||||
async def mock_generate(*args, **kwargs):
|
||||
yield request_output_1
|
||||
yield request_output_2
|
||||
|
||||
mock_engine.generate = MagicMock(side_effect=mock_generate)
|
||||
|
||||
request = CompletionRequest(
|
||||
model=MODEL_NAME,
|
||||
prompt="Test prompt",
|
||||
max_tokens=10,
|
||||
stream=True,
|
||||
)
|
||||
|
||||
response = await serving_completion.create_completion(request)
|
||||
|
||||
chunks = []
|
||||
async for chunk in response:
|
||||
chunks.append(chunk)
|
||||
|
||||
assert len(chunks) >= 2
|
||||
assert any("Internal server error" in chunk for chunk in chunks), (
|
||||
f"Expected error message in chunks: {chunks}"
|
||||
)
|
||||
assert chunks[-1] == "data: [DONE]\n\n"
|
||||
89
tests/entrypoints/openai/test_responses_error.py
Normal file
89
tests/entrypoints/openai/test_responses_error.py
Normal file
@ -0,0 +1,89 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from http import HTTPStatus
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.entrypoints.openai.protocol import ErrorResponse
|
||||
from vllm.entrypoints.openai.serving_engine import GenerationError, OpenAIServing
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_raise_if_error_raises_generation_error():
|
||||
"""test _raise_if_error raises GenerationError"""
|
||||
# create a minimal OpenAIServing instance
|
||||
mock_engine = MagicMock()
|
||||
mock_engine.model_config = MagicMock()
|
||||
mock_engine.model_config.max_model_len = 100
|
||||
mock_models = MagicMock()
|
||||
|
||||
serving = OpenAIServing(
|
||||
engine_client=mock_engine,
|
||||
models=mock_models,
|
||||
request_logger=None,
|
||||
)
|
||||
|
||||
# test that error finish_reason raises GenerationError
|
||||
with pytest.raises(GenerationError) as exc_info:
|
||||
serving._raise_if_error("error", "test-request-id")
|
||||
|
||||
assert str(exc_info.value) == "Internal server error"
|
||||
assert exc_info.value.status_code == HTTPStatus.INTERNAL_SERVER_ERROR
|
||||
|
||||
# test that other finish_reasons don't raise
|
||||
serving._raise_if_error("stop", "test-request-id") # should not raise
|
||||
serving._raise_if_error("length", "test-request-id") # should not raise
|
||||
serving._raise_if_error(None, "test-request-id") # should not raise
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_convert_generation_error_to_response():
|
||||
"""test _convert_generation_error_to_response creates proper ErrorResponse"""
|
||||
mock_engine = MagicMock()
|
||||
mock_engine.model_config = MagicMock()
|
||||
mock_engine.model_config.max_model_len = 100
|
||||
mock_models = MagicMock()
|
||||
|
||||
serving = OpenAIServing(
|
||||
engine_client=mock_engine,
|
||||
models=mock_models,
|
||||
request_logger=None,
|
||||
)
|
||||
|
||||
# create a GenerationError
|
||||
gen_error = GenerationError("Internal server error")
|
||||
|
||||
# convert to ErrorResponse
|
||||
error_response = serving._convert_generation_error_to_response(gen_error)
|
||||
|
||||
assert isinstance(error_response, ErrorResponse)
|
||||
assert error_response.error.type == "InternalServerError"
|
||||
assert error_response.error.message == "Internal server error"
|
||||
assert error_response.error.code == HTTPStatus.INTERNAL_SERVER_ERROR
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_convert_generation_error_to_streaming_response():
|
||||
"""test _convert_generation_error_to_streaming_response output"""
|
||||
mock_engine = MagicMock()
|
||||
mock_engine.model_config = MagicMock()
|
||||
mock_engine.model_config.max_model_len = 100
|
||||
mock_models = MagicMock()
|
||||
|
||||
serving = OpenAIServing(
|
||||
engine_client=mock_engine,
|
||||
models=mock_models,
|
||||
request_logger=None,
|
||||
)
|
||||
|
||||
# create a GenerationError
|
||||
gen_error = GenerationError("Internal server error")
|
||||
|
||||
# convert to streaming error response
|
||||
error_json = serving._convert_generation_error_to_streaming_response(gen_error)
|
||||
|
||||
assert isinstance(error_json, str)
|
||||
assert "Internal server error" in error_json
|
||||
assert "InternalServerError" in error_json
|
||||
163
tests/v1/kv_connector/unit/test_cache_pollution_prevention.py
Normal file
163
tests/v1/kv_connector/unit/test_cache_pollution_prevention.py
Normal file
@ -0,0 +1,163 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
"""
|
||||
test that invalid blocks are evicted from prefix cache to prevent pollution.
|
||||
|
||||
verifies that when sync-loading fails, invalid blocks are removed from the
|
||||
prefix cache hash table so future requests cannot match and reuse corrupted data.
|
||||
"""
|
||||
|
||||
from collections.abc import Callable
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.v1.core.sched.scheduler import Scheduler
|
||||
from vllm.v1.request import Request, RequestStatus
|
||||
|
||||
from .utils import (
|
||||
create_model_runner_output,
|
||||
create_request,
|
||||
create_scheduler,
|
||||
create_vllm_config,
|
||||
)
|
||||
|
||||
pytestmark = pytest.mark.cpu_test
|
||||
|
||||
|
||||
def _make_get_num_new_matched_tokens(
|
||||
req_num_new_matched_tokens: dict[str, int],
|
||||
async_load: bool,
|
||||
) -> Callable[[Request, int], tuple[int, bool]]:
|
||||
def get_num_new_matched_tokens(request: Request, _: int) -> tuple[int, bool]:
|
||||
value = req_num_new_matched_tokens.get(request.request_id, 0)
|
||||
return value, async_load
|
||||
|
||||
return get_num_new_matched_tokens
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def fail_scheduler():
|
||||
"""scheduler with kv_load_failure_policy='fail'"""
|
||||
vllm_config = create_vllm_config()
|
||||
vllm_config.kv_transfer_config.kv_load_failure_policy = "fail"
|
||||
return create_scheduler(vllm_config)
|
||||
|
||||
|
||||
def test_invalid_blocks_evicted_prevents_cache_pollution(
|
||||
fail_scheduler: Scheduler,
|
||||
):
|
||||
"""
|
||||
verify invalid blocks are evicted to prevent future cache hits.
|
||||
|
||||
scenario:
|
||||
1. request 1 loads externally-computed blocks (sync mode)
|
||||
2. some blocks fail to load and are marked invalid
|
||||
3. with fail policy, invalid blocks should be evicted from prefix cache
|
||||
4. request is marked as FINISHED_ERROR
|
||||
"""
|
||||
num_prompt_blocks = 100
|
||||
num_external_computed_blocks = 99
|
||||
invalid_block_idx = 50
|
||||
|
||||
num_prompt_tokens = num_prompt_blocks * fail_scheduler.block_size
|
||||
num_external_computed_tokens = (
|
||||
num_external_computed_blocks * fail_scheduler.block_size
|
||||
)
|
||||
|
||||
# request 1: will have invalid blocks
|
||||
request1 = create_request(num_tokens=num_prompt_tokens, request_id=1)
|
||||
fail_scheduler.add_request(request=request1)
|
||||
|
||||
req_num_new_matched_tokens = {
|
||||
request1.request_id: num_external_computed_tokens,
|
||||
}
|
||||
|
||||
# mock connector indicating sync load
|
||||
fail_scheduler.connector = Mock()
|
||||
fail_scheduler.connector.get_num_new_matched_tokens.side_effect = (
|
||||
_make_get_num_new_matched_tokens(req_num_new_matched_tokens, False)
|
||||
)
|
||||
fail_scheduler.connector.request_finished.return_value = (False, None)
|
||||
fail_scheduler.connector.take_events.return_value = ()
|
||||
|
||||
scheduler_output = fail_scheduler.schedule()
|
||||
|
||||
# request should be running with sync KV load
|
||||
assert len(fail_scheduler.running) == 1
|
||||
assert request1.status == RequestStatus.RUNNING
|
||||
|
||||
# get allocated block IDs
|
||||
req_block_ids = scheduler_output.scheduled_new_reqs[0].block_ids[0]
|
||||
invalid_block_id = req_block_ids[invalid_block_idx]
|
||||
invalid_block_ids = {invalid_block_id}
|
||||
|
||||
# get the block object to verify eviction later
|
||||
block = fail_scheduler.kv_cache_manager.block_pool.blocks[invalid_block_id]
|
||||
|
||||
# cache the blocks to simulate they've been computed and cached
|
||||
# (in real scenario blocks would be cached after compute)
|
||||
fail_scheduler.kv_cache_manager.cache_blocks(request1, num_external_computed_tokens)
|
||||
|
||||
# verify block has a hash (is cached) before reporting invalid blocks
|
||||
assert block.block_hash is not None, (
|
||||
f"block {invalid_block_id} should be cached (have a hash) before "
|
||||
f"eviction test, but hash is None"
|
||||
)
|
||||
|
||||
# report invalid blocks
|
||||
model_runner_output = create_model_runner_output(
|
||||
[request1],
|
||||
invalid_block_ids=invalid_block_ids,
|
||||
use_eos=False,
|
||||
)
|
||||
|
||||
fail_scheduler.update_from_output(scheduler_output, model_runner_output)
|
||||
|
||||
# verify request finished with error (fail policy)
|
||||
assert request1.status == RequestStatus.FINISHED_ERROR
|
||||
|
||||
# critical assertion: invalid block and all subsequent blocks should be evicted
|
||||
# all blocks from invalid_block_idx onwards become invalid since they were
|
||||
# computed based on the failed block
|
||||
for idx in range(invalid_block_idx, len(req_block_ids)):
|
||||
block_id = req_block_ids[idx]
|
||||
block_obj = fail_scheduler.kv_cache_manager.block_pool.blocks[block_id]
|
||||
assert block_obj.block_hash is None, (
|
||||
f"block {block_id} at index {idx} should have been evicted "
|
||||
f"(hash reset to None), but hash is {block_obj.block_hash}. "
|
||||
f"All blocks from index {invalid_block_idx} onwards should be evicted "
|
||||
f"since they depend on the invalid block at index {invalid_block_idx}."
|
||||
)
|
||||
|
||||
# verify cache contains exactly the valid blocks (before first affected block)
|
||||
# and none of the invalid blocks (from first affected block onwards)
|
||||
|
||||
# valid blocks: all blocks before invalid_block_idx should be cached
|
||||
for idx in range(invalid_block_idx):
|
||||
block_id = req_block_ids[idx]
|
||||
block_obj = fail_scheduler.kv_cache_manager.block_pool.blocks[block_id]
|
||||
assert block_obj.block_hash is not None, (
|
||||
f"valid block {block_id} at index {idx} should still be cached "
|
||||
f"(have a hash), but hash is None. Only blocks from index "
|
||||
f"{invalid_block_idx} onwards should be evicted."
|
||||
)
|
||||
|
||||
# invalid blocks: verify they're not in the cached_block_hash_to_block map
|
||||
cached_blocks = (
|
||||
fail_scheduler.kv_cache_manager.block_pool.cached_block_hash_to_block
|
||||
)
|
||||
cached_block_ids = {
|
||||
b.block_id
|
||||
for blocks_val in cached_blocks._cache.values()
|
||||
for b in (
|
||||
[blocks_val] if not isinstance(blocks_val, dict) else blocks_val.values()
|
||||
)
|
||||
}
|
||||
|
||||
for idx in range(invalid_block_idx, len(req_block_ids)):
|
||||
block_id = req_block_ids[idx]
|
||||
assert block_id not in cached_block_ids, (
|
||||
f"invalid block {block_id} at index {idx} should not be in cache hash table"
|
||||
)
|
||||
147
tests/v1/kv_connector/unit/test_error_propagation.py
Normal file
147
tests/v1/kv_connector/unit/test_error_propagation.py
Normal file
@ -0,0 +1,147 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from collections.abc import Callable
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.v1.core.sched.scheduler import Scheduler
|
||||
from vllm.v1.request import FinishReason, Request, RequestStatus
|
||||
|
||||
from .utils import (
|
||||
create_model_runner_output,
|
||||
create_request,
|
||||
create_scheduler,
|
||||
create_vllm_config,
|
||||
)
|
||||
|
||||
pytestmark = pytest.mark.cpu_test
|
||||
|
||||
|
||||
def _make_get_num_new_matched_tokens(
|
||||
req_num_new_matched_tokens: dict[str, int],
|
||||
async_load: bool,
|
||||
) -> Callable[[Request, int], tuple[int, bool]]:
|
||||
def get_num_new_matched_tokens(request: Request, _: int) -> tuple[int, bool]:
|
||||
value = req_num_new_matched_tokens.get(request.request_id, 0)
|
||||
return value, async_load
|
||||
|
||||
return get_num_new_matched_tokens
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def fail_scheduler():
|
||||
"""scheduler with kv_load_failure_policy='fail'"""
|
||||
vllm_config = create_vllm_config()
|
||||
vllm_config.kv_transfer_config.kv_load_failure_policy = "fail"
|
||||
return create_scheduler(vllm_config)
|
||||
|
||||
|
||||
def test_error_propagation_sync_load(fail_scheduler: Scheduler):
|
||||
"""test invalid_block_ids with fail policy -> FINISHED_ERROR (sync load)"""
|
||||
num_prompt_blocks = 100
|
||||
num_external_computed_blocks = 99
|
||||
invalid_block_idx = 50
|
||||
|
||||
num_prompt_tokens = num_prompt_blocks * fail_scheduler.block_size
|
||||
num_external_computed_tokens = (
|
||||
num_external_computed_blocks * fail_scheduler.block_size
|
||||
)
|
||||
|
||||
request = create_request(num_tokens=num_prompt_tokens)
|
||||
fail_scheduler.add_request(request=request)
|
||||
|
||||
req_num_new_matched_tokens = {
|
||||
request.request_id: num_external_computed_tokens,
|
||||
}
|
||||
|
||||
fail_scheduler.connector = Mock()
|
||||
fail_scheduler.connector.get_num_new_matched_tokens.side_effect = (
|
||||
_make_get_num_new_matched_tokens(req_num_new_matched_tokens, False)
|
||||
)
|
||||
fail_scheduler.connector.request_finished.return_value = (False, None)
|
||||
fail_scheduler.connector.take_events.return_value = ()
|
||||
|
||||
scheduler_output = fail_scheduler.schedule()
|
||||
|
||||
assert len(fail_scheduler.running) == 1
|
||||
assert len(scheduler_output.scheduled_new_reqs) == 1
|
||||
assert fail_scheduler.connector.get_num_new_matched_tokens.call_count == 1
|
||||
|
||||
req_block_ids = scheduler_output.scheduled_new_reqs[0].block_ids[0]
|
||||
invalid_block_ids = {req_block_ids[invalid_block_idx]}
|
||||
model_runner_output = create_model_runner_output(
|
||||
[request],
|
||||
invalid_block_ids=invalid_block_ids,
|
||||
use_eos=True,
|
||||
)
|
||||
|
||||
outputs = fail_scheduler.update_from_output(scheduler_output, model_runner_output)
|
||||
|
||||
assert request.status == RequestStatus.FINISHED_ERROR
|
||||
assert request.get_finished_reason() == FinishReason.ERROR
|
||||
|
||||
assert len(outputs) == 1
|
||||
engine_outputs = next(iter(outputs.values()))
|
||||
assert len(engine_outputs.outputs) == 1
|
||||
output = engine_outputs.outputs[0]
|
||||
assert output.request_id == request.request_id
|
||||
assert output.finish_reason == FinishReason.ERROR
|
||||
|
||||
assert len(fail_scheduler.running) == 0
|
||||
|
||||
|
||||
def test_error_propagation_async_load(fail_scheduler: Scheduler):
|
||||
"""test invalid_block_ids with fail policy -> FINISHED_ERROR (async load)"""
|
||||
num_prompt_blocks = 100
|
||||
num_external_computed_blocks = 99
|
||||
invalid_block_idx = 50
|
||||
|
||||
num_prompt_tokens = num_prompt_blocks * fail_scheduler.block_size
|
||||
num_external_computed_tokens = (
|
||||
num_external_computed_blocks * fail_scheduler.block_size
|
||||
)
|
||||
|
||||
request = create_request(num_tokens=num_prompt_tokens)
|
||||
fail_scheduler.add_request(request=request)
|
||||
|
||||
req_num_new_matched_tokens = {
|
||||
request.request_id: num_external_computed_tokens,
|
||||
}
|
||||
|
||||
fail_scheduler.connector = Mock()
|
||||
fail_scheduler.connector.get_num_new_matched_tokens.side_effect = (
|
||||
_make_get_num_new_matched_tokens(req_num_new_matched_tokens, True)
|
||||
)
|
||||
fail_scheduler.connector.request_finished.return_value = (False, None)
|
||||
fail_scheduler.connector.take_events.return_value = ()
|
||||
|
||||
scheduler_output = fail_scheduler.schedule()
|
||||
|
||||
assert len(fail_scheduler.waiting) == 1
|
||||
assert request.status == RequestStatus.WAITING_FOR_REMOTE_KVS
|
||||
assert request.num_computed_tokens == 0
|
||||
|
||||
(req_block_ids,) = fail_scheduler.kv_cache_manager.get_block_ids(request.request_id)
|
||||
invalid_block_ids = {req_block_ids[invalid_block_idx]}
|
||||
model_runner_output = create_model_runner_output(
|
||||
reqs=[],
|
||||
finished_recving=set(),
|
||||
invalid_block_ids=invalid_block_ids,
|
||||
use_eos=True,
|
||||
)
|
||||
|
||||
outputs = fail_scheduler.update_from_output(scheduler_output, model_runner_output)
|
||||
|
||||
assert request.status == RequestStatus.FINISHED_ERROR
|
||||
assert request.get_finished_reason() == FinishReason.ERROR
|
||||
|
||||
assert len(outputs) == 1
|
||||
engine_outputs = next(iter(outputs.values()))
|
||||
assert len(engine_outputs.outputs) == 1
|
||||
output = engine_outputs.outputs[0]
|
||||
assert output.request_id == request.request_id
|
||||
assert output.finish_reason == FinishReason.ERROR
|
||||
|
||||
assert len(fail_scheduler.waiting) == 0
|
||||
454
tests/v1/kv_connector/unit/test_invalid_blocks_correctness.py
Normal file
454
tests/v1/kv_connector/unit/test_invalid_blocks_correctness.py
Normal file
@ -0,0 +1,454 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
"""
|
||||
Tests for correctness in invalid block handling.
|
||||
|
||||
These tests verify correct behavior in three scenarios:
|
||||
1. Sync recompute case: Blocks should not be freed for running requests
|
||||
that need to recompute invalid blocks
|
||||
2. Sync fail case: Invalid blocks must be evicted from cache when request fails
|
||||
3. Async recompute case: Invalid blocks should not be cached after transfer
|
||||
"""
|
||||
|
||||
from collections.abc import Callable
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.v1.core.sched.scheduler import Scheduler
|
||||
from vllm.v1.request import FinishReason, Request, RequestStatus
|
||||
|
||||
from .utils import (
|
||||
create_model_runner_output,
|
||||
create_request,
|
||||
create_scheduler,
|
||||
create_vllm_config,
|
||||
)
|
||||
|
||||
pytestmark = pytest.mark.cpu_test
|
||||
|
||||
|
||||
def _make_get_num_new_matched_tokens(
|
||||
req_num_new_matched_tokens: dict[str, int],
|
||||
async_load: bool,
|
||||
) -> Callable[[Request, int], tuple[int, bool]]:
|
||||
def get_num_new_matched_tokens(request: Request, _: int) -> tuple[int, bool]:
|
||||
value = req_num_new_matched_tokens.get(request.request_id, 0)
|
||||
return value, async_load
|
||||
|
||||
return get_num_new_matched_tokens
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def fail_scheduler():
|
||||
"""scheduler with kv_load_failure_policy='fail'"""
|
||||
vllm_config = create_vllm_config()
|
||||
vllm_config.kv_transfer_config.kv_load_failure_policy = "fail"
|
||||
return create_scheduler(vllm_config)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def recompute_scheduler():
|
||||
"""scheduler with kv_load_failure_policy='recompute'"""
|
||||
vllm_config = create_vllm_config()
|
||||
vllm_config.kv_transfer_config.kv_load_failure_policy = "recompute"
|
||||
return create_scheduler(vllm_config)
|
||||
|
||||
|
||||
def test_sync_recompute_blocks_not_freed_for_running_requests(
|
||||
recompute_scheduler: Scheduler,
|
||||
):
|
||||
"""
|
||||
Test sync recompute case - blocks must not be freed for running requests.
|
||||
|
||||
When a running request has invalid blocks and retry_policy is 'recompute':
|
||||
1. Request should remain in RUNNING state
|
||||
2. num_computed_tokens should be truncated to invalid block boundary
|
||||
3. Blocks should NOT be freed (request still needs them for recomputation)
|
||||
4. Request should remain in scheduler.requests and scheduler.running
|
||||
"""
|
||||
num_prompt_blocks = 100
|
||||
num_external_computed_blocks = 99
|
||||
invalid_block_idx = 50
|
||||
|
||||
num_prompt_tokens = num_prompt_blocks * recompute_scheduler.block_size
|
||||
num_external_computed_tokens = (
|
||||
num_external_computed_blocks * recompute_scheduler.block_size
|
||||
)
|
||||
|
||||
request = create_request(num_tokens=num_prompt_tokens)
|
||||
recompute_scheduler.add_request(request=request)
|
||||
|
||||
req_num_new_matched_tokens = {
|
||||
request.request_id: num_external_computed_tokens,
|
||||
}
|
||||
|
||||
# mock connector indicating sync load
|
||||
recompute_scheduler.connector = Mock()
|
||||
recompute_scheduler.connector.get_num_new_matched_tokens.side_effect = (
|
||||
_make_get_num_new_matched_tokens(req_num_new_matched_tokens, False)
|
||||
)
|
||||
recompute_scheduler.connector.request_finished.return_value = (False, None)
|
||||
recompute_scheduler.connector.take_events.return_value = ()
|
||||
|
||||
scheduler_output = recompute_scheduler.schedule()
|
||||
|
||||
# request should be running with sync KV load
|
||||
assert len(recompute_scheduler.running) == 1
|
||||
assert len(scheduler_output.scheduled_new_reqs) == 1
|
||||
assert request.status == RequestStatus.RUNNING
|
||||
|
||||
# get the allocated block IDs before invalid blocks are reported
|
||||
req_block_ids = scheduler_output.scheduled_new_reqs[0].block_ids[0]
|
||||
invalid_block_ids = {req_block_ids[invalid_block_idx]}
|
||||
|
||||
# store original num_computed_tokens for comparison
|
||||
original_num_computed_tokens = request.num_computed_tokens
|
||||
|
||||
model_runner_output = create_model_runner_output(
|
||||
[request],
|
||||
invalid_block_ids=invalid_block_ids,
|
||||
use_eos=False, # not finished - should continue running
|
||||
)
|
||||
|
||||
outputs = recompute_scheduler.update_from_output(
|
||||
scheduler_output, model_runner_output
|
||||
)
|
||||
|
||||
# critical assertions for recompute case:
|
||||
|
||||
# 1. request should still be RUNNING (not finished, not aborted)
|
||||
assert request.status == RequestStatus.RUNNING, (
|
||||
f"Request should remain RUNNING for recompute, got {request.status}"
|
||||
)
|
||||
|
||||
# 2. num_computed_tokens should be truncated to first invalid block
|
||||
expected_truncated_tokens = invalid_block_idx * recompute_scheduler.block_size
|
||||
assert request.num_computed_tokens == expected_truncated_tokens, (
|
||||
f"num_computed_tokens should be truncated to {expected_truncated_tokens}, "
|
||||
f"got {request.num_computed_tokens}"
|
||||
)
|
||||
assert request.num_computed_tokens < original_num_computed_tokens, (
|
||||
"num_computed_tokens should be reduced after invalid block detection"
|
||||
)
|
||||
|
||||
# 3. no output should be generated (request is still running)
|
||||
# the request should be skipped in the output loop
|
||||
assert len(outputs) == 0 or request.request_id not in [
|
||||
out.request_id for outs in outputs.values() for out in outs.outputs
|
||||
], "No output should be generated for recompute requests"
|
||||
|
||||
# 4. request should still be in running queue
|
||||
assert request in recompute_scheduler.running, (
|
||||
"Request should remain in running queue for recomputation"
|
||||
)
|
||||
|
||||
# 5. request should still be in scheduler.requests (not deleted)
|
||||
assert request.request_id in recompute_scheduler.requests, (
|
||||
"Request should not be deleted from scheduler.requests"
|
||||
)
|
||||
|
||||
# 6. blocks should NOT be freed - verify blocks are still allocated
|
||||
try:
|
||||
allocated_blocks = recompute_scheduler.kv_cache_manager.get_block_ids(
|
||||
request.request_id
|
||||
)
|
||||
assert allocated_blocks is not None
|
||||
assert len(allocated_blocks[0]) > 0, (
|
||||
"Blocks should still be allocated for recomputation"
|
||||
)
|
||||
except KeyError:
|
||||
pytest.fail(
|
||||
"Blocks were freed incorrectly! Running requests need their blocks "
|
||||
"to recompute invalid portions."
|
||||
)
|
||||
|
||||
# 7. verify request can be rescheduled in next step
|
||||
scheduler_output_2 = recompute_scheduler.schedule()
|
||||
|
||||
# request should appear in the new schedule to recompute invalid blocks
|
||||
scheduled_req_ids = [
|
||||
req.request_id for req in scheduler_output_2.scheduled_new_reqs
|
||||
]
|
||||
if scheduler_output_2.num_scheduled_tokens:
|
||||
scheduled_req_ids.extend(scheduler_output_2.num_scheduled_tokens.keys())
|
||||
|
||||
assert (
|
||||
request.request_id in scheduled_req_ids or len(recompute_scheduler.running) > 0
|
||||
), "Request should be reschedulable for recomputation"
|
||||
|
||||
|
||||
def test_sync_fail_invalid_blocks_evicted(fail_scheduler: Scheduler):
|
||||
"""
|
||||
Test sync fail case - invalid blocks must be evicted from cache.
|
||||
|
||||
When a request fails with policy='fail' and has invalid blocks from sync loading:
|
||||
1. Request should be finished with FINISHED_ERROR
|
||||
2. Invalid blocks should be evicted from the KV cache
|
||||
3. Valid blocks (if shared) should remain in cache
|
||||
4. Future requests should not reuse the invalid blocks
|
||||
|
||||
This test verifies that invalid blocks are properly evicted to prevent
|
||||
cache corruption and reuse of invalid data.
|
||||
"""
|
||||
num_prompt_blocks = 100
|
||||
num_external_computed_blocks = 99
|
||||
invalid_block_idx = 50
|
||||
|
||||
num_prompt_tokens = num_prompt_blocks * fail_scheduler.block_size
|
||||
num_external_computed_tokens = (
|
||||
num_external_computed_blocks * fail_scheduler.block_size
|
||||
)
|
||||
|
||||
request = create_request(num_tokens=num_prompt_tokens)
|
||||
fail_scheduler.add_request(request=request)
|
||||
|
||||
req_num_new_matched_tokens = {
|
||||
request.request_id: num_external_computed_tokens,
|
||||
}
|
||||
|
||||
# mock connector indicating sync load
|
||||
fail_scheduler.connector = Mock()
|
||||
fail_scheduler.connector.get_num_new_matched_tokens.side_effect = (
|
||||
_make_get_num_new_matched_tokens(req_num_new_matched_tokens, False)
|
||||
)
|
||||
fail_scheduler.connector.request_finished.return_value = (False, None)
|
||||
fail_scheduler.connector.take_events.return_value = ()
|
||||
|
||||
scheduler_output = fail_scheduler.schedule()
|
||||
|
||||
# request should be running with sync KV load
|
||||
assert len(fail_scheduler.running) == 1
|
||||
assert request.status == RequestStatus.RUNNING
|
||||
|
||||
# get allocated block IDs
|
||||
req_block_ids = scheduler_output.scheduled_new_reqs[0].block_ids[0]
|
||||
invalid_block_id = req_block_ids[invalid_block_idx]
|
||||
invalid_block_ids = {invalid_block_id}
|
||||
|
||||
# verify the block is in the block pool before we report it as invalid
|
||||
block = fail_scheduler.kv_cache_manager.block_pool.blocks[invalid_block_id]
|
||||
assert block is not None
|
||||
|
||||
# report invalid blocks - request should fail
|
||||
model_runner_output = create_model_runner_output(
|
||||
[request],
|
||||
invalid_block_ids=invalid_block_ids,
|
||||
use_eos=True,
|
||||
)
|
||||
|
||||
outputs = fail_scheduler.update_from_output(scheduler_output, model_runner_output)
|
||||
|
||||
# verify request is finished with error
|
||||
assert request.status == RequestStatus.FINISHED_ERROR
|
||||
assert request.get_finished_reason() == FinishReason.ERROR
|
||||
|
||||
# verify output is generated
|
||||
assert len(outputs) == 1
|
||||
engine_outputs = next(iter(outputs.values()))
|
||||
assert len(engine_outputs.outputs) == 1
|
||||
output = engine_outputs.outputs[0]
|
||||
assert output.request_id == request.request_id
|
||||
assert output.finish_reason == FinishReason.ERROR
|
||||
|
||||
# verify the request was removed from scheduler
|
||||
assert request.request_id not in fail_scheduler.requests
|
||||
assert len(fail_scheduler.running) == 0
|
||||
|
||||
# critical: verify invalid block was actually freed from cache
|
||||
# this is the key assertion - the invalid block should no longer be
|
||||
# tracked by the KV cache manager for this request
|
||||
# if it's still there, a future request could reuse the invalid data
|
||||
try:
|
||||
block_ids = fail_scheduler.kv_cache_manager.get_block_ids(request.request_id)
|
||||
# if we get here, check if blocks were actually freed
|
||||
if block_ids is not None and len(block_ids[0]) > 0:
|
||||
pytest.fail(
|
||||
f"Invalid blocks still tracked for finished request! "
|
||||
f"Request {request.request_id} should have been freed but "
|
||||
f"still has {len(block_ids[0])} blocks allocated."
|
||||
)
|
||||
# blocks list exists but is empty - this is fine, they were freed
|
||||
except KeyError:
|
||||
# expected - request completely removed from tracking
|
||||
pass
|
||||
|
||||
# critical: verify invalid block was evicted from prefix cache
|
||||
# the block should no longer have a hash (hash is reset on eviction)
|
||||
assert block.block_hash is None, (
|
||||
f"Invalid block {invalid_block_id} should have been evicted from cache "
|
||||
f"(hash should be None), but hash is still {block.block_hash}"
|
||||
)
|
||||
|
||||
|
||||
def test_async_recompute_blocks_not_cached_when_invalid(
|
||||
recompute_scheduler: Scheduler,
|
||||
):
|
||||
"""
|
||||
Test async recompute case - invalid blocks not cached after transfer.
|
||||
|
||||
When async KV loading has invalid blocks and retry_policy is 'recompute':
|
||||
1. Blocks are allocated but not cached yet
|
||||
2. When async transfer completes, only valid blocks should be cached
|
||||
3. Invalid blocks should never enter the prefix cache
|
||||
|
||||
This test verifies correctness, the failed_recving_kv_req_ids protection
|
||||
ensures only valid blocks are cached when the transfer completes, and we
|
||||
only evict blocks from cache that are already hashed in the block table.
|
||||
"""
|
||||
from unittest.mock import patch
|
||||
|
||||
num_prompt_blocks = 100
|
||||
num_external_computed_blocks = 99
|
||||
invalid_block_idx = 50
|
||||
|
||||
num_prompt_tokens = num_prompt_blocks * recompute_scheduler.block_size
|
||||
num_external_computed_tokens = (
|
||||
num_external_computed_blocks * recompute_scheduler.block_size
|
||||
)
|
||||
|
||||
request = create_request(num_tokens=num_prompt_tokens)
|
||||
recompute_scheduler.add_request(request=request)
|
||||
|
||||
req_num_new_matched_tokens = {
|
||||
request.request_id: num_external_computed_tokens,
|
||||
}
|
||||
|
||||
# mock connector indicating async load
|
||||
recompute_scheduler.connector = Mock()
|
||||
recompute_scheduler.connector.get_num_new_matched_tokens.side_effect = (
|
||||
_make_get_num_new_matched_tokens(req_num_new_matched_tokens, True)
|
||||
)
|
||||
recompute_scheduler.connector.request_finished.return_value = (False, None)
|
||||
recompute_scheduler.connector.take_events.return_value = ()
|
||||
|
||||
scheduler_output = recompute_scheduler.schedule()
|
||||
|
||||
# request should be waiting for remote KVs
|
||||
assert len(recompute_scheduler.waiting) == 1
|
||||
assert request.status == RequestStatus.WAITING_FOR_REMOTE_KVS
|
||||
assert request.num_computed_tokens == 0
|
||||
|
||||
# get the allocated block IDs
|
||||
(req_block_ids,) = recompute_scheduler.kv_cache_manager.get_block_ids(
|
||||
request.request_id
|
||||
)
|
||||
invalid_block_id = req_block_ids[invalid_block_idx]
|
||||
invalid_block_ids = {invalid_block_id}
|
||||
|
||||
# get the block object to verify it's not cached yet and stays uncached
|
||||
block = recompute_scheduler.kv_cache_manager.block_pool.blocks[invalid_block_id]
|
||||
|
||||
# verify block has no hash before invalid blocks are reported
|
||||
assert block.block_hash is None, (
|
||||
"Async loading blocks should not be cached yet (no hash)"
|
||||
)
|
||||
|
||||
# report invalid blocks (transfer not finished yet)
|
||||
model_runner_output = create_model_runner_output(
|
||||
reqs=[],
|
||||
finished_recving=None, # transfer NOT finished
|
||||
invalid_block_ids=invalid_block_ids,
|
||||
use_eos=False,
|
||||
)
|
||||
|
||||
# critical: spy on evict_blocks to verify it's NOT called for async blocks
|
||||
original_evict_blocks = recompute_scheduler.kv_cache_manager.evict_blocks
|
||||
evict_blocks_calls = []
|
||||
|
||||
def evict_blocks_spy(block_ids):
|
||||
evict_blocks_calls.append(set(block_ids))
|
||||
return original_evict_blocks(block_ids)
|
||||
|
||||
with patch.object(
|
||||
recompute_scheduler.kv_cache_manager, "evict_blocks", evict_blocks_spy
|
||||
):
|
||||
recompute_scheduler.update_from_output(scheduler_output, model_runner_output)
|
||||
|
||||
# verify evict_blocks was NOT called (async blocks excluded from eviction)
|
||||
assert len(evict_blocks_calls) == 0, (
|
||||
f"evict_blocks should not be called for async-only invalid blocks, "
|
||||
f"but was called {len(evict_blocks_calls)} time(s) with {evict_blocks_calls}"
|
||||
)
|
||||
|
||||
# request should still be waiting (not finished with error due to recompute policy)
|
||||
assert request.status == RequestStatus.WAITING_FOR_REMOTE_KVS
|
||||
assert request.request_id in recompute_scheduler.failed_recving_kv_req_ids
|
||||
|
||||
# verify num_computed_tokens was truncated to before invalid block
|
||||
expected_valid_tokens = invalid_block_idx * recompute_scheduler.block_size
|
||||
assert request.num_computed_tokens == expected_valid_tokens
|
||||
|
||||
# verify invalid block still has no hash (was not evicted)
|
||||
assert block.block_hash is None, (
|
||||
f"Async loading blocks shouldn't be cached or evicted. "
|
||||
f"Block {invalid_block_id} hash should be None but is {block.block_hash}"
|
||||
)
|
||||
|
||||
# now simulate async transfer completing
|
||||
model_runner_output_2 = create_model_runner_output(
|
||||
reqs=[],
|
||||
finished_recving={request.request_id},
|
||||
invalid_block_ids=None,
|
||||
use_eos=False,
|
||||
)
|
||||
|
||||
recompute_scheduler.update_from_output(scheduler_output, model_runner_output_2)
|
||||
|
||||
# verify request is now marked as finished receiving and ready to be processed
|
||||
assert request.request_id in recompute_scheduler.finished_recving_kv_req_ids
|
||||
assert request.request_id in recompute_scheduler.failed_recving_kv_req_ids
|
||||
|
||||
# critical: verify invalid block still has no hash before recompute
|
||||
# the async transfer invalid data was never cached
|
||||
assert block.block_hash is None, (
|
||||
f"Invalid block {invalid_block_id} should not be cached before recompute "
|
||||
f"(hash should be None), but hash is {block.block_hash}"
|
||||
)
|
||||
|
||||
# critical end-to-end test: spy on cache_blocks to verify it's called with
|
||||
# the truncated num_computed_tokens value
|
||||
original_cache_blocks = recompute_scheduler.kv_cache_manager.cache_blocks
|
||||
cache_blocks_calls = []
|
||||
|
||||
def cache_blocks_spy(req, num_tokens):
|
||||
cache_blocks_calls.append((req.request_id, num_tokens))
|
||||
return original_cache_blocks(req, num_tokens)
|
||||
|
||||
with patch.object(
|
||||
recompute_scheduler.kv_cache_manager, "cache_blocks", cache_blocks_spy
|
||||
):
|
||||
# call schedule() again - this triggers _update_waiting_for_remote_kv()
|
||||
# which should call cache_blocks with the truncated value
|
||||
recompute_scheduler.schedule()
|
||||
|
||||
# verify cache_blocks was called with the truncated value
|
||||
assert len(cache_blocks_calls) == 1, (
|
||||
f"cache_blocks should be called exactly once, "
|
||||
f"got {len(cache_blocks_calls)} calls"
|
||||
)
|
||||
cached_req_id, cached_num_tokens = cache_blocks_calls[0]
|
||||
assert cached_req_id == request.request_id
|
||||
assert cached_num_tokens == expected_valid_tokens, (
|
||||
f"cache_blocks should be called with truncated value {expected_valid_tokens}, "
|
||||
f"but was called with {cached_num_tokens}"
|
||||
)
|
||||
|
||||
# request should now be RUNNING (scheduled immediately after transfer completes)
|
||||
# the flow is: WAITING_FOR_REMOTE_KVS -> WAITING -> RUNNING in same schedule() call
|
||||
assert request.status == RequestStatus.RUNNING
|
||||
|
||||
# num_computed_tokens should be >= expected_valid_tokens because the scheduler
|
||||
# will schedule additional new tokens (up to max_num_batched_tokens) for the request
|
||||
assert request.num_computed_tokens >= expected_valid_tokens, (
|
||||
f"num_computed_tokens should be at least {expected_valid_tokens}, "
|
||||
f"got {request.num_computed_tokens}"
|
||||
)
|
||||
|
||||
# request should no longer be in the failed/finished receiving sets
|
||||
assert request.request_id not in recompute_scheduler.failed_recving_kv_req_ids
|
||||
assert request.request_id not in recompute_scheduler.finished_recving_kv_req_ids
|
||||
|
||||
# request should be in the running queue
|
||||
assert request in recompute_scheduler.running
|
||||
@ -64,6 +64,11 @@ class KVTransferConfig:
|
||||
enable_permute_local_kv: bool = False
|
||||
"""Experiment feature flag to enable HND to NHD KV Transfer"""
|
||||
|
||||
kv_load_failure_policy: Literal["recompute", "fail"] = "recompute"
|
||||
"""Policy for handling KV cache load failures.
|
||||
'recompute': reschedule the request to recompute failed blocks (default)
|
||||
'fail': immediately fail the request with an error finish reason"""
|
||||
|
||||
def compute_hash(self) -> str:
|
||||
"""
|
||||
WARNING: Whenever a new field is added to this config,
|
||||
|
||||
@ -51,7 +51,11 @@ from vllm.entrypoints.openai.protocol import (
|
||||
ToolCall,
|
||||
UsageInfo,
|
||||
)
|
||||
from vllm.entrypoints.openai.serving_engine import OpenAIServing, clamp_prompt_logprobs
|
||||
from vllm.entrypoints.openai.serving_engine import (
|
||||
GenerationError,
|
||||
OpenAIServing,
|
||||
clamp_prompt_logprobs,
|
||||
)
|
||||
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
|
||||
from vllm.entrypoints.openai.tool_parsers import ToolParser
|
||||
from vllm.entrypoints.openai.tool_parsers.mistral_tool_parser import MistralToolCall
|
||||
@ -380,6 +384,8 @@ class OpenAIServingChat(OpenAIServing):
|
||||
tokenizer,
|
||||
request_metadata,
|
||||
)
|
||||
except GenerationError as e:
|
||||
return self._convert_generation_error_to_response(e)
|
||||
except ValueError as e:
|
||||
# TODO: Use a vllm-specific Validation Error
|
||||
return self.create_error_response(str(e))
|
||||
@ -1120,6 +1126,10 @@ class OpenAIServingChat(OpenAIServing):
|
||||
|
||||
# if the model is finished generating
|
||||
else:
|
||||
# check for error finish reason and abort streaming
|
||||
# finish_reason='error' indicates a retryable error
|
||||
self._raise_if_error(output.finish_reason, request_id)
|
||||
|
||||
# check to make sure we haven't "forgotten" to stream
|
||||
# any tokens that were generated but previously
|
||||
# matched by partial json parsing
|
||||
@ -1287,6 +1297,8 @@ class OpenAIServingChat(OpenAIServing):
|
||||
delta=False,
|
||||
)
|
||||
|
||||
except GenerationError as e:
|
||||
yield f"data: {self._convert_generation_error_to_streaming_response(e)}\n\n"
|
||||
except Exception as e:
|
||||
# TODO: Use a vllm-specific Validation Error
|
||||
logger.exception("Error in chat completion stream generator.")
|
||||
@ -1327,6 +1339,9 @@ class OpenAIServingChat(OpenAIServing):
|
||||
|
||||
role = self.get_chat_request_role(request)
|
||||
for output in final_res.outputs:
|
||||
# check for error finish reason and raise GenerationError
|
||||
# finish_reason='error' indicates a retryable request-level internal error
|
||||
self._raise_if_error(output.finish_reason, request_id)
|
||||
token_ids = output.token_ids
|
||||
out_logprobs = output.logprobs
|
||||
tool_call_info = None
|
||||
|
||||
@ -24,7 +24,11 @@ from vllm.entrypoints.openai.protocol import (
|
||||
RequestResponseMetadata,
|
||||
UsageInfo,
|
||||
)
|
||||
from vllm.entrypoints.openai.serving_engine import OpenAIServing, clamp_prompt_logprobs
|
||||
from vllm.entrypoints.openai.serving_engine import (
|
||||
GenerationError,
|
||||
OpenAIServing,
|
||||
clamp_prompt_logprobs,
|
||||
)
|
||||
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
|
||||
from vllm.entrypoints.renderer import RenderConfig
|
||||
from vllm.entrypoints.utils import get_max_tokens, should_include_usage
|
||||
@ -300,6 +304,8 @@ class OpenAIServingCompletion(OpenAIServing):
|
||||
)
|
||||
except asyncio.CancelledError:
|
||||
return self.create_error_response("Client disconnected")
|
||||
except GenerationError as e:
|
||||
return self._convert_generation_error_to_response(e)
|
||||
except ValueError as e:
|
||||
# TODO: Use a vllm-specific Validation Error
|
||||
return self.create_error_response(str(e))
|
||||
@ -437,6 +443,8 @@ class OpenAIServingCompletion(OpenAIServing):
|
||||
finish_reason = output.finish_reason
|
||||
stop_reason = output.stop_reason
|
||||
|
||||
self._raise_if_error(finish_reason, request_id)
|
||||
|
||||
chunk = CompletionStreamResponse(
|
||||
id=request_id,
|
||||
created=created_time,
|
||||
@ -498,8 +506,11 @@ class OpenAIServingCompletion(OpenAIServing):
|
||||
# report to FastAPI middleware aggregate usage across all choices
|
||||
request_metadata.final_usage_info = final_usage_info
|
||||
|
||||
except GenerationError as e:
|
||||
yield f"data: {self._convert_generation_error_to_streaming_response(e)}\n\n"
|
||||
except Exception as e:
|
||||
# TODO: Use a vllm-specific Validation Error
|
||||
logger.exception("Error in completion stream generator.")
|
||||
data = self.create_streaming_error_response(str(e))
|
||||
yield f"data: {data}\n\n"
|
||||
yield "data: [DONE]\n\n"
|
||||
@ -530,6 +541,8 @@ class OpenAIServingCompletion(OpenAIServing):
|
||||
out_logprobs: GenericSequence[dict[int, Logprob] | None] | None
|
||||
|
||||
for output in final_res.outputs:
|
||||
self._raise_if_error(output.finish_reason, request_id)
|
||||
|
||||
assert request.max_tokens is not None
|
||||
if request.echo:
|
||||
if request.return_token_ids:
|
||||
|
||||
@ -133,6 +133,15 @@ from vllm.utils.async_utils import (
|
||||
from vllm.utils.collection_utils import is_list_of
|
||||
from vllm.v1.engine import EngineCoreRequest
|
||||
|
||||
|
||||
class GenerationError(Exception):
|
||||
"""raised when finish_reason indicates internal server error (500)"""
|
||||
|
||||
def __init__(self, message: str = "Internal server error"):
|
||||
super().__init__(message)
|
||||
self.status_code = HTTPStatus.INTERNAL_SERVER_ERROR
|
||||
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
CompletionLikeRequest: TypeAlias = (
|
||||
@ -456,6 +465,29 @@ class OpenAIServing:
|
||||
# Iterate through all beam inference results
|
||||
for i, result in enumerate(output):
|
||||
current_beam = all_beams[i]
|
||||
|
||||
# check for error finish reason and abort beam search
|
||||
if result.outputs[0].finish_reason == "error":
|
||||
# yield error output and terminate beam search
|
||||
yield RequestOutput(
|
||||
request_id=request_id,
|
||||
prompt=prompt_text,
|
||||
outputs=[
|
||||
CompletionOutput(
|
||||
index=0,
|
||||
text="",
|
||||
token_ids=[],
|
||||
cumulative_logprob=None,
|
||||
logprobs=None,
|
||||
finish_reason="error",
|
||||
)
|
||||
],
|
||||
finished=True,
|
||||
prompt_token_ids=prompt_token_ids,
|
||||
prompt_logprobs=None,
|
||||
)
|
||||
return
|
||||
|
||||
if result.outputs[0].logprobs is not None:
|
||||
logprobs = result.outputs[0].logprobs[0]
|
||||
all_beams_token_id.extend(list(logprobs.keys()))
|
||||
@ -780,6 +812,35 @@ class OpenAIServing:
|
||||
)
|
||||
return json_str
|
||||
|
||||
def _raise_if_error(self, finish_reason: str | None, request_id: str) -> None:
|
||||
"""Raise GenerationError if finish_reason indicates an error."""
|
||||
if finish_reason == "error":
|
||||
logger.error(
|
||||
"Request %s failed with an internal error during generation",
|
||||
request_id,
|
||||
)
|
||||
raise GenerationError("Internal server error")
|
||||
|
||||
def _convert_generation_error_to_response(
|
||||
self, e: GenerationError
|
||||
) -> ErrorResponse:
|
||||
"""Convert GenerationError to ErrorResponse."""
|
||||
return self.create_error_response(
|
||||
str(e),
|
||||
err_type="InternalServerError",
|
||||
status_code=e.status_code,
|
||||
)
|
||||
|
||||
def _convert_generation_error_to_streaming_response(
|
||||
self, e: GenerationError
|
||||
) -> str:
|
||||
"""Convert GenerationError to streaming error response."""
|
||||
return self.create_streaming_error_response(
|
||||
str(e),
|
||||
err_type="InternalServerError",
|
||||
status_code=e.status_code,
|
||||
)
|
||||
|
||||
async def _check_model(
|
||||
self,
|
||||
request: AnyRequest,
|
||||
|
||||
@ -50,6 +50,7 @@ from openai.types.responses.response_reasoning_item import (
|
||||
)
|
||||
from openai.types.responses.tool import Mcp, Tool
|
||||
from openai_harmony import Message as OpenAIHarmonyMessage
|
||||
from pydantic import TypeAdapter
|
||||
|
||||
from vllm import envs
|
||||
from vllm.engine.protocol import EngineClient
|
||||
@ -94,7 +95,10 @@ from vllm.entrypoints.openai.protocol import (
|
||||
ResponseUsage,
|
||||
StreamingResponsesResponse,
|
||||
)
|
||||
from vllm.entrypoints.openai.serving_engine import OpenAIServing
|
||||
from vllm.entrypoints.openai.serving_engine import (
|
||||
GenerationError,
|
||||
OpenAIServing,
|
||||
)
|
||||
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
|
||||
from vllm.entrypoints.responses_utils import (
|
||||
construct_input_messages,
|
||||
@ -541,6 +545,8 @@ class OpenAIServingResponses(OpenAIServing):
|
||||
tokenizer,
|
||||
request_metadata,
|
||||
)
|
||||
except GenerationError as e:
|
||||
return self._convert_generation_error_to_response(e)
|
||||
except Exception as e:
|
||||
return self.create_error_response(str(e))
|
||||
|
||||
@ -648,6 +654,8 @@ class OpenAIServingResponses(OpenAIServing):
|
||||
status = "incomplete"
|
||||
elif context.finish_reason == "abort":
|
||||
status = "cancelled"
|
||||
else:
|
||||
self._raise_if_error(context.finish_reason, request.request_id)
|
||||
else:
|
||||
status = "incomplete"
|
||||
elif isinstance(context, ParsableContext):
|
||||
@ -673,6 +681,9 @@ class OpenAIServingResponses(OpenAIServing):
|
||||
assert len(final_res.outputs) == 1
|
||||
final_output = final_res.outputs[0]
|
||||
|
||||
# finish_reason='error' indicates retryable internal error
|
||||
self._raise_if_error(final_output.finish_reason, request.request_id)
|
||||
|
||||
output = self._make_response_output_items(request, final_output, tokenizer)
|
||||
|
||||
if request.enable_response_messages:
|
||||
@ -1066,6 +1077,8 @@ class OpenAIServingResponses(OpenAIServing):
|
||||
async for event in generator:
|
||||
event_deque.append(event)
|
||||
new_event_signal.set() # Signal new event available
|
||||
except GenerationError as e:
|
||||
response = self._convert_generation_error_to_response(e)
|
||||
except Exception as e:
|
||||
logger.exception("Background request failed for %s", request.request_id)
|
||||
response = self.create_error_response(str(e))
|
||||
@ -1089,6 +1102,8 @@ class OpenAIServingResponses(OpenAIServing):
|
||||
):
|
||||
try:
|
||||
response = await self.responses_full_generator(request, *args, **kwargs)
|
||||
except GenerationError as e:
|
||||
response = self._convert_generation_error_to_response(e)
|
||||
except Exception as e:
|
||||
logger.exception("Background request failed for %s", request.request_id)
|
||||
response = self.create_error_response(str(e))
|
||||
@ -1227,6 +1242,8 @@ class OpenAIServingResponses(OpenAIServing):
|
||||
continue
|
||||
if ctx.last_output.outputs:
|
||||
output = ctx.last_output.outputs[0]
|
||||
# finish_reason='error' indicates a retryable error
|
||||
self._raise_if_error(output.finish_reason, request.request_id)
|
||||
if reasoning_parser:
|
||||
delta_message = reasoning_parser.extract_reasoning_streaming(
|
||||
previous_text=previous_text,
|
||||
@ -1522,6 +1539,9 @@ class OpenAIServingResponses(OpenAIServing):
|
||||
async for ctx in result_generator:
|
||||
assert isinstance(ctx, StreamingHarmonyContext)
|
||||
|
||||
# finish_reason='error' indicates a retryable error
|
||||
self._raise_if_error(ctx.finish_reason, request.request_id)
|
||||
|
||||
if ctx.is_expecting_start():
|
||||
current_output_index += 1
|
||||
sent_output_item_added = False
|
||||
@ -2016,18 +2036,25 @@ class OpenAIServingResponses(OpenAIServing):
|
||||
)
|
||||
)
|
||||
|
||||
async for event_data in processer(
|
||||
request,
|
||||
sampling_params,
|
||||
result_generator,
|
||||
context,
|
||||
model_name,
|
||||
tokenizer,
|
||||
request_metadata,
|
||||
created_time,
|
||||
_increment_sequence_number_and_return,
|
||||
):
|
||||
yield event_data
|
||||
try:
|
||||
async for event_data in processer(
|
||||
request,
|
||||
sampling_params,
|
||||
result_generator,
|
||||
context,
|
||||
model_name,
|
||||
tokenizer,
|
||||
request_metadata,
|
||||
created_time,
|
||||
_increment_sequence_number_and_return,
|
||||
):
|
||||
yield event_data
|
||||
except GenerationError as e:
|
||||
error_json = self._convert_generation_error_to_streaming_response(e)
|
||||
yield _increment_sequence_number_and_return(
|
||||
TypeAdapter(StreamingResponsesResponse).validate_json(error_json)
|
||||
)
|
||||
return
|
||||
|
||||
async def empty_async_generator():
|
||||
# A hack to trick Python to think this is a generator but
|
||||
|
||||
@ -397,6 +397,25 @@ class BlockPool:
|
||||
[block for block in blocks_list if block.ref_cnt == 0 and not block.is_null]
|
||||
)
|
||||
|
||||
def evict_blocks(self, block_ids: set[int]) -> None:
|
||||
"""evict blocks from the prefix cache by their block IDs.
|
||||
|
||||
only evicts blocks that are currently cached (have a hash). blocks
|
||||
with ref_cnt > 0 are not freed from the block pool, only evicted
|
||||
from the prefix cache hash table.
|
||||
|
||||
Args:
|
||||
block_ids: Set of block IDs to evict from cache.
|
||||
"""
|
||||
for block_id in block_ids:
|
||||
assert block_id < len(self.blocks), (
|
||||
f"Invalid block_id {block_id} >= {len(self.blocks)}. "
|
||||
f"This indicates a bug in the KV connector - workers should "
|
||||
f"only report block IDs that were allocated by the scheduler."
|
||||
)
|
||||
block = self.blocks[block_id]
|
||||
self._maybe_evict_cached_block(block)
|
||||
|
||||
def reset_prefix_cache(self) -> bool:
|
||||
"""Reset prefix cache. This function may be used in RLHF
|
||||
flows to invalid prefix caching after the weights are updated,
|
||||
|
||||
@ -333,6 +333,14 @@ class KVCacheManager:
|
||||
"""
|
||||
self.coordinator.free(request.request_id)
|
||||
|
||||
def evict_blocks(self, block_ids: set[int]) -> None:
|
||||
"""evict blocks from the prefix cache by their block IDs.
|
||||
|
||||
Args:
|
||||
block_ids: Set of block IDs to evict from cache.
|
||||
"""
|
||||
self.block_pool.evict_blocks(block_ids)
|
||||
|
||||
def reset_prefix_cache(self) -> bool:
|
||||
"""Reset prefix cache. This function may be used in RLHF
|
||||
flows to invalidate prefix caching after the weights are updated,
|
||||
|
||||
@ -106,6 +106,7 @@ class Scheduler(SchedulerInterface):
|
||||
# KV Connector pushes/pull of remote KVs for P/D and offloading.
|
||||
self.connector = None
|
||||
self.connector_prefix_cache_stats: PrefixCacheStats | None = None
|
||||
self.recompute_kv_load_failures = True
|
||||
if self.vllm_config.kv_transfer_config is not None:
|
||||
assert not self.is_encoder_decoder, (
|
||||
"Encoder-decoder models are not currently supported with KV connectors"
|
||||
@ -117,6 +118,10 @@ class Scheduler(SchedulerInterface):
|
||||
)
|
||||
if self.log_stats:
|
||||
self.connector_prefix_cache_stats = PrefixCacheStats()
|
||||
kv_load_failure_policy = (
|
||||
self.vllm_config.kv_transfer_config.kv_load_failure_policy
|
||||
)
|
||||
self.recompute_kv_load_failures = kv_load_failure_policy == "recompute"
|
||||
|
||||
self.kv_event_publisher = EventPublisherFactory.create(
|
||||
self.kv_events_config,
|
||||
@ -1066,7 +1071,7 @@ class Scheduler(SchedulerInterface):
|
||||
for req_id, num_tokens_scheduled in num_scheduled_tokens.items():
|
||||
assert num_tokens_scheduled > 0
|
||||
if failed_kv_load_req_ids and req_id in failed_kv_load_req_ids:
|
||||
# Skip requests that were recovered from KV load failure
|
||||
# skip failed or rescheduled requests from KV load failure
|
||||
continue
|
||||
request = self.requests.get(req_id)
|
||||
if request is None:
|
||||
@ -1177,6 +1182,21 @@ class Scheduler(SchedulerInterface):
|
||||
# This is a rare case and unlikely to impact performance.
|
||||
self.waiting.remove_requests(stopped_preempted_reqs)
|
||||
|
||||
if failed_kv_load_req_ids and not self.recompute_kv_load_failures:
|
||||
requests = [self.requests[req_id] for req_id in failed_kv_load_req_ids]
|
||||
self.finish_requests(failed_kv_load_req_ids, RequestStatus.FINISHED_ERROR)
|
||||
for request in requests:
|
||||
outputs[request.client_index].append(
|
||||
EngineCoreOutput(
|
||||
request_id=request.request_id,
|
||||
new_token_ids=[],
|
||||
finish_reason=request.get_finished_reason(),
|
||||
events=request.take_events(),
|
||||
trace_headers=request.trace_headers,
|
||||
num_cached_tokens=request.num_cached_tokens,
|
||||
)
|
||||
)
|
||||
|
||||
# KV Connector: update state for finished KV Transfers.
|
||||
if kv_connector_output:
|
||||
self._update_from_kv_xfer_finished(kv_connector_output)
|
||||
@ -1610,8 +1630,11 @@ class Scheduler(SchedulerInterface):
|
||||
self._free_blocks(self.requests[req_id])
|
||||
|
||||
def _update_requests_with_invalid_blocks(
|
||||
self, requests: Iterable[Request], invalid_block_ids: set[int]
|
||||
) -> tuple[set[str], int]:
|
||||
self,
|
||||
requests: Iterable[Request],
|
||||
invalid_block_ids: set[int],
|
||||
evict_blocks: bool = True,
|
||||
) -> tuple[set[str], int, set[int]]:
|
||||
"""
|
||||
Identify and update requests affected by invalid KV cache blocks.
|
||||
|
||||
@ -1623,16 +1646,21 @@ class Scheduler(SchedulerInterface):
|
||||
Args:
|
||||
requests: The set of requests to scan for invalid blocks.
|
||||
invalid_block_ids: IDs of invalid blocks.
|
||||
evict_blocks: Whether to collect blocks for eviction (False for
|
||||
async requests which aren't cached yet).
|
||||
|
||||
Returns:
|
||||
tuple:
|
||||
- affected_req_ids (set[str]): IDs of requests impacted by
|
||||
invalid blocks.
|
||||
- total_affected_tokens (int): Total number of tokens that must
|
||||
be recomputed across all affected requests (for observability).
|
||||
be recomputed across all affected requests.
|
||||
- blocks_to_evict (set[int]): Block IDs to evict from cache,
|
||||
including invalid blocks and downstream dependent blocks.
|
||||
"""
|
||||
affected_req_ids: set[str] = set()
|
||||
total_affected_tokens = 0
|
||||
blocks_to_evict: set[int] = set()
|
||||
# If a block is invalid and shared by multiple requests in the batch,
|
||||
# these requests must be rescheduled, but only the first will recompute
|
||||
# it. This set tracks blocks already marked for recomputation.
|
||||
@ -1690,6 +1718,9 @@ class Scheduler(SchedulerInterface):
|
||||
)
|
||||
total_affected_tokens += num_affected_tokens
|
||||
request.num_external_computed_tokens -= num_affected_tokens
|
||||
# collect invalid block and all downstream dependent blocks
|
||||
if evict_blocks:
|
||||
blocks_to_evict.update(req_block_ids[idx:])
|
||||
|
||||
if is_affected:
|
||||
if not marked_invalid_block:
|
||||
@ -1705,47 +1736,70 @@ class Scheduler(SchedulerInterface):
|
||||
|
||||
affected_req_ids.add(request.request_id)
|
||||
|
||||
return affected_req_ids, total_affected_tokens
|
||||
return affected_req_ids, total_affected_tokens, blocks_to_evict
|
||||
|
||||
def _handle_invalid_blocks(self, invalid_block_ids: set[int]) -> set[str]:
|
||||
total_requests_to_reschedule = 0
|
||||
total_tokens_to_reschedule = 0
|
||||
"""
|
||||
Handle requests affected by invalid KV cache blocks.
|
||||
|
||||
# --- Handle async KV loads (WAITING_FOR_REMOTE_KVS) ---
|
||||
Returns:
|
||||
Set of affected request IDs to skip in update_from_output main loop.
|
||||
"""
|
||||
should_fail = not self.recompute_kv_load_failures
|
||||
|
||||
# handle async KV loads (not cached yet, evict_blocks=False)
|
||||
async_load_reqs = (
|
||||
req
|
||||
for req in self.waiting
|
||||
if req.status == RequestStatus.WAITING_FOR_REMOTE_KVS
|
||||
)
|
||||
async_affected_req_ids, num_tokens_to_reschedule = (
|
||||
async_failed_req_ids, num_failed_tokens, _ = (
|
||||
self._update_requests_with_invalid_blocks(
|
||||
async_load_reqs, invalid_block_ids
|
||||
async_load_reqs, invalid_block_ids, evict_blocks=False
|
||||
)
|
||||
)
|
||||
|
||||
total_requests_to_reschedule += len(async_affected_req_ids)
|
||||
total_tokens_to_reschedule += num_tokens_to_reschedule
|
||||
total_failed_requests = len(async_failed_req_ids)
|
||||
total_failed_tokens = num_failed_tokens
|
||||
|
||||
# Mark requests with async KV load failures; they will be rescheduled
|
||||
# once loading completes.
|
||||
self.failed_recving_kv_req_ids |= async_affected_req_ids
|
||||
|
||||
# --- Handle sync KV loads (running requests) ---
|
||||
sync_affected_req_ids, num_tokens_to_reschedule = (
|
||||
self._update_requests_with_invalid_blocks(self.running, invalid_block_ids)
|
||||
# handle sync loads (may be cached, collect blocks for eviction)
|
||||
sync_failed_req_ids, num_failed_tokens, sync_blocks_to_evict = (
|
||||
self._update_requests_with_invalid_blocks(
|
||||
self.running, invalid_block_ids, evict_blocks=True
|
||||
)
|
||||
)
|
||||
|
||||
total_requests_to_reschedule += len(sync_affected_req_ids)
|
||||
total_tokens_to_reschedule += num_tokens_to_reschedule
|
||||
total_failed_requests += len(sync_failed_req_ids)
|
||||
total_failed_tokens += num_failed_tokens
|
||||
|
||||
if total_requests_to_reschedule:
|
||||
logger.warning(
|
||||
"Recovered from KV load failure: "
|
||||
"%d request(s) rescheduled (%d tokens affected).",
|
||||
total_requests_to_reschedule,
|
||||
total_tokens_to_reschedule,
|
||||
if not total_failed_requests:
|
||||
return set()
|
||||
|
||||
# evict invalid blocks and downstream dependent blocks from cache
|
||||
# only when not using recompute policy (where blocks will be recomputed
|
||||
# and reused by other requests sharing them)
|
||||
if sync_blocks_to_evict and not self.recompute_kv_load_failures:
|
||||
self.kv_cache_manager.evict_blocks(sync_blocks_to_evict)
|
||||
|
||||
if should_fail:
|
||||
all_failed_req_ids = async_failed_req_ids | sync_failed_req_ids
|
||||
logger.error(
|
||||
"Failing %d request(s) due to KV load failure "
|
||||
"(failure_policy=fail, %d tokens affected). Request IDs: %s",
|
||||
total_failed_requests,
|
||||
total_failed_tokens,
|
||||
all_failed_req_ids,
|
||||
)
|
||||
return all_failed_req_ids
|
||||
|
||||
# Return the IDs of affected running requests to skip in
|
||||
# update_from_output.
|
||||
return sync_affected_req_ids
|
||||
logger.warning(
|
||||
"Recovered from KV load failure: "
|
||||
"%d request(s) rescheduled (%d tokens affected).",
|
||||
total_failed_requests,
|
||||
total_failed_tokens,
|
||||
)
|
||||
|
||||
# Mark async requests with KV load failures for retry once loading completes
|
||||
self.failed_recving_kv_req_ids |= async_failed_req_ids
|
||||
# Return sync affected IDs to skip in update_from_output
|
||||
return sync_failed_req_ids
|
||||
|
||||
@ -19,24 +19,27 @@ from vllm.v1.serial_utils import UtilityResult
|
||||
|
||||
# These are possible values of RequestOutput.finish_reason,
|
||||
# so form part of the external API.
|
||||
FINISH_REASON_STRINGS = ("stop", "length", "abort")
|
||||
FINISH_REASON_STRINGS = ("stop", "length", "abort", "error")
|
||||
|
||||
|
||||
class FinishReason(enum.IntEnum):
|
||||
"""
|
||||
Reason a request finished - stop, length, or abort.
|
||||
Reason a request finished - stop, length, abort, or error.
|
||||
|
||||
Int rather than Str for more compact serialization.
|
||||
|
||||
stop - a stop string was emitted
|
||||
length - max_tokens was consumed, or max_model_len was reached
|
||||
abort - aborted for another reason
|
||||
abort - aborted by client
|
||||
error - retryable request-level internal error (e.g., KV load failure).
|
||||
Invariant: always converted to 500 Internal Server Error.
|
||||
|
||||
"""
|
||||
|
||||
STOP = 0
|
||||
LENGTH = 1
|
||||
ABORT = 2
|
||||
ERROR = 3
|
||||
|
||||
def __str__(self):
|
||||
return FINISH_REASON_STRINGS[self.value]
|
||||
|
||||
@ -255,6 +255,7 @@ class RequestStatus(enum.IntEnum):
|
||||
FINISHED_LENGTH_CAPPED = enum.auto()
|
||||
FINISHED_ABORTED = enum.auto()
|
||||
FINISHED_IGNORED = enum.auto()
|
||||
FINISHED_ERROR = enum.auto()
|
||||
|
||||
def __str__(self):
|
||||
return self.name
|
||||
@ -277,4 +278,5 @@ _FINISHED_REASON_MAP = {
|
||||
RequestStatus.FINISHED_LENGTH_CAPPED: FinishReason.LENGTH,
|
||||
RequestStatus.FINISHED_ABORTED: FinishReason.ABORT,
|
||||
RequestStatus.FINISHED_IGNORED: FinishReason.LENGTH,
|
||||
RequestStatus.FINISHED_ERROR: FinishReason.ERROR,
|
||||
}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user