From 6c2eef5a5d95b3d213856626a40bafb2c5cb1e18 Mon Sep 17 00:00:00 2001 From: Tyler Michael Smith Date: Tue, 21 Oct 2025 19:30:47 -0400 Subject: [PATCH] [P/D] KVConnector for decode benchmarking (#25986) Signed-off-by: Tyler Michael Smith Signed-off-by: Tyler Michael Smith --- .../unit/test_decode_bench_connector.py | 415 ++++++++++++++++++ .../kv_transfer/kv_connector/factory.py | 6 + .../kv_transfer/kv_connector/v1/__init__.py | 5 +- .../kv_connector/v1/decode_bench_connector.py | 413 +++++++++++++++++ 4 files changed, 838 insertions(+), 1 deletion(-) create mode 100644 tests/v1/kv_connector/unit/test_decode_bench_connector.py create mode 100644 vllm/distributed/kv_transfer/kv_connector/v1/decode_bench_connector.py diff --git a/tests/v1/kv_connector/unit/test_decode_bench_connector.py b/tests/v1/kv_connector/unit/test_decode_bench_connector.py new file mode 100644 index 0000000000000..24802317a2bbc --- /dev/null +++ b/tests/v1/kv_connector/unit/test_decode_bench_connector.py @@ -0,0 +1,415 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +Unit tests for DecodeBenchConnector. + +Tests the functionality of the DecodeBenchConnector which fills KV cache +with dummy values for decode performance benchmarking. +""" + +import pytest +import torch + +from vllm import SamplingParams +from vllm.config import KVTransferConfig +from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorRole + +# ruff: noqa: E501 +from vllm.distributed.kv_transfer.kv_connector.v1.decode_bench_connector import ( + DecodeBenchConnector, + DecodeBenchConnectorMetadata, +) +from vllm.forward_context import ForwardContext +from vllm.utils.hashing import sha256 +from vllm.v1.core.kv_cache_utils import get_request_block_hasher, init_none_hash +from vllm.v1.core.sched.scheduler import Scheduler +from vllm.v1.request import Request + +from .utils import ( + EOS_TOKEN_ID, + create_model_runner_output, + create_scheduler, + create_vllm_config, +) + + +class DecodeBenchTestRunner: + """Test runner for DecodeBenchConnector.""" + + def __init__(self, block_size: int, num_gpu_blocks: int): + self.block_size = block_size + self.num_gpu_blocks = num_gpu_blocks + + self.req_id = -1 + + # Create vllm config with DecodeBenchConnector + vllm_config = create_vllm_config( + block_size=block_size, max_num_batched_tokens=1000 + ) + vllm_config.kv_transfer_config = KVTransferConfig( + kv_connector="DecodeBenchConnector", + kv_role="kv_both", + ) + + self.vllm_config = vllm_config + self.scheduler: Scheduler = create_scheduler( + vllm_config, num_blocks=num_gpu_blocks + ) + + # Create worker-side connector + self.worker_connector = DecodeBenchConnector( + vllm_config, KVConnectorRole.WORKER + ) + + # Create dummy KV caches for testing + # Shape: [num_blocks, 2, num_heads, block_size, head_dim] + # Using simplified shape for testing + num_heads = 4 + head_dim = 64 + self.kv_caches = { + f"layer_{i}": torch.zeros( + num_gpu_blocks, 2, num_heads, block_size, head_dim + ) + for i in range(2) # 2 layers for testing + } + + # Register KV caches with worker connector + self.worker_connector.register_kv_caches(self.kv_caches) + + # Extract scheduler-side connector + scheduler_connector = self.scheduler.connector + assert scheduler_connector is not None + assert isinstance(scheduler_connector, DecodeBenchConnector) + self.scheduler_connector: DecodeBenchConnector = scheduler_connector + + init_none_hash(sha256) + self._block_hasher = get_request_block_hasher(block_size, sha256) + + self._dummy_ctx: ForwardContext = ForwardContext( + no_compile_layers={}, attn_metadata={}, virtual_engine=0 + ) + + def new_request(self, token_ids: list[int]) -> Request: + """Create a new request with given token IDs.""" + self.req_id += 1 + + req = Request( + request_id=str(self.req_id), + prompt_token_ids=token_ids, + sampling_params=SamplingParams(max_tokens=100), + pooling_params=None, + eos_token_id=EOS_TOKEN_ID, + block_hasher=self._block_hasher, + ) + + self.scheduler.add_request(req) + return req + + def run_single_step(self, token_id: int = 0): + """Run a single scheduler + worker step.""" + scheduler_output = self.scheduler.schedule() + + # Get connector metadata + kv_connector_metadata = scheduler_output.kv_connector_metadata + assert kv_connector_metadata is not None + assert isinstance(kv_connector_metadata, DecodeBenchConnectorMetadata) + + # Bind metadata and load KV + self.worker_connector.bind_connector_metadata(kv_connector_metadata) + self.worker_connector.start_load_kv(self._dummy_ctx) + + if scheduler_output.total_num_scheduled_tokens > 0: + self.worker_connector.wait_for_save() + + self.worker_connector.clear_connector_metadata() + + # Create model runner output + model_runner_output = create_model_runner_output( + reqs=self.scheduler.running, + token_id=token_id, + ) + + self.scheduler.update_from_output(scheduler_output, model_runner_output) + + return scheduler_output, kv_connector_metadata + + +def test_decode_bench_connector_basic(): + """Test basic functionality of DecodeBenchConnector.""" + block_size = 16 + num_gpu_blocks = 100 + + runner = DecodeBenchTestRunner(block_size=block_size, num_gpu_blocks=num_gpu_blocks) + + # Create a request with multiple blocks worth of tokens + num_tokens = block_size * 3 # 3 blocks + token_ids = [1] * num_tokens + + req = runner.new_request(token_ids) + + # Run first step - should fill KV cache with dummy values + scheduler_output, metadata = runner.run_single_step() + + # Check that get_num_new_matched_tokens returned correct value + # Should be num_tokens - 1 (all except the last token for decode) + expected_fill_tokens = num_tokens - 1 + + # Check metadata has the request to fill + assert len(metadata.reqs_to_fill) == 1 + assert req.request_id in metadata.reqs_to_fill + + block_ids_per_group, num_tokens_to_fill = metadata.reqs_to_fill[req.request_id] + assert num_tokens_to_fill == expected_fill_tokens + + # For standard attention, there's only one group + assert len(block_ids_per_group) == 1 + block_ids = block_ids_per_group[0] + + # Calculate expected number of blocks + expected_num_blocks = (expected_fill_tokens + block_size - 1) // block_size + assert len(block_ids) == expected_num_blocks + + # Verify KV caches were filled with constant value + for layer_name, kv_cache in runner.kv_caches.items(): + for block_id in block_ids: + # Check that the block was filled + block_data = kv_cache[block_id] + # Should be filled with constant value 0.015 + assert torch.allclose(block_data, torch.tensor(0.015)) + + +def test_decode_bench_connector_no_refill(): + """Test that DecodeBenchConnector only fills once per request.""" + block_size = 16 + num_gpu_blocks = 100 + + runner = DecodeBenchTestRunner(block_size=block_size, num_gpu_blocks=num_gpu_blocks) + + # Create a request + num_tokens = block_size * 2 + token_ids = [1] * num_tokens + + runner.new_request(token_ids) + + # Run first step - should fill KV cache + _, metadata1 = runner.run_single_step() + assert len(metadata1.reqs_to_fill) == 1 + + # Run second step - should NOT fill again (already filled) + _, metadata2 = runner.run_single_step() + assert len(metadata2.reqs_to_fill) == 0 + + +def test_decode_bench_connector_single_token(): + """Test DecodeBenchConnector with single token request.""" + block_size = 16 + num_gpu_blocks = 100 + + runner = DecodeBenchTestRunner(block_size=block_size, num_gpu_blocks=num_gpu_blocks) + + # Create a request with just 1 token + # Should not fill anything (need at least 2 tokens: 1 to fill, 1 to decode) + token_ids = [1] + + runner.new_request(token_ids) + + # Run step - should NOT fill KV cache + _, metadata = runner.run_single_step() + assert len(metadata.reqs_to_fill) == 0 + + +def test_decode_bench_connector_two_tokens(): + """Test DecodeBenchConnector with two token request.""" + block_size = 16 + num_gpu_blocks = 100 + + runner = DecodeBenchTestRunner(block_size=block_size, num_gpu_blocks=num_gpu_blocks) + + # Create a request with 2 tokens + # Should fill 1 token (first token), decode the second + token_ids = [1, 2] + + req = runner.new_request(token_ids) + + # Run step + _, metadata = runner.run_single_step() + + assert len(metadata.reqs_to_fill) == 1 + assert req.request_id in metadata.reqs_to_fill + + block_ids_per_group, num_tokens_to_fill = metadata.reqs_to_fill[req.request_id] + assert num_tokens_to_fill == 1 + # For standard attention, there's only one group + assert len(block_ids_per_group) == 1 + assert len(block_ids_per_group[0]) == 1 # 1 token needs 1 block + + +def test_decode_bench_connector_large_context(): + """Test DecodeBenchConnector with large context size.""" + block_size = 16 + num_gpu_blocks = 1000 + + runner = DecodeBenchTestRunner(block_size=block_size, num_gpu_blocks=num_gpu_blocks) + + # Create a request with many blocks + num_blocks = 20 + num_tokens = block_size * num_blocks + token_ids = list(range(num_tokens)) + + req = runner.new_request(token_ids) + + # Run step + _, metadata = runner.run_single_step() + + assert len(metadata.reqs_to_fill) == 1 + assert req.request_id in metadata.reqs_to_fill + + block_ids_per_group, num_tokens_to_fill = metadata.reqs_to_fill[req.request_id] + + # Should fill all tokens except the last one + expected_fill_tokens = num_tokens - 1 + assert num_tokens_to_fill == expected_fill_tokens + + # For standard attention, there's only one group + assert len(block_ids_per_group) == 1 + block_ids = block_ids_per_group[0] + + # Calculate expected number of blocks + expected_num_blocks = (expected_fill_tokens + block_size - 1) // block_size + assert len(block_ids) == expected_num_blocks + + # Verify blocks were filled + for layer_name, kv_cache in runner.kv_caches.items(): + for block_id in block_ids: + block_data = kv_cache[block_id] + assert torch.allclose(block_data, torch.tensor(0.015)) + + +def test_decode_bench_connector_multiple_requests(): + """Test DecodeBenchConnector with multiple sequential requests.""" + block_size = 16 + num_gpu_blocks = 100 + + runner = DecodeBenchTestRunner(block_size=block_size, num_gpu_blocks=num_gpu_blocks) + + # First request + req1 = runner.new_request([1] * (block_size * 2)) + _, metadata1 = runner.run_single_step() + + assert len(metadata1.reqs_to_fill) == 1 + assert req1.request_id in metadata1.reqs_to_fill + + # Complete first request + while runner.scheduler.running: + runner.run_single_step() + + # Add EOS to finish + scheduler_output = runner.scheduler.schedule() + model_runner_output = create_model_runner_output( + reqs=runner.scheduler.running, + token_id=EOS_TOKEN_ID, + use_eos=True, + ) + runner.scheduler.update_from_output(scheduler_output, model_runner_output) + + # Second request - should also get filled + req2 = runner.new_request([2] * (block_size * 3)) + _, metadata2 = runner.run_single_step() + + assert len(metadata2.reqs_to_fill) == 1 + assert req2.request_id in metadata2.reqs_to_fill + + # Different request should have different metadata + _, num_tokens1 = metadata1.reqs_to_fill[req1.request_id] + _, num_tokens2 = metadata2.reqs_to_fill[req2.request_id] + + assert num_tokens1 == block_size * 2 - 1 + assert num_tokens2 == block_size * 3 - 1 + + +def test_decode_bench_connector_partial_block(): + """Test DecodeBenchConnector with partial block filling.""" + block_size = 16 + num_gpu_blocks = 100 + + runner = DecodeBenchTestRunner(block_size=block_size, num_gpu_blocks=num_gpu_blocks) + + # Create a request that doesn't align to block boundaries + # e.g., 2.5 blocks worth of tokens + num_tokens = block_size * 2 + block_size // 2 + token_ids = [1] * num_tokens + + req = runner.new_request(token_ids) + + # Run step + _, metadata = runner.run_single_step() + + assert len(metadata.reqs_to_fill) == 1 + assert req.request_id in metadata.reqs_to_fill + + block_ids_per_group, num_tokens_to_fill = metadata.reqs_to_fill[req.request_id] + + # Should fill all tokens except the last one + expected_fill_tokens = num_tokens - 1 + assert num_tokens_to_fill == expected_fill_tokens + + # For standard attention, there's only one group + assert len(block_ids_per_group) == 1 + block_ids = block_ids_per_group[0] + + # Should allocate 3 blocks to hold the partial data + expected_num_blocks = 3 + assert len(block_ids) == expected_num_blocks + + +def test_decode_bench_connector_concurrent_requests(): + """Test DecodeBenchConnector with multiple concurrent requests in the same batch.""" + block_size = 16 + num_gpu_blocks = 1000 + + runner = DecodeBenchTestRunner(block_size=block_size, num_gpu_blocks=num_gpu_blocks) + + # Create multiple requests that will be batched together + req1 = runner.new_request([1] * (block_size * 2)) + req2 = runner.new_request([2] * (block_size * 3)) + req3 = runner.new_request([3] * (block_size * 1)) + + # Run first step - all requests should be filled concurrently + _, metadata = runner.run_single_step() + + # All three requests should be in the metadata + assert len(metadata.reqs_to_fill) == 3 + assert req1.request_id in metadata.reqs_to_fill + assert req2.request_id in metadata.reqs_to_fill + assert req3.request_id in metadata.reqs_to_fill + + # Verify each request has correct fill info + block_ids_per_group1, num_tokens1 = metadata.reqs_to_fill[req1.request_id] + block_ids_per_group2, num_tokens2 = metadata.reqs_to_fill[req2.request_id] + block_ids_per_group3, num_tokens3 = metadata.reqs_to_fill[req3.request_id] + + # Verify token counts (all tokens except last one) + assert num_tokens1 == block_size * 2 - 1 + assert num_tokens2 == block_size * 3 - 1 + assert num_tokens3 == block_size * 1 - 1 + + # Verify block counts for each request + assert len(block_ids_per_group1[0]) == 2 # 2 blocks + assert len(block_ids_per_group2[0]) == 3 # 3 blocks + assert len(block_ids_per_group3[0]) == 1 # 1 block + + # Verify all blocks are filled in KV cache + for req_id, (block_ids_per_group, _) in metadata.reqs_to_fill.items(): + block_ids = block_ids_per_group[0] + for layer_name, kv_cache in runner.kv_caches.items(): + for block_id in block_ids: + block_data = kv_cache[block_id] + assert torch.allclose(block_data, torch.tensor(0.015)) + + # Run second step - should NOT fill again (already filled) + _, metadata2 = runner.run_single_step() + assert len(metadata2.reqs_to_fill) == 0 + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/vllm/distributed/kv_transfer/kv_connector/factory.py b/vllm/distributed/kv_transfer/kv_connector/factory.py index ff806962028c0..5ef56f6c381f4 100644 --- a/vllm/distributed/kv_transfer/kv_connector/factory.py +++ b/vllm/distributed/kv_transfer/kv_connector/factory.py @@ -130,3 +130,9 @@ KVConnectorFactory.register_connector( "vllm.distributed.kv_transfer.kv_connector.v1.offloading_connector", "OffloadingConnector", ) + +KVConnectorFactory.register_connector( + "DecodeBenchConnector", + "vllm.distributed.kv_transfer.kv_connector.v1.decode_bench_connector", + "DecodeBenchConnector", +) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/__init__.py b/vllm/distributed/kv_transfer/kv_connector/v1/__init__.py index 034c7afe97a48..bb558c9560297 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/__init__.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/__init__.py @@ -4,5 +4,8 @@ from vllm.distributed.kv_transfer.kv_connector.v1.base import ( KVConnectorBase_V1, KVConnectorRole, ) +from vllm.distributed.kv_transfer.kv_connector.v1.decode_bench_connector import ( # noqa E:501 + DecodeBenchConnector, +) -__all__ = ["KVConnectorRole", "KVConnectorBase_V1"] +__all__ = ["KVConnectorRole", "KVConnectorBase_V1", "DecodeBenchConnector"] diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/decode_bench_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/decode_bench_connector.py new file mode 100644 index 0000000000000..17c00b9c3d0ef --- /dev/null +++ b/vllm/distributed/kv_transfer/kv_connector/v1/decode_bench_connector.py @@ -0,0 +1,413 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +DecodeBenchConnector: A KV Connector for decode instance performance testing. + +This connector emulates a prefill-decode disaggregated setting by filling +the KV cache with dummy values, allowing measurement of decoder performance +under larger input sequence lengths (ISL) in resource-limited environments. + +Usage: + To use this connector for benchmarking, configure it in the kv_transfer_config: + + Example: + vllm serve --kv-transfer-config '{ + "kv_connector": "DecodeBenchConnector", + "kv_role": "kv_both", + "kv_connector_extra_config": { + "fill_mean": 0.015, + "fill_std": 0.0 + } + }' + + Then run your benchmark with desired input/output lengths: + vllm bench serve --base-url http://127.0.0.1:8000 --model \\ + --dataset-name random --random-input-len 40000 \\ + --random-output-len 100 --max-concurrency 10 + + Configuration options (via kv_connector_extra_config): + - fill_mean (float): Mean value for random normal fill (default: 0.015) + - fill_std (float): Standard deviation for random fill (default: 0.0) + Set to 0 for constant values, >0 for random sampling +""" + +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any + +import torch + +from vllm.distributed.kv_transfer.kv_connector.v1 import ( + KVConnectorBase_V1, + KVConnectorRole, +) +from vllm.distributed.kv_transfer.kv_connector.v1.base import KVConnectorMetadata +from vllm.logger import init_logger +from vllm.utils import cdiv + +if TYPE_CHECKING: + from vllm.attention.backends.abstract import AttentionMetadata + from vllm.config import VllmConfig + from vllm.forward_context import ForwardContext + from vllm.v1.core.kv_cache_manager import KVCacheBlocks + from vllm.v1.core.sched.output import SchedulerOutput + from vllm.v1.request import Request + +logger = init_logger(__name__) + + +@dataclass +class DecodeBenchConnectorMetadata(KVConnectorMetadata): + """Metadata for DecodeBenchConnector. + + Contains information about which requests need their KV cache filled + with dummy values for benchmarking purposes. + """ + + # request_id -> (block_ids_per_group, num_tokens_to_fill) + # block_ids_per_group is a tuple of lists, one per KV cache group + # For standard attention: single group, e.g., ([1, 2, 3],) + # For MLA: multiple groups, e.g., ([1, 2], [1, 2]) + reqs_to_fill: dict[str, tuple[tuple[list[int], ...], int]] + + +class DecodeBenchConnector(KVConnectorBase_V1): + """ + A KV Connector for decode instance performance testing. + + This connector fills the KV cache with dummy (non-zero) values to + emulate a prefill-decode disaggregated setting, enabling performance + testing of the decoder with larger input sequence lengths. + """ + + def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole): + super().__init__(vllm_config, role) + + self.connector_scheduler: DecodeBenchConnectorScheduler | None = None + self.connector_worker: DecodeBenchConnectorWorker | None = None + + if role == KVConnectorRole.SCHEDULER: + self.connector_scheduler = DecodeBenchConnectorScheduler(vllm_config) + elif role == KVConnectorRole.WORKER: + self.connector_worker = DecodeBenchConnectorWorker(vllm_config) + + # ============================== + # Worker-side methods + # ============================== + + def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): + assert self.connector_worker is not None + self.connector_worker.register_kv_caches(kv_caches) + + def start_load_kv(self, forward_context: "ForwardContext", **kwargs: Any) -> None: + assert self.connector_worker is not None + assert isinstance(self._connector_metadata, DecodeBenchConnectorMetadata) + self.connector_worker.start_fill_kv(self._connector_metadata) + + def wait_for_layer_load(self, layer_name: str) -> None: + # All operations are synchronous, so nothing to wait for + pass + + def save_kv_layer( + self, + layer_name: str, + kv_layer: torch.Tensor, + attn_metadata: "AttentionMetadata", + **kwargs: Any, + ) -> None: + # This connector doesn't save KV cache (benchmarking only) + pass + + def wait_for_save(self): + # This connector doesn't save KV cache (benchmarking only) + pass + + # ============================== + # Scheduler-side methods + # ============================== + + def get_num_new_matched_tokens( + self, + request: "Request", + num_computed_tokens: int, + ) -> tuple[int | None, bool]: + assert self.connector_scheduler is not None + return self.connector_scheduler.get_num_new_matched_tokens( + request, num_computed_tokens + ) + + def update_state_after_alloc( + self, request: "Request", blocks: "KVCacheBlocks", num_external_tokens: int + ): + assert self.connector_scheduler is not None + return self.connector_scheduler.update_state_after_alloc( + request, blocks, num_external_tokens + ) + + def build_connector_meta( + self, scheduler_output: "SchedulerOutput" + ) -> KVConnectorMetadata: + assert self.connector_scheduler is not None + return self.connector_scheduler.build_connector_meta(scheduler_output) + + def request_finished( + self, + request: "Request", + block_ids: list[int], + ) -> tuple[bool, dict[str, Any] | None]: + assert self.connector_scheduler is not None + self.connector_scheduler.request_finished(request) + return False, None + + +class DecodeBenchConnectorScheduler: + """Scheduler-side implementation for DecodeBenchConnector.""" + + def __init__(self, vllm_config: "VllmConfig"): + self.vllm_config = vllm_config + self.block_size = vllm_config.cache_config.block_size + + # Track which requests have already been filled + self._filled_requests: set[str] = set() + + # Track pending fills for the current scheduler step + # request_id -> (block_ids_per_group, num_tokens_to_fill) + # Note: _pending_fills doesn't need explicit cleanup - it's cleared + # after build_connector_meta() is called in the same scheduler step + self._pending_fills: dict[str, tuple[tuple[list[int], ...], int]] = {} + + def get_num_new_matched_tokens( + self, + request: "Request", + num_computed_tokens: int, + ) -> tuple[int, bool]: + """ + For new requests, return the number of tokens that should be filled + with dummy KV cache values. + + Returns: + (num_tokens_to_fill, is_async) + - num_tokens_to_fill: number of uncomputed tokens minus 1 + (we fill everything except the last token for decode) + - is_async: False (synchronous filling) + """ + req_id = request.request_id + + # Only fill once per request on first scheduling + if req_id in self._filled_requests: + return 0, False + + # Calculate how many tokens we need to fill + # Fill all uncomputed tokens except the last one (which will be decoded) + # This simulates having processed a long prefill + num_uncomputed_tokens = request.num_tokens - num_computed_tokens + num_tokens_to_fill = max(0, num_uncomputed_tokens - 1) + + if num_tokens_to_fill == 0: + return 0, False + + # Return False for synchronous operation - the fill is fast enough + # that async overhead isn't worth it + return num_tokens_to_fill, False + + def update_state_after_alloc( + self, request: "Request", blocks: "KVCacheBlocks", num_external_tokens: int + ): + """ + Called after blocks are allocated. Store the block IDs so we can + fill them with dummy values. + + Supports both standard attention (single KV cache group) and MLA + (multiple KV cache groups). + """ + req_id = request.request_id + + if num_external_tokens == 0: + return + + # Get the block IDs that were allocated + # block_groups is a tuple of lists, one per KV cache group + # For standard attention: 1 group + # For MLA: multiple groups (one per attention type) + block_groups = blocks.get_block_ids() + + # Calculate how many blocks we need to fill + # num_external_tokens are the tokens we said we'd provide + num_blocks_to_fill = cdiv(num_external_tokens, self.block_size) + + # Extract the first num_blocks_to_fill blocks from each group + # All groups should have the same block IDs for the same request + block_ids_per_group = tuple( + group_blocks[:num_blocks_to_fill] for group_blocks in block_groups + ) + + # Store the blocks to fill for all group. _pending_fills doesn't need cleanup + # as it's cleared after build_connector_meta + self._pending_fills[req_id] = ( + block_ids_per_group, + num_external_tokens, + ) + self._filled_requests.add(req_id) + + logger.debug( + "DecodeBenchConnector: Allocated %d blocks across %d KV cache groups " + "for request %s", + num_blocks_to_fill, + len(block_groups), + req_id, + ) + + def build_connector_meta( + self, scheduler_output: "SchedulerOutput" + ) -> KVConnectorMetadata: + """ + Build metadata containing information about which blocks to fill + with dummy KV values. + """ + meta = DecodeBenchConnectorMetadata(reqs_to_fill=self._pending_fills.copy()) + + # Clear pending fills after building metadata + self._pending_fills.clear() + + return meta + + def request_finished(self, request: "Request"): + """ + Called when a request has finished. Clean up any state. + """ + self._filled_requests.discard(request.request_id) + + +class DecodeBenchConnectorWorker: + """Worker-side implementation for DecodeBenchConnector.""" + + def __init__(self, vllm_config: "VllmConfig"): + self.vllm_config = vllm_config + self.block_size = vllm_config.cache_config.block_size + + # Get fill parameters from extra config + kv_transfer_config = vllm_config.kv_transfer_config + assert kv_transfer_config is not None + self.fill_mean = kv_transfer_config.get_from_extra_config("fill_mean", 0.015) + self.fill_std = kv_transfer_config.get_from_extra_config("fill_std", 0.0) + + # Will be populated via register_kv_caches + self.kv_caches: dict[str, torch.Tensor] | None = None + + # Mapping from KV cache group index to list of layer names in that group + self.group_to_layers: dict[int, list[str]] | None = None + + def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): + """Store references to the KV cache tensors and build group mapping.""" + self.kv_caches = kv_caches + + # For simplicity, assume all layers belong to group 0 (standard attention) + # For MLA models with multiple groups, the metadata will handle the mapping + # We just need to fill the blocks specified in the metadata + self.group_to_layers = {0: list(kv_caches.keys())} + + logger.debug( + "DecodeBenchConnector: Registered %d KV cache layers", + len(kv_caches), + ) + + def start_fill_kv(self, metadata: DecodeBenchConnectorMetadata): + """ + Fill the allocated KV cache blocks with dummy (non-zero) values. + + This simulates having a populated KV cache from a prefill phase, + allowing decode performance testing with larger context sizes. + + Supports both standard attention (single group) and MLA (multiple groups). + """ + if not metadata.reqs_to_fill: + return + + assert self.kv_caches is not None, "KV caches must be registered before filling" + assert self.group_to_layers is not None, "Group mapping must be initialized" + + for req_id, (block_ids_per_group, num_tokens) in metadata.reqs_to_fill.items(): + # Fill blocks for each KV cache group + for group_idx, block_ids in enumerate(block_ids_per_group): + self._fill_blocks(group_idx, block_ids, num_tokens) + + logger.debug( + "DecodeBenchConnector: Filled %d blocks (%d tokens) across %d groups " + "for request %s", + len(block_ids_per_group[0]) if block_ids_per_group else 0, + num_tokens, + len(block_ids_per_group), + req_id, + ) + + def _fill_blocks(self, group_idx: int, block_ids: list[int], num_tokens: int): + """ + Fill specified blocks with dummy non-zero values for a specific KV cache group. + + Args: + group_idx: The KV cache group index to fill + block_ids: List of block IDs to fill in this group + num_tokens: Total number of tokens to fill across these blocks + """ + if not block_ids: + return + + assert self.kv_caches is not None + assert self.group_to_layers is not None + + # Get the layers that belong to this group + layer_names = self.group_to_layers.get(group_idx, []) + + # Fill only the layers in this group + for layer_name in layer_names: + if layer_name not in self.kv_caches: + logger.warning( + "DecodeBenchConnector: Layer %s not found in KV caches", layer_name + ) + continue + + kv_cache = self.kv_caches[layer_name] + + # Convert block_ids to tensor on device + block_ids_tensor = torch.tensor( + block_ids, dtype=torch.long, device=kv_cache.device + ) + + # Filter invalid block IDs + valid_mask = block_ids_tensor < kv_cache.shape[0] + valid_block_ids = block_ids_tensor[valid_mask] + + if len(valid_block_ids) == 0: + continue + + # Create fill values - either constant or random + block_shape = kv_cache.shape[1:] + if self.fill_std > 0: + # Random normal sampling + fill_values = torch.normal( + mean=self.fill_mean, + std=self.fill_std, + size=(len(valid_block_ids),) + block_shape, + dtype=kv_cache.dtype, + device=kv_cache.device, + ) + else: + # Constant fill value + fill_values = torch.full( + (len(valid_block_ids),) + block_shape, + self.fill_mean, + dtype=kv_cache.dtype, + device=kv_cache.device, + ) + + # Batch fill operation + kv_cache[valid_block_ids] = fill_values + + logger.debug( + "DecodeBenchConnector: Filled %d blocks in group %d with %s values " + "(mean=%.3f, std=%.3f)", + len(block_ids), + group_idx, + "random" if self.fill_std > 0 else "constant", + self.fill_mean, + self.fill_std, + )