mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 11:17:07 +08:00
[KV Connector] Test async mode in scheduler tests (#28550)
Signed-off-by: Mark McLoughlin <markmc@redhat.com>
This commit is contained in:
parent
e64011f29a
commit
6e25b1cddf
@ -31,11 +31,11 @@ from vllm.v1.kv_cache_interface import (
|
|||||||
KVCacheConfig,
|
KVCacheConfig,
|
||||||
KVCacheGroupSpec,
|
KVCacheGroupSpec,
|
||||||
)
|
)
|
||||||
from vllm.v1.outputs import DraftTokenIds, ModelRunnerOutput
|
from vllm.v1.outputs import DraftTokenIds, KVConnectorOutput, ModelRunnerOutput
|
||||||
from vllm.v1.request import Request, RequestStatus
|
from vllm.v1.request import Request, RequestStatus
|
||||||
from vllm.v1.structured_output import StructuredOutputManager
|
from vllm.v1.structured_output import StructuredOutputManager
|
||||||
|
|
||||||
from .utils import EOS_TOKEN_ID, create_requests, create_scheduler
|
from .utils import EOS_TOKEN_ID, create_requests, create_scheduler, mock_kv
|
||||||
|
|
||||||
pytestmark = pytest.mark.cpu_test
|
pytestmark = pytest.mark.cpu_test
|
||||||
|
|
||||||
@ -888,27 +888,65 @@ def _step_until_done(
|
|||||||
all_finished = all_done
|
all_finished = all_done
|
||||||
|
|
||||||
|
|
||||||
def test_kv_connector_basic():
|
def _step_until_kv_transfer_finished(scheduler: Scheduler, req_ids: list[str]):
|
||||||
|
"""Cycle requests through a KV transfer cyle."""
|
||||||
|
|
||||||
|
# Requests should first transition to WAITING_FOR_REMOTE_KVS
|
||||||
|
output = scheduler.schedule()
|
||||||
|
assert len(scheduler.waiting) == len(req_ids)
|
||||||
|
assert len(scheduler.running) == 0
|
||||||
|
assert len(output.scheduled_new_reqs) == 0
|
||||||
|
for req in scheduler.requests.values():
|
||||||
|
assert req.status == RequestStatus.WAITING_FOR_REMOTE_KVS
|
||||||
|
|
||||||
|
# No model execution yet
|
||||||
|
EMPTY_OUTPUT = ModelRunnerOutput(
|
||||||
|
req_ids=[],
|
||||||
|
req_id_to_index={},
|
||||||
|
sampled_token_ids=[],
|
||||||
|
logprobs=None,
|
||||||
|
prompt_logprobs_dict={},
|
||||||
|
pooler_output=[],
|
||||||
|
)
|
||||||
|
scheduler.update_from_output(output, EMPTY_OUTPUT)
|
||||||
|
|
||||||
|
# Simulate KV transfer completion using KVConnectorOutput.finished_recving
|
||||||
|
output = scheduler.schedule()
|
||||||
|
assert len(scheduler.waiting) == len(req_ids)
|
||||||
|
assert len(scheduler.running) == 0
|
||||||
|
|
||||||
|
MODEL_RUNNER_OUTPUT = ModelRunnerOutput(
|
||||||
|
req_ids=[],
|
||||||
|
req_id_to_index={},
|
||||||
|
sampled_token_ids=[],
|
||||||
|
logprobs=None,
|
||||||
|
prompt_logprobs_dict={},
|
||||||
|
pooler_output=[],
|
||||||
|
kv_connector_output=KVConnectorOutput(finished_recving=req_ids),
|
||||||
|
)
|
||||||
|
scheduler.update_from_output(output, MODEL_RUNNER_OUTPUT)
|
||||||
|
for req_id in req_ids:
|
||||||
|
assert req_id in scheduler.finished_recving_kv_req_ids
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("is_async", [False, True])
|
||||||
|
def test_kv_connector_basic(is_async: bool):
|
||||||
"""
|
"""
|
||||||
Test whether Scheduler with KVConnector schedules tokens, allocates
|
Test whether Scheduler with KVConnector schedules tokens, allocates
|
||||||
memory, and cleans up requests as expected under normal operation.
|
memory, and cleans up requests as expected under normal operation.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Setup Scheduler.
|
# Setup Scheduler.
|
||||||
|
BLOCK_SIZE = 16
|
||||||
|
NUM_MATCHED_NEW_TOKENS = BLOCK_SIZE * 2
|
||||||
scheduler = create_scheduler(
|
scheduler = create_scheduler(
|
||||||
enable_prefix_caching=True,
|
enable_prefix_caching=True,
|
||||||
use_kv_connector=True,
|
use_kv_connector=mock_kv(
|
||||||
|
matched_tokens=NUM_MATCHED_NEW_TOKENS, is_async=is_async
|
||||||
|
),
|
||||||
|
block_size=BLOCK_SIZE,
|
||||||
)
|
)
|
||||||
NUM_TOTAL_BLOCKS = scheduler.kv_cache_manager.block_pool.get_num_free_blocks()
|
NUM_TOTAL_BLOCKS = scheduler.kv_cache_manager.block_pool.get_num_free_blocks()
|
||||||
BLOCK_SIZE = scheduler.cache_config.block_size
|
|
||||||
|
|
||||||
# Mock External Cache Hit.
|
|
||||||
NUM_MATCHED_NEW_TOKENS = BLOCK_SIZE * 2
|
|
||||||
scheduler.connector.get_num_new_matched_tokens = Mock(name="method")
|
|
||||||
scheduler.connector.get_num_new_matched_tokens.return_value = (
|
|
||||||
NUM_MATCHED_NEW_TOKENS,
|
|
||||||
False,
|
|
||||||
)
|
|
||||||
|
|
||||||
######################################################
|
######################################################
|
||||||
# FIRST SET OF REQUESTS - External Hit Only
|
# FIRST SET OF REQUESTS - External Hit Only
|
||||||
@ -928,6 +966,9 @@ def test_kv_connector_basic():
|
|||||||
req_ids.append(request.request_id)
|
req_ids.append(request.request_id)
|
||||||
req_to_index[request.request_id] = i
|
req_to_index[request.request_id] = i
|
||||||
|
|
||||||
|
if is_async:
|
||||||
|
_step_until_kv_transfer_finished(scheduler, req_ids)
|
||||||
|
|
||||||
MODEL_RUNNER_OUTPUT = ModelRunnerOutput(
|
MODEL_RUNNER_OUTPUT = ModelRunnerOutput(
|
||||||
req_ids=req_ids,
|
req_ids=req_ids,
|
||||||
req_id_to_index=req_to_index,
|
req_id_to_index=req_to_index,
|
||||||
@ -978,6 +1019,9 @@ def test_kv_connector_basic():
|
|||||||
req_ids.append(request.request_id)
|
req_ids.append(request.request_id)
|
||||||
req_to_index[request.request_id] = i
|
req_to_index[request.request_id] = i
|
||||||
|
|
||||||
|
if is_async:
|
||||||
|
_step_until_kv_transfer_finished(scheduler, req_ids)
|
||||||
|
|
||||||
MODEL_RUNNER_OUTPUT = ModelRunnerOutput(
|
MODEL_RUNNER_OUTPUT = ModelRunnerOutput(
|
||||||
req_ids=req_ids,
|
req_ids=req_ids,
|
||||||
req_id_to_index=req_to_index,
|
req_id_to_index=req_to_index,
|
||||||
@ -1020,17 +1064,10 @@ def test_external_prefix_cache_metrics():
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
# Setup Scheduler.
|
# Setup Scheduler.
|
||||||
|
NUM_MATCHED_NEW_TOKENS = 4
|
||||||
scheduler = create_scheduler(
|
scheduler = create_scheduler(
|
||||||
enable_prefix_caching=False,
|
enable_prefix_caching=False,
|
||||||
use_kv_connector=True,
|
use_kv_connector=mock_kv(matched_tokens=NUM_MATCHED_NEW_TOKENS, is_async=False),
|
||||||
)
|
|
||||||
|
|
||||||
# Mock connector to simulate a partial external cache hit
|
|
||||||
NUM_MATCHED_NEW_TOKENS = 4
|
|
||||||
scheduler.connector.get_num_new_matched_tokens = Mock(name="method")
|
|
||||||
scheduler.connector.get_num_new_matched_tokens.return_value = (
|
|
||||||
NUM_MATCHED_NEW_TOKENS,
|
|
||||||
False,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# --- Prepare simple requests ---
|
# --- Prepare simple requests ---
|
||||||
@ -1085,21 +1122,16 @@ def test_kv_connector_unable_to_allocate(use_ec_connector, ec_role):
|
|||||||
# Setup Scheduler With Mock External Cache Hit.
|
# Setup Scheduler With Mock External Cache Hit.
|
||||||
BLOCK_SIZE = 4
|
BLOCK_SIZE = 4
|
||||||
NUM_BLOCKS = 10
|
NUM_BLOCKS = 10
|
||||||
|
NUM_MATCHED_NEW_TOKENS = BLOCK_SIZE * 2
|
||||||
scheduler = create_scheduler(
|
scheduler = create_scheduler(
|
||||||
enable_prefix_caching=True,
|
enable_prefix_caching=True,
|
||||||
use_kv_connector=True,
|
use_kv_connector=mock_kv(matched_tokens=NUM_MATCHED_NEW_TOKENS, is_async=False),
|
||||||
block_size=BLOCK_SIZE,
|
block_size=BLOCK_SIZE,
|
||||||
num_blocks=NUM_BLOCKS,
|
num_blocks=NUM_BLOCKS,
|
||||||
# encoder connector should not affect test results
|
# encoder connector should not affect test results
|
||||||
use_ec_connector=use_ec_connector,
|
use_ec_connector=use_ec_connector,
|
||||||
ec_role=ec_role,
|
ec_role=ec_role,
|
||||||
)
|
)
|
||||||
NUM_MATCHED_NEW_TOKENS = BLOCK_SIZE * 2
|
|
||||||
scheduler.connector.get_num_new_matched_tokens = Mock(name="method")
|
|
||||||
scheduler.connector.get_num_new_matched_tokens.return_value = (
|
|
||||||
NUM_MATCHED_NEW_TOKENS,
|
|
||||||
False,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Create two requests. The second request will not be able to
|
# Create two requests. The second request will not be able to
|
||||||
# allocate slots because it will not have enough blocks.
|
# allocate slots because it will not have enough blocks.
|
||||||
@ -1174,9 +1206,10 @@ def test_kv_connector_handles_preemption(use_ec_connector, ec_role):
|
|||||||
BLOCK_SIZE = 2
|
BLOCK_SIZE = 2
|
||||||
# NOTE: there is 1 null block, so this is 6 blocks.
|
# NOTE: there is 1 null block, so this is 6 blocks.
|
||||||
NUM_BLOCKS = 7
|
NUM_BLOCKS = 7
|
||||||
|
NUM_MATCHED_NEW_TOKENS = BLOCK_SIZE
|
||||||
scheduler = create_scheduler(
|
scheduler = create_scheduler(
|
||||||
enable_prefix_caching=True,
|
enable_prefix_caching=True,
|
||||||
use_kv_connector=True,
|
use_kv_connector=mock_kv(matched_tokens=NUM_MATCHED_NEW_TOKENS, is_async=False),
|
||||||
block_size=BLOCK_SIZE,
|
block_size=BLOCK_SIZE,
|
||||||
num_blocks=NUM_BLOCKS,
|
num_blocks=NUM_BLOCKS,
|
||||||
# encoder connector should not affect test results
|
# encoder connector should not affect test results
|
||||||
@ -1184,13 +1217,6 @@ def test_kv_connector_handles_preemption(use_ec_connector, ec_role):
|
|||||||
ec_role=ec_role,
|
ec_role=ec_role,
|
||||||
)
|
)
|
||||||
|
|
||||||
NUM_MATCHED_NEW_TOKENS = BLOCK_SIZE
|
|
||||||
scheduler.connector.get_num_new_matched_tokens = Mock(name="method")
|
|
||||||
scheduler.connector.get_num_new_matched_tokens.return_value = (
|
|
||||||
NUM_MATCHED_NEW_TOKENS,
|
|
||||||
False,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Create two requests.
|
# Create two requests.
|
||||||
# Both can be scheduled at first, but the second request
|
# Both can be scheduled at first, but the second request
|
||||||
# will be preempted and re-scheduled.
|
# will be preempted and re-scheduled.
|
||||||
|
|||||||
@ -3,6 +3,7 @@
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from tests.v1.kv_connector.unit.utils import MockKVConfig
|
||||||
from vllm.config import (
|
from vllm.config import (
|
||||||
CacheConfig,
|
CacheConfig,
|
||||||
ECTransferConfig,
|
ECTransferConfig,
|
||||||
@ -33,6 +34,10 @@ from vllm.v1.structured_output import StructuredOutputManager
|
|||||||
EOS_TOKEN_ID = 50256
|
EOS_TOKEN_ID = 50256
|
||||||
|
|
||||||
|
|
||||||
|
def mock_kv(matched_tokens: int, is_async: bool):
|
||||||
|
return MockKVConfig(matched_tokens=matched_tokens, is_async=is_async)
|
||||||
|
|
||||||
|
|
||||||
def create_scheduler(
|
def create_scheduler(
|
||||||
model: str = "facebook/opt-125m",
|
model: str = "facebook/opt-125m",
|
||||||
max_num_seqs: int = 16,
|
max_num_seqs: int = 16,
|
||||||
@ -40,7 +45,7 @@ def create_scheduler(
|
|||||||
enable_prefix_caching: bool | None = None,
|
enable_prefix_caching: bool | None = None,
|
||||||
long_prefill_token_threshold: int = 0,
|
long_prefill_token_threshold: int = 0,
|
||||||
disable_chunked_mm_input: bool = False,
|
disable_chunked_mm_input: bool = False,
|
||||||
use_kv_connector: bool = False,
|
use_kv_connector: None | bool | MockKVConfig = None,
|
||||||
num_blocks: int = 10000,
|
num_blocks: int = 10000,
|
||||||
block_size: int = 16,
|
block_size: int = 16,
|
||||||
max_model_len: int | None = None,
|
max_model_len: int | None = None,
|
||||||
@ -94,15 +99,22 @@ def create_scheduler(
|
|||||||
cache_dtype="auto",
|
cache_dtype="auto",
|
||||||
**kwargs_cache,
|
**kwargs_cache,
|
||||||
)
|
)
|
||||||
kv_transfer_config = (
|
kv_transfer_config = None
|
||||||
KVTransferConfig(
|
if isinstance(use_kv_connector, MockKVConfig):
|
||||||
|
kv_transfer_config = KVTransferConfig(
|
||||||
|
kv_connector="MockKVConnector",
|
||||||
|
kv_role="kv_both",
|
||||||
|
kv_connector_extra_config={
|
||||||
|
"matched_tokens": use_kv_connector.matched_tokens,
|
||||||
|
"is_async": use_kv_connector.is_async,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
elif use_kv_connector:
|
||||||
|
kv_transfer_config = KVTransferConfig(
|
||||||
kv_connector="SharedStorageConnector",
|
kv_connector="SharedStorageConnector",
|
||||||
kv_role="kv_both",
|
kv_role="kv_both",
|
||||||
kv_connector_extra_config={"shared_storage_path": "local_storage"},
|
kv_connector_extra_config={"shared_storage_path": "local_storage"},
|
||||||
)
|
)
|
||||||
if use_kv_connector
|
|
||||||
else None
|
|
||||||
)
|
|
||||||
|
|
||||||
speculative_config: SpeculativeConfig | None = None
|
speculative_config: SpeculativeConfig | None = None
|
||||||
if num_speculative_tokens is not None:
|
if num_speculative_tokens is not None:
|
||||||
|
|||||||
@ -3,7 +3,8 @@
|
|||||||
import tempfile
|
import tempfile
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
from itertools import count
|
from dataclasses import dataclass
|
||||||
|
from itertools import chain, count
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@ -18,13 +19,18 @@ from vllm.config import (
|
|||||||
VllmConfig,
|
VllmConfig,
|
||||||
)
|
)
|
||||||
from vllm.distributed.kv_transfer.kv_connector.factory import KVConnectorFactory
|
from vllm.distributed.kv_transfer.kv_connector.factory import KVConnectorFactory
|
||||||
|
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
|
||||||
|
KVConnectorBase_V1,
|
||||||
|
KVConnectorMetadata,
|
||||||
|
KVConnectorRole,
|
||||||
|
)
|
||||||
from vllm.distributed.kv_transfer.kv_connector.v1.shared_storage_connector import ( # noqa
|
from vllm.distributed.kv_transfer.kv_connector.v1.shared_storage_connector import ( # noqa
|
||||||
SharedStorageConnector,
|
SharedStorageConnector,
|
||||||
)
|
)
|
||||||
from vllm.utils.hashing import sha256
|
from vllm.utils.hashing import sha256
|
||||||
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
|
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
|
||||||
from vllm.v1.core.kv_cache_utils import get_request_block_hasher, init_none_hash
|
from vllm.v1.core.kv_cache_utils import get_request_block_hasher, init_none_hash
|
||||||
from vllm.v1.core.sched.scheduler import Scheduler
|
from vllm.v1.core.sched.scheduler import Scheduler, SchedulerOutput
|
||||||
from vllm.v1.kv_cache_interface import (
|
from vllm.v1.kv_cache_interface import (
|
||||||
FullAttentionSpec,
|
FullAttentionSpec,
|
||||||
KVCacheConfig,
|
KVCacheConfig,
|
||||||
@ -307,6 +313,82 @@ class TestSharedStorageConnector(SharedStorageConnector):
|
|||||||
return attr
|
return attr
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class MockKVConfig:
|
||||||
|
matched_tokens: int = 0
|
||||||
|
is_async: bool = False
|
||||||
|
|
||||||
|
|
||||||
|
class MockKVConnectorMetadata(KVConnectorMetadata):
|
||||||
|
def __init__(self):
|
||||||
|
# Scheduler tests check metadata.requests
|
||||||
|
self.requests: list = []
|
||||||
|
|
||||||
|
|
||||||
|
class MockKVConnector(KVConnectorBase_V1):
|
||||||
|
"""Mock KV connector for scheduler tests, supporting both sync and async mode."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
vllm_config: VllmConfig,
|
||||||
|
role: KVConnectorRole,
|
||||||
|
kv_cache_config: KVCacheConfig | None = None,
|
||||||
|
):
|
||||||
|
super().__init__(vllm_config, role, kv_cache_config)
|
||||||
|
extra_config = self._kv_transfer_config.kv_connector_extra_config
|
||||||
|
self.config = MockKVConfig(
|
||||||
|
matched_tokens=extra_config["matched_tokens"],
|
||||||
|
is_async=extra_config["is_async"],
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_num_new_matched_tokens(
|
||||||
|
self,
|
||||||
|
request: Request,
|
||||||
|
num_computed_tokens: int,
|
||||||
|
) -> tuple[int | None, bool]:
|
||||||
|
return (self.config.matched_tokens, self.config.is_async)
|
||||||
|
|
||||||
|
def update_state_after_alloc(
|
||||||
|
self,
|
||||||
|
request: Request,
|
||||||
|
blocks: KVCacheBlocks,
|
||||||
|
num_external_tokens: int,
|
||||||
|
):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def build_connector_meta(
|
||||||
|
self, scheduler_output: SchedulerOutput
|
||||||
|
) -> KVConnectorMetadata:
|
||||||
|
metadata = MockKVConnectorMetadata()
|
||||||
|
cached_reqs = scheduler_output.scheduled_cached_reqs
|
||||||
|
for req_id in chain(
|
||||||
|
(req.req_id for req in scheduler_output.scheduled_new_reqs),
|
||||||
|
(
|
||||||
|
req_id
|
||||||
|
for req_id in cached_reqs.req_ids
|
||||||
|
if req_id in cached_reqs.resumed_req_ids
|
||||||
|
),
|
||||||
|
):
|
||||||
|
metadata.requests.append({"req_id": req_id})
|
||||||
|
return metadata
|
||||||
|
|
||||||
|
def start_load_kv(self, kv_caches, finished_req_ids):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def wait_for_layer_load(self, layer_name):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def save_kv_layer(self, layer_name, kv_layer, attn_metadata, **kwargs):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def wait_for_save(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
KVConnectorFactory.register_connector(
|
KVConnectorFactory.register_connector(
|
||||||
"TestSharedStorageConnector", __name__, TestSharedStorageConnector.__name__
|
"TestSharedStorageConnector", __name__, TestSharedStorageConnector.__name__
|
||||||
)
|
)
|
||||||
|
|
||||||
|
KVConnectorFactory.register_connector(
|
||||||
|
"MockKVConnector", __name__, MockKVConnector.__name__
|
||||||
|
)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user