diff --git a/tests/v1/core/test_scheduler.py b/tests/v1/core/test_scheduler.py index d5b829e79b8f..d31338220fca 100644 --- a/tests/v1/core/test_scheduler.py +++ b/tests/v1/core/test_scheduler.py @@ -31,11 +31,11 @@ from vllm.v1.kv_cache_interface import ( KVCacheConfig, KVCacheGroupSpec, ) -from vllm.v1.outputs import DraftTokenIds, ModelRunnerOutput +from vllm.v1.outputs import DraftTokenIds, KVConnectorOutput, ModelRunnerOutput from vllm.v1.request import Request, RequestStatus from vllm.v1.structured_output import StructuredOutputManager -from .utils import EOS_TOKEN_ID, create_requests, create_scheduler +from .utils import EOS_TOKEN_ID, create_requests, create_scheduler, mock_kv pytestmark = pytest.mark.cpu_test @@ -888,27 +888,65 @@ def _step_until_done( all_finished = all_done -def test_kv_connector_basic(): +def _step_until_kv_transfer_finished(scheduler: Scheduler, req_ids: list[str]): + """Cycle requests through a KV transfer cyle.""" + + # Requests should first transition to WAITING_FOR_REMOTE_KVS + output = scheduler.schedule() + assert len(scheduler.waiting) == len(req_ids) + assert len(scheduler.running) == 0 + assert len(output.scheduled_new_reqs) == 0 + for req in scheduler.requests.values(): + assert req.status == RequestStatus.WAITING_FOR_REMOTE_KVS + + # No model execution yet + EMPTY_OUTPUT = ModelRunnerOutput( + req_ids=[], + req_id_to_index={}, + sampled_token_ids=[], + logprobs=None, + prompt_logprobs_dict={}, + pooler_output=[], + ) + scheduler.update_from_output(output, EMPTY_OUTPUT) + + # Simulate KV transfer completion using KVConnectorOutput.finished_recving + output = scheduler.schedule() + assert len(scheduler.waiting) == len(req_ids) + assert len(scheduler.running) == 0 + + MODEL_RUNNER_OUTPUT = ModelRunnerOutput( + req_ids=[], + req_id_to_index={}, + sampled_token_ids=[], + logprobs=None, + prompt_logprobs_dict={}, + pooler_output=[], + kv_connector_output=KVConnectorOutput(finished_recving=req_ids), + ) + scheduler.update_from_output(output, MODEL_RUNNER_OUTPUT) + for req_id in req_ids: + assert req_id in scheduler.finished_recving_kv_req_ids + + +@pytest.mark.parametrize("is_async", [False, True]) +def test_kv_connector_basic(is_async: bool): """ Test whether Scheduler with KVConnector schedules tokens, allocates memory, and cleans up requests as expected under normal operation. """ # Setup Scheduler. + BLOCK_SIZE = 16 + NUM_MATCHED_NEW_TOKENS = BLOCK_SIZE * 2 scheduler = create_scheduler( enable_prefix_caching=True, - use_kv_connector=True, + use_kv_connector=mock_kv( + matched_tokens=NUM_MATCHED_NEW_TOKENS, is_async=is_async + ), + block_size=BLOCK_SIZE, ) NUM_TOTAL_BLOCKS = scheduler.kv_cache_manager.block_pool.get_num_free_blocks() - BLOCK_SIZE = scheduler.cache_config.block_size - - # Mock External Cache Hit. - NUM_MATCHED_NEW_TOKENS = BLOCK_SIZE * 2 - scheduler.connector.get_num_new_matched_tokens = Mock(name="method") - scheduler.connector.get_num_new_matched_tokens.return_value = ( - NUM_MATCHED_NEW_TOKENS, - False, - ) ###################################################### # FIRST SET OF REQUESTS - External Hit Only @@ -928,6 +966,9 @@ def test_kv_connector_basic(): req_ids.append(request.request_id) req_to_index[request.request_id] = i + if is_async: + _step_until_kv_transfer_finished(scheduler, req_ids) + MODEL_RUNNER_OUTPUT = ModelRunnerOutput( req_ids=req_ids, req_id_to_index=req_to_index, @@ -978,6 +1019,9 @@ def test_kv_connector_basic(): req_ids.append(request.request_id) req_to_index[request.request_id] = i + if is_async: + _step_until_kv_transfer_finished(scheduler, req_ids) + MODEL_RUNNER_OUTPUT = ModelRunnerOutput( req_ids=req_ids, req_id_to_index=req_to_index, @@ -1020,17 +1064,10 @@ def test_external_prefix_cache_metrics(): """ # Setup Scheduler. + NUM_MATCHED_NEW_TOKENS = 4 scheduler = create_scheduler( enable_prefix_caching=False, - use_kv_connector=True, - ) - - # Mock connector to simulate a partial external cache hit - NUM_MATCHED_NEW_TOKENS = 4 - scheduler.connector.get_num_new_matched_tokens = Mock(name="method") - scheduler.connector.get_num_new_matched_tokens.return_value = ( - NUM_MATCHED_NEW_TOKENS, - False, + use_kv_connector=mock_kv(matched_tokens=NUM_MATCHED_NEW_TOKENS, is_async=False), ) # --- Prepare simple requests --- @@ -1085,21 +1122,16 @@ def test_kv_connector_unable_to_allocate(use_ec_connector, ec_role): # Setup Scheduler With Mock External Cache Hit. BLOCK_SIZE = 4 NUM_BLOCKS = 10 + NUM_MATCHED_NEW_TOKENS = BLOCK_SIZE * 2 scheduler = create_scheduler( enable_prefix_caching=True, - use_kv_connector=True, + use_kv_connector=mock_kv(matched_tokens=NUM_MATCHED_NEW_TOKENS, is_async=False), block_size=BLOCK_SIZE, num_blocks=NUM_BLOCKS, # encoder connector should not affect test results use_ec_connector=use_ec_connector, ec_role=ec_role, ) - NUM_MATCHED_NEW_TOKENS = BLOCK_SIZE * 2 - scheduler.connector.get_num_new_matched_tokens = Mock(name="method") - scheduler.connector.get_num_new_matched_tokens.return_value = ( - NUM_MATCHED_NEW_TOKENS, - False, - ) # Create two requests. The second request will not be able to # allocate slots because it will not have enough blocks. @@ -1174,9 +1206,10 @@ def test_kv_connector_handles_preemption(use_ec_connector, ec_role): BLOCK_SIZE = 2 # NOTE: there is 1 null block, so this is 6 blocks. NUM_BLOCKS = 7 + NUM_MATCHED_NEW_TOKENS = BLOCK_SIZE scheduler = create_scheduler( enable_prefix_caching=True, - use_kv_connector=True, + use_kv_connector=mock_kv(matched_tokens=NUM_MATCHED_NEW_TOKENS, is_async=False), block_size=BLOCK_SIZE, num_blocks=NUM_BLOCKS, # encoder connector should not affect test results @@ -1184,13 +1217,6 @@ def test_kv_connector_handles_preemption(use_ec_connector, ec_role): ec_role=ec_role, ) - NUM_MATCHED_NEW_TOKENS = BLOCK_SIZE - scheduler.connector.get_num_new_matched_tokens = Mock(name="method") - scheduler.connector.get_num_new_matched_tokens.return_value = ( - NUM_MATCHED_NEW_TOKENS, - False, - ) - # Create two requests. # Both can be scheduled at first, but the second request # will be preempted and re-scheduled. diff --git a/tests/v1/core/utils.py b/tests/v1/core/utils.py index 3692e633322e..65511c17473b 100644 --- a/tests/v1/core/utils.py +++ b/tests/v1/core/utils.py @@ -3,6 +3,7 @@ import torch +from tests.v1.kv_connector.unit.utils import MockKVConfig from vllm.config import ( CacheConfig, ECTransferConfig, @@ -33,6 +34,10 @@ from vllm.v1.structured_output import StructuredOutputManager EOS_TOKEN_ID = 50256 +def mock_kv(matched_tokens: int, is_async: bool): + return MockKVConfig(matched_tokens=matched_tokens, is_async=is_async) + + def create_scheduler( model: str = "facebook/opt-125m", max_num_seqs: int = 16, @@ -40,7 +45,7 @@ def create_scheduler( enable_prefix_caching: bool | None = None, long_prefill_token_threshold: int = 0, disable_chunked_mm_input: bool = False, - use_kv_connector: bool = False, + use_kv_connector: None | bool | MockKVConfig = None, num_blocks: int = 10000, block_size: int = 16, max_model_len: int | None = None, @@ -94,15 +99,22 @@ def create_scheduler( cache_dtype="auto", **kwargs_cache, ) - kv_transfer_config = ( - KVTransferConfig( + kv_transfer_config = None + if isinstance(use_kv_connector, MockKVConfig): + kv_transfer_config = KVTransferConfig( + kv_connector="MockKVConnector", + kv_role="kv_both", + kv_connector_extra_config={ + "matched_tokens": use_kv_connector.matched_tokens, + "is_async": use_kv_connector.is_async, + }, + ) + elif use_kv_connector: + kv_transfer_config = KVTransferConfig( kv_connector="SharedStorageConnector", kv_role="kv_both", kv_connector_extra_config={"shared_storage_path": "local_storage"}, ) - if use_kv_connector - else None - ) speculative_config: SpeculativeConfig | None = None if num_speculative_tokens is not None: diff --git a/tests/v1/kv_connector/unit/utils.py b/tests/v1/kv_connector/unit/utils.py index f0031643aa9d..f35f91bb3adf 100644 --- a/tests/v1/kv_connector/unit/utils.py +++ b/tests/v1/kv_connector/unit/utils.py @@ -3,7 +3,8 @@ import tempfile from collections import defaultdict from collections.abc import Callable -from itertools import count +from dataclasses import dataclass +from itertools import chain, count from typing import Any import torch @@ -18,13 +19,18 @@ from vllm.config import ( VllmConfig, ) from vllm.distributed.kv_transfer.kv_connector.factory import KVConnectorFactory +from vllm.distributed.kv_transfer.kv_connector.v1.base import ( + KVConnectorBase_V1, + KVConnectorMetadata, + KVConnectorRole, +) from vllm.distributed.kv_transfer.kv_connector.v1.shared_storage_connector import ( # noqa SharedStorageConnector, ) from vllm.utils.hashing import sha256 from vllm.v1.core.kv_cache_manager import KVCacheBlocks from vllm.v1.core.kv_cache_utils import get_request_block_hasher, init_none_hash -from vllm.v1.core.sched.scheduler import Scheduler +from vllm.v1.core.sched.scheduler import Scheduler, SchedulerOutput from vllm.v1.kv_cache_interface import ( FullAttentionSpec, KVCacheConfig, @@ -307,6 +313,82 @@ class TestSharedStorageConnector(SharedStorageConnector): return attr +@dataclass(frozen=True) +class MockKVConfig: + matched_tokens: int = 0 + is_async: bool = False + + +class MockKVConnectorMetadata(KVConnectorMetadata): + def __init__(self): + # Scheduler tests check metadata.requests + self.requests: list = [] + + +class MockKVConnector(KVConnectorBase_V1): + """Mock KV connector for scheduler tests, supporting both sync and async mode.""" + + def __init__( + self, + vllm_config: VllmConfig, + role: KVConnectorRole, + kv_cache_config: KVCacheConfig | None = None, + ): + super().__init__(vllm_config, role, kv_cache_config) + extra_config = self._kv_transfer_config.kv_connector_extra_config + self.config = MockKVConfig( + matched_tokens=extra_config["matched_tokens"], + is_async=extra_config["is_async"], + ) + + def get_num_new_matched_tokens( + self, + request: Request, + num_computed_tokens: int, + ) -> tuple[int | None, bool]: + return (self.config.matched_tokens, self.config.is_async) + + def update_state_after_alloc( + self, + request: Request, + blocks: KVCacheBlocks, + num_external_tokens: int, + ): + pass + + def build_connector_meta( + self, scheduler_output: SchedulerOutput + ) -> KVConnectorMetadata: + metadata = MockKVConnectorMetadata() + cached_reqs = scheduler_output.scheduled_cached_reqs + for req_id in chain( + (req.req_id for req in scheduler_output.scheduled_new_reqs), + ( + req_id + for req_id in cached_reqs.req_ids + if req_id in cached_reqs.resumed_req_ids + ), + ): + metadata.requests.append({"req_id": req_id}) + return metadata + + def start_load_kv(self, kv_caches, finished_req_ids): + pass + + def wait_for_layer_load(self, layer_name): + pass + + def save_kv_layer(self, layer_name, kv_layer, attn_metadata, **kwargs): + pass + + def wait_for_save(self): + pass + + KVConnectorFactory.register_connector( "TestSharedStorageConnector", __name__, TestSharedStorageConnector.__name__ ) + +KVConnectorFactory.register_connector( + "MockKVConnector", __name__, MockKVConnector.__name__ +)