[KV Connector] Test async mode in scheduler tests (#28550)

Signed-off-by: Mark McLoughlin <markmc@redhat.com>
This commit is contained in:
Mark McLoughlin 2025-11-13 23:30:59 +00:00 committed by GitHub
parent e64011f29a
commit 6e25b1cddf
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 165 additions and 45 deletions

View File

@ -31,11 +31,11 @@ from vllm.v1.kv_cache_interface import (
KVCacheConfig, KVCacheConfig,
KVCacheGroupSpec, KVCacheGroupSpec,
) )
from vllm.v1.outputs import DraftTokenIds, ModelRunnerOutput from vllm.v1.outputs import DraftTokenIds, KVConnectorOutput, ModelRunnerOutput
from vllm.v1.request import Request, RequestStatus from vllm.v1.request import Request, RequestStatus
from vllm.v1.structured_output import StructuredOutputManager from vllm.v1.structured_output import StructuredOutputManager
from .utils import EOS_TOKEN_ID, create_requests, create_scheduler from .utils import EOS_TOKEN_ID, create_requests, create_scheduler, mock_kv
pytestmark = pytest.mark.cpu_test pytestmark = pytest.mark.cpu_test
@ -888,27 +888,65 @@ def _step_until_done(
all_finished = all_done all_finished = all_done
def test_kv_connector_basic(): def _step_until_kv_transfer_finished(scheduler: Scheduler, req_ids: list[str]):
"""Cycle requests through a KV transfer cyle."""
# Requests should first transition to WAITING_FOR_REMOTE_KVS
output = scheduler.schedule()
assert len(scheduler.waiting) == len(req_ids)
assert len(scheduler.running) == 0
assert len(output.scheduled_new_reqs) == 0
for req in scheduler.requests.values():
assert req.status == RequestStatus.WAITING_FOR_REMOTE_KVS
# No model execution yet
EMPTY_OUTPUT = ModelRunnerOutput(
req_ids=[],
req_id_to_index={},
sampled_token_ids=[],
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[],
)
scheduler.update_from_output(output, EMPTY_OUTPUT)
# Simulate KV transfer completion using KVConnectorOutput.finished_recving
output = scheduler.schedule()
assert len(scheduler.waiting) == len(req_ids)
assert len(scheduler.running) == 0
MODEL_RUNNER_OUTPUT = ModelRunnerOutput(
req_ids=[],
req_id_to_index={},
sampled_token_ids=[],
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[],
kv_connector_output=KVConnectorOutput(finished_recving=req_ids),
)
scheduler.update_from_output(output, MODEL_RUNNER_OUTPUT)
for req_id in req_ids:
assert req_id in scheduler.finished_recving_kv_req_ids
@pytest.mark.parametrize("is_async", [False, True])
def test_kv_connector_basic(is_async: bool):
""" """
Test whether Scheduler with KVConnector schedules tokens, allocates Test whether Scheduler with KVConnector schedules tokens, allocates
memory, and cleans up requests as expected under normal operation. memory, and cleans up requests as expected under normal operation.
""" """
# Setup Scheduler. # Setup Scheduler.
BLOCK_SIZE = 16
NUM_MATCHED_NEW_TOKENS = BLOCK_SIZE * 2
scheduler = create_scheduler( scheduler = create_scheduler(
enable_prefix_caching=True, enable_prefix_caching=True,
use_kv_connector=True, use_kv_connector=mock_kv(
matched_tokens=NUM_MATCHED_NEW_TOKENS, is_async=is_async
),
block_size=BLOCK_SIZE,
) )
NUM_TOTAL_BLOCKS = scheduler.kv_cache_manager.block_pool.get_num_free_blocks() NUM_TOTAL_BLOCKS = scheduler.kv_cache_manager.block_pool.get_num_free_blocks()
BLOCK_SIZE = scheduler.cache_config.block_size
# Mock External Cache Hit.
NUM_MATCHED_NEW_TOKENS = BLOCK_SIZE * 2
scheduler.connector.get_num_new_matched_tokens = Mock(name="method")
scheduler.connector.get_num_new_matched_tokens.return_value = (
NUM_MATCHED_NEW_TOKENS,
False,
)
###################################################### ######################################################
# FIRST SET OF REQUESTS - External Hit Only # FIRST SET OF REQUESTS - External Hit Only
@ -928,6 +966,9 @@ def test_kv_connector_basic():
req_ids.append(request.request_id) req_ids.append(request.request_id)
req_to_index[request.request_id] = i req_to_index[request.request_id] = i
if is_async:
_step_until_kv_transfer_finished(scheduler, req_ids)
MODEL_RUNNER_OUTPUT = ModelRunnerOutput( MODEL_RUNNER_OUTPUT = ModelRunnerOutput(
req_ids=req_ids, req_ids=req_ids,
req_id_to_index=req_to_index, req_id_to_index=req_to_index,
@ -978,6 +1019,9 @@ def test_kv_connector_basic():
req_ids.append(request.request_id) req_ids.append(request.request_id)
req_to_index[request.request_id] = i req_to_index[request.request_id] = i
if is_async:
_step_until_kv_transfer_finished(scheduler, req_ids)
MODEL_RUNNER_OUTPUT = ModelRunnerOutput( MODEL_RUNNER_OUTPUT = ModelRunnerOutput(
req_ids=req_ids, req_ids=req_ids,
req_id_to_index=req_to_index, req_id_to_index=req_to_index,
@ -1020,17 +1064,10 @@ def test_external_prefix_cache_metrics():
""" """
# Setup Scheduler. # Setup Scheduler.
NUM_MATCHED_NEW_TOKENS = 4
scheduler = create_scheduler( scheduler = create_scheduler(
enable_prefix_caching=False, enable_prefix_caching=False,
use_kv_connector=True, use_kv_connector=mock_kv(matched_tokens=NUM_MATCHED_NEW_TOKENS, is_async=False),
)
# Mock connector to simulate a partial external cache hit
NUM_MATCHED_NEW_TOKENS = 4
scheduler.connector.get_num_new_matched_tokens = Mock(name="method")
scheduler.connector.get_num_new_matched_tokens.return_value = (
NUM_MATCHED_NEW_TOKENS,
False,
) )
# --- Prepare simple requests --- # --- Prepare simple requests ---
@ -1085,21 +1122,16 @@ def test_kv_connector_unable_to_allocate(use_ec_connector, ec_role):
# Setup Scheduler With Mock External Cache Hit. # Setup Scheduler With Mock External Cache Hit.
BLOCK_SIZE = 4 BLOCK_SIZE = 4
NUM_BLOCKS = 10 NUM_BLOCKS = 10
NUM_MATCHED_NEW_TOKENS = BLOCK_SIZE * 2
scheduler = create_scheduler( scheduler = create_scheduler(
enable_prefix_caching=True, enable_prefix_caching=True,
use_kv_connector=True, use_kv_connector=mock_kv(matched_tokens=NUM_MATCHED_NEW_TOKENS, is_async=False),
block_size=BLOCK_SIZE, block_size=BLOCK_SIZE,
num_blocks=NUM_BLOCKS, num_blocks=NUM_BLOCKS,
# encoder connector should not affect test results # encoder connector should not affect test results
use_ec_connector=use_ec_connector, use_ec_connector=use_ec_connector,
ec_role=ec_role, ec_role=ec_role,
) )
NUM_MATCHED_NEW_TOKENS = BLOCK_SIZE * 2
scheduler.connector.get_num_new_matched_tokens = Mock(name="method")
scheduler.connector.get_num_new_matched_tokens.return_value = (
NUM_MATCHED_NEW_TOKENS,
False,
)
# Create two requests. The second request will not be able to # Create two requests. The second request will not be able to
# allocate slots because it will not have enough blocks. # allocate slots because it will not have enough blocks.
@ -1174,9 +1206,10 @@ def test_kv_connector_handles_preemption(use_ec_connector, ec_role):
BLOCK_SIZE = 2 BLOCK_SIZE = 2
# NOTE: there is 1 null block, so this is 6 blocks. # NOTE: there is 1 null block, so this is 6 blocks.
NUM_BLOCKS = 7 NUM_BLOCKS = 7
NUM_MATCHED_NEW_TOKENS = BLOCK_SIZE
scheduler = create_scheduler( scheduler = create_scheduler(
enable_prefix_caching=True, enable_prefix_caching=True,
use_kv_connector=True, use_kv_connector=mock_kv(matched_tokens=NUM_MATCHED_NEW_TOKENS, is_async=False),
block_size=BLOCK_SIZE, block_size=BLOCK_SIZE,
num_blocks=NUM_BLOCKS, num_blocks=NUM_BLOCKS,
# encoder connector should not affect test results # encoder connector should not affect test results
@ -1184,13 +1217,6 @@ def test_kv_connector_handles_preemption(use_ec_connector, ec_role):
ec_role=ec_role, ec_role=ec_role,
) )
NUM_MATCHED_NEW_TOKENS = BLOCK_SIZE
scheduler.connector.get_num_new_matched_tokens = Mock(name="method")
scheduler.connector.get_num_new_matched_tokens.return_value = (
NUM_MATCHED_NEW_TOKENS,
False,
)
# Create two requests. # Create two requests.
# Both can be scheduled at first, but the second request # Both can be scheduled at first, but the second request
# will be preempted and re-scheduled. # will be preempted and re-scheduled.

View File

@ -3,6 +3,7 @@
import torch import torch
from tests.v1.kv_connector.unit.utils import MockKVConfig
from vllm.config import ( from vllm.config import (
CacheConfig, CacheConfig,
ECTransferConfig, ECTransferConfig,
@ -33,6 +34,10 @@ from vllm.v1.structured_output import StructuredOutputManager
EOS_TOKEN_ID = 50256 EOS_TOKEN_ID = 50256
def mock_kv(matched_tokens: int, is_async: bool):
return MockKVConfig(matched_tokens=matched_tokens, is_async=is_async)
def create_scheduler( def create_scheduler(
model: str = "facebook/opt-125m", model: str = "facebook/opt-125m",
max_num_seqs: int = 16, max_num_seqs: int = 16,
@ -40,7 +45,7 @@ def create_scheduler(
enable_prefix_caching: bool | None = None, enable_prefix_caching: bool | None = None,
long_prefill_token_threshold: int = 0, long_prefill_token_threshold: int = 0,
disable_chunked_mm_input: bool = False, disable_chunked_mm_input: bool = False,
use_kv_connector: bool = False, use_kv_connector: None | bool | MockKVConfig = None,
num_blocks: int = 10000, num_blocks: int = 10000,
block_size: int = 16, block_size: int = 16,
max_model_len: int | None = None, max_model_len: int | None = None,
@ -94,15 +99,22 @@ def create_scheduler(
cache_dtype="auto", cache_dtype="auto",
**kwargs_cache, **kwargs_cache,
) )
kv_transfer_config = ( kv_transfer_config = None
KVTransferConfig( if isinstance(use_kv_connector, MockKVConfig):
kv_transfer_config = KVTransferConfig(
kv_connector="MockKVConnector",
kv_role="kv_both",
kv_connector_extra_config={
"matched_tokens": use_kv_connector.matched_tokens,
"is_async": use_kv_connector.is_async,
},
)
elif use_kv_connector:
kv_transfer_config = KVTransferConfig(
kv_connector="SharedStorageConnector", kv_connector="SharedStorageConnector",
kv_role="kv_both", kv_role="kv_both",
kv_connector_extra_config={"shared_storage_path": "local_storage"}, kv_connector_extra_config={"shared_storage_path": "local_storage"},
) )
if use_kv_connector
else None
)
speculative_config: SpeculativeConfig | None = None speculative_config: SpeculativeConfig | None = None
if num_speculative_tokens is not None: if num_speculative_tokens is not None:

View File

@ -3,7 +3,8 @@
import tempfile import tempfile
from collections import defaultdict from collections import defaultdict
from collections.abc import Callable from collections.abc import Callable
from itertools import count from dataclasses import dataclass
from itertools import chain, count
from typing import Any from typing import Any
import torch import torch
@ -18,13 +19,18 @@ from vllm.config import (
VllmConfig, VllmConfig,
) )
from vllm.distributed.kv_transfer.kv_connector.factory import KVConnectorFactory from vllm.distributed.kv_transfer.kv_connector.factory import KVConnectorFactory
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
KVConnectorBase_V1,
KVConnectorMetadata,
KVConnectorRole,
)
from vllm.distributed.kv_transfer.kv_connector.v1.shared_storage_connector import ( # noqa from vllm.distributed.kv_transfer.kv_connector.v1.shared_storage_connector import ( # noqa
SharedStorageConnector, SharedStorageConnector,
) )
from vllm.utils.hashing import sha256 from vllm.utils.hashing import sha256
from vllm.v1.core.kv_cache_manager import KVCacheBlocks from vllm.v1.core.kv_cache_manager import KVCacheBlocks
from vllm.v1.core.kv_cache_utils import get_request_block_hasher, init_none_hash from vllm.v1.core.kv_cache_utils import get_request_block_hasher, init_none_hash
from vllm.v1.core.sched.scheduler import Scheduler from vllm.v1.core.sched.scheduler import Scheduler, SchedulerOutput
from vllm.v1.kv_cache_interface import ( from vllm.v1.kv_cache_interface import (
FullAttentionSpec, FullAttentionSpec,
KVCacheConfig, KVCacheConfig,
@ -307,6 +313,82 @@ class TestSharedStorageConnector(SharedStorageConnector):
return attr return attr
@dataclass(frozen=True)
class MockKVConfig:
matched_tokens: int = 0
is_async: bool = False
class MockKVConnectorMetadata(KVConnectorMetadata):
def __init__(self):
# Scheduler tests check metadata.requests
self.requests: list = []
class MockKVConnector(KVConnectorBase_V1):
"""Mock KV connector for scheduler tests, supporting both sync and async mode."""
def __init__(
self,
vllm_config: VllmConfig,
role: KVConnectorRole,
kv_cache_config: KVCacheConfig | None = None,
):
super().__init__(vllm_config, role, kv_cache_config)
extra_config = self._kv_transfer_config.kv_connector_extra_config
self.config = MockKVConfig(
matched_tokens=extra_config["matched_tokens"],
is_async=extra_config["is_async"],
)
def get_num_new_matched_tokens(
self,
request: Request,
num_computed_tokens: int,
) -> tuple[int | None, bool]:
return (self.config.matched_tokens, self.config.is_async)
def update_state_after_alloc(
self,
request: Request,
blocks: KVCacheBlocks,
num_external_tokens: int,
):
pass
def build_connector_meta(
self, scheduler_output: SchedulerOutput
) -> KVConnectorMetadata:
metadata = MockKVConnectorMetadata()
cached_reqs = scheduler_output.scheduled_cached_reqs
for req_id in chain(
(req.req_id for req in scheduler_output.scheduled_new_reqs),
(
req_id
for req_id in cached_reqs.req_ids
if req_id in cached_reqs.resumed_req_ids
),
):
metadata.requests.append({"req_id": req_id})
return metadata
def start_load_kv(self, kv_caches, finished_req_ids):
pass
def wait_for_layer_load(self, layer_name):
pass
def save_kv_layer(self, layer_name, kv_layer, attn_metadata, **kwargs):
pass
def wait_for_save(self):
pass
KVConnectorFactory.register_connector( KVConnectorFactory.register_connector(
"TestSharedStorageConnector", __name__, TestSharedStorageConnector.__name__ "TestSharedStorageConnector", __name__, TestSharedStorageConnector.__name__
) )
KVConnectorFactory.register_connector(
"MockKVConnector", __name__, MockKVConnector.__name__
)