[P/D] KVConnector for decode benchmarking (#25986)

Signed-off-by: Tyler Michael Smith <tlrmchlsmth@gmail.com>
Signed-off-by: Tyler Michael Smith <tyler@neuralmagic.com>
This commit is contained in:
Tyler Michael Smith 2025-10-21 19:30:47 -04:00 committed by GitHub
parent 19748806f0
commit 6c2eef5a5d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 838 additions and 1 deletions

View File

@ -0,0 +1,415 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Unit tests for DecodeBenchConnector.
Tests the functionality of the DecodeBenchConnector which fills KV cache
with dummy values for decode performance benchmarking.
"""
import pytest
import torch
from vllm import SamplingParams
from vllm.config import KVTransferConfig
from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorRole
# ruff: noqa: E501
from vllm.distributed.kv_transfer.kv_connector.v1.decode_bench_connector import (
DecodeBenchConnector,
DecodeBenchConnectorMetadata,
)
from vllm.forward_context import ForwardContext
from vllm.utils.hashing import sha256
from vllm.v1.core.kv_cache_utils import get_request_block_hasher, init_none_hash
from vllm.v1.core.sched.scheduler import Scheduler
from vllm.v1.request import Request
from .utils import (
EOS_TOKEN_ID,
create_model_runner_output,
create_scheduler,
create_vllm_config,
)
class DecodeBenchTestRunner:
"""Test runner for DecodeBenchConnector."""
def __init__(self, block_size: int, num_gpu_blocks: int):
self.block_size = block_size
self.num_gpu_blocks = num_gpu_blocks
self.req_id = -1
# Create vllm config with DecodeBenchConnector
vllm_config = create_vllm_config(
block_size=block_size, max_num_batched_tokens=1000
)
vllm_config.kv_transfer_config = KVTransferConfig(
kv_connector="DecodeBenchConnector",
kv_role="kv_both",
)
self.vllm_config = vllm_config
self.scheduler: Scheduler = create_scheduler(
vllm_config, num_blocks=num_gpu_blocks
)
# Create worker-side connector
self.worker_connector = DecodeBenchConnector(
vllm_config, KVConnectorRole.WORKER
)
# Create dummy KV caches for testing
# Shape: [num_blocks, 2, num_heads, block_size, head_dim]
# Using simplified shape for testing
num_heads = 4
head_dim = 64
self.kv_caches = {
f"layer_{i}": torch.zeros(
num_gpu_blocks, 2, num_heads, block_size, head_dim
)
for i in range(2) # 2 layers for testing
}
# Register KV caches with worker connector
self.worker_connector.register_kv_caches(self.kv_caches)
# Extract scheduler-side connector
scheduler_connector = self.scheduler.connector
assert scheduler_connector is not None
assert isinstance(scheduler_connector, DecodeBenchConnector)
self.scheduler_connector: DecodeBenchConnector = scheduler_connector
init_none_hash(sha256)
self._block_hasher = get_request_block_hasher(block_size, sha256)
self._dummy_ctx: ForwardContext = ForwardContext(
no_compile_layers={}, attn_metadata={}, virtual_engine=0
)
def new_request(self, token_ids: list[int]) -> Request:
"""Create a new request with given token IDs."""
self.req_id += 1
req = Request(
request_id=str(self.req_id),
prompt_token_ids=token_ids,
sampling_params=SamplingParams(max_tokens=100),
pooling_params=None,
eos_token_id=EOS_TOKEN_ID,
block_hasher=self._block_hasher,
)
self.scheduler.add_request(req)
return req
def run_single_step(self, token_id: int = 0):
"""Run a single scheduler + worker step."""
scheduler_output = self.scheduler.schedule()
# Get connector metadata
kv_connector_metadata = scheduler_output.kv_connector_metadata
assert kv_connector_metadata is not None
assert isinstance(kv_connector_metadata, DecodeBenchConnectorMetadata)
# Bind metadata and load KV
self.worker_connector.bind_connector_metadata(kv_connector_metadata)
self.worker_connector.start_load_kv(self._dummy_ctx)
if scheduler_output.total_num_scheduled_tokens > 0:
self.worker_connector.wait_for_save()
self.worker_connector.clear_connector_metadata()
# Create model runner output
model_runner_output = create_model_runner_output(
reqs=self.scheduler.running,
token_id=token_id,
)
self.scheduler.update_from_output(scheduler_output, model_runner_output)
return scheduler_output, kv_connector_metadata
def test_decode_bench_connector_basic():
"""Test basic functionality of DecodeBenchConnector."""
block_size = 16
num_gpu_blocks = 100
runner = DecodeBenchTestRunner(block_size=block_size, num_gpu_blocks=num_gpu_blocks)
# Create a request with multiple blocks worth of tokens
num_tokens = block_size * 3 # 3 blocks
token_ids = [1] * num_tokens
req = runner.new_request(token_ids)
# Run first step - should fill KV cache with dummy values
scheduler_output, metadata = runner.run_single_step()
# Check that get_num_new_matched_tokens returned correct value
# Should be num_tokens - 1 (all except the last token for decode)
expected_fill_tokens = num_tokens - 1
# Check metadata has the request to fill
assert len(metadata.reqs_to_fill) == 1
assert req.request_id in metadata.reqs_to_fill
block_ids_per_group, num_tokens_to_fill = metadata.reqs_to_fill[req.request_id]
assert num_tokens_to_fill == expected_fill_tokens
# For standard attention, there's only one group
assert len(block_ids_per_group) == 1
block_ids = block_ids_per_group[0]
# Calculate expected number of blocks
expected_num_blocks = (expected_fill_tokens + block_size - 1) // block_size
assert len(block_ids) == expected_num_blocks
# Verify KV caches were filled with constant value
for layer_name, kv_cache in runner.kv_caches.items():
for block_id in block_ids:
# Check that the block was filled
block_data = kv_cache[block_id]
# Should be filled with constant value 0.015
assert torch.allclose(block_data, torch.tensor(0.015))
def test_decode_bench_connector_no_refill():
"""Test that DecodeBenchConnector only fills once per request."""
block_size = 16
num_gpu_blocks = 100
runner = DecodeBenchTestRunner(block_size=block_size, num_gpu_blocks=num_gpu_blocks)
# Create a request
num_tokens = block_size * 2
token_ids = [1] * num_tokens
runner.new_request(token_ids)
# Run first step - should fill KV cache
_, metadata1 = runner.run_single_step()
assert len(metadata1.reqs_to_fill) == 1
# Run second step - should NOT fill again (already filled)
_, metadata2 = runner.run_single_step()
assert len(metadata2.reqs_to_fill) == 0
def test_decode_bench_connector_single_token():
"""Test DecodeBenchConnector with single token request."""
block_size = 16
num_gpu_blocks = 100
runner = DecodeBenchTestRunner(block_size=block_size, num_gpu_blocks=num_gpu_blocks)
# Create a request with just 1 token
# Should not fill anything (need at least 2 tokens: 1 to fill, 1 to decode)
token_ids = [1]
runner.new_request(token_ids)
# Run step - should NOT fill KV cache
_, metadata = runner.run_single_step()
assert len(metadata.reqs_to_fill) == 0
def test_decode_bench_connector_two_tokens():
"""Test DecodeBenchConnector with two token request."""
block_size = 16
num_gpu_blocks = 100
runner = DecodeBenchTestRunner(block_size=block_size, num_gpu_blocks=num_gpu_blocks)
# Create a request with 2 tokens
# Should fill 1 token (first token), decode the second
token_ids = [1, 2]
req = runner.new_request(token_ids)
# Run step
_, metadata = runner.run_single_step()
assert len(metadata.reqs_to_fill) == 1
assert req.request_id in metadata.reqs_to_fill
block_ids_per_group, num_tokens_to_fill = metadata.reqs_to_fill[req.request_id]
assert num_tokens_to_fill == 1
# For standard attention, there's only one group
assert len(block_ids_per_group) == 1
assert len(block_ids_per_group[0]) == 1 # 1 token needs 1 block
def test_decode_bench_connector_large_context():
"""Test DecodeBenchConnector with large context size."""
block_size = 16
num_gpu_blocks = 1000
runner = DecodeBenchTestRunner(block_size=block_size, num_gpu_blocks=num_gpu_blocks)
# Create a request with many blocks
num_blocks = 20
num_tokens = block_size * num_blocks
token_ids = list(range(num_tokens))
req = runner.new_request(token_ids)
# Run step
_, metadata = runner.run_single_step()
assert len(metadata.reqs_to_fill) == 1
assert req.request_id in metadata.reqs_to_fill
block_ids_per_group, num_tokens_to_fill = metadata.reqs_to_fill[req.request_id]
# Should fill all tokens except the last one
expected_fill_tokens = num_tokens - 1
assert num_tokens_to_fill == expected_fill_tokens
# For standard attention, there's only one group
assert len(block_ids_per_group) == 1
block_ids = block_ids_per_group[0]
# Calculate expected number of blocks
expected_num_blocks = (expected_fill_tokens + block_size - 1) // block_size
assert len(block_ids) == expected_num_blocks
# Verify blocks were filled
for layer_name, kv_cache in runner.kv_caches.items():
for block_id in block_ids:
block_data = kv_cache[block_id]
assert torch.allclose(block_data, torch.tensor(0.015))
def test_decode_bench_connector_multiple_requests():
"""Test DecodeBenchConnector with multiple sequential requests."""
block_size = 16
num_gpu_blocks = 100
runner = DecodeBenchTestRunner(block_size=block_size, num_gpu_blocks=num_gpu_blocks)
# First request
req1 = runner.new_request([1] * (block_size * 2))
_, metadata1 = runner.run_single_step()
assert len(metadata1.reqs_to_fill) == 1
assert req1.request_id in metadata1.reqs_to_fill
# Complete first request
while runner.scheduler.running:
runner.run_single_step()
# Add EOS to finish
scheduler_output = runner.scheduler.schedule()
model_runner_output = create_model_runner_output(
reqs=runner.scheduler.running,
token_id=EOS_TOKEN_ID,
use_eos=True,
)
runner.scheduler.update_from_output(scheduler_output, model_runner_output)
# Second request - should also get filled
req2 = runner.new_request([2] * (block_size * 3))
_, metadata2 = runner.run_single_step()
assert len(metadata2.reqs_to_fill) == 1
assert req2.request_id in metadata2.reqs_to_fill
# Different request should have different metadata
_, num_tokens1 = metadata1.reqs_to_fill[req1.request_id]
_, num_tokens2 = metadata2.reqs_to_fill[req2.request_id]
assert num_tokens1 == block_size * 2 - 1
assert num_tokens2 == block_size * 3 - 1
def test_decode_bench_connector_partial_block():
"""Test DecodeBenchConnector with partial block filling."""
block_size = 16
num_gpu_blocks = 100
runner = DecodeBenchTestRunner(block_size=block_size, num_gpu_blocks=num_gpu_blocks)
# Create a request that doesn't align to block boundaries
# e.g., 2.5 blocks worth of tokens
num_tokens = block_size * 2 + block_size // 2
token_ids = [1] * num_tokens
req = runner.new_request(token_ids)
# Run step
_, metadata = runner.run_single_step()
assert len(metadata.reqs_to_fill) == 1
assert req.request_id in metadata.reqs_to_fill
block_ids_per_group, num_tokens_to_fill = metadata.reqs_to_fill[req.request_id]
# Should fill all tokens except the last one
expected_fill_tokens = num_tokens - 1
assert num_tokens_to_fill == expected_fill_tokens
# For standard attention, there's only one group
assert len(block_ids_per_group) == 1
block_ids = block_ids_per_group[0]
# Should allocate 3 blocks to hold the partial data
expected_num_blocks = 3
assert len(block_ids) == expected_num_blocks
def test_decode_bench_connector_concurrent_requests():
"""Test DecodeBenchConnector with multiple concurrent requests in the same batch."""
block_size = 16
num_gpu_blocks = 1000
runner = DecodeBenchTestRunner(block_size=block_size, num_gpu_blocks=num_gpu_blocks)
# Create multiple requests that will be batched together
req1 = runner.new_request([1] * (block_size * 2))
req2 = runner.new_request([2] * (block_size * 3))
req3 = runner.new_request([3] * (block_size * 1))
# Run first step - all requests should be filled concurrently
_, metadata = runner.run_single_step()
# All three requests should be in the metadata
assert len(metadata.reqs_to_fill) == 3
assert req1.request_id in metadata.reqs_to_fill
assert req2.request_id in metadata.reqs_to_fill
assert req3.request_id in metadata.reqs_to_fill
# Verify each request has correct fill info
block_ids_per_group1, num_tokens1 = metadata.reqs_to_fill[req1.request_id]
block_ids_per_group2, num_tokens2 = metadata.reqs_to_fill[req2.request_id]
block_ids_per_group3, num_tokens3 = metadata.reqs_to_fill[req3.request_id]
# Verify token counts (all tokens except last one)
assert num_tokens1 == block_size * 2 - 1
assert num_tokens2 == block_size * 3 - 1
assert num_tokens3 == block_size * 1 - 1
# Verify block counts for each request
assert len(block_ids_per_group1[0]) == 2 # 2 blocks
assert len(block_ids_per_group2[0]) == 3 # 3 blocks
assert len(block_ids_per_group3[0]) == 1 # 1 block
# Verify all blocks are filled in KV cache
for req_id, (block_ids_per_group, _) in metadata.reqs_to_fill.items():
block_ids = block_ids_per_group[0]
for layer_name, kv_cache in runner.kv_caches.items():
for block_id in block_ids:
block_data = kv_cache[block_id]
assert torch.allclose(block_data, torch.tensor(0.015))
# Run second step - should NOT fill again (already filled)
_, metadata2 = runner.run_single_step()
assert len(metadata2.reqs_to_fill) == 0
if __name__ == "__main__":
pytest.main([__file__, "-v"])

View File

@ -130,3 +130,9 @@ KVConnectorFactory.register_connector(
"vllm.distributed.kv_transfer.kv_connector.v1.offloading_connector",
"OffloadingConnector",
)
KVConnectorFactory.register_connector(
"DecodeBenchConnector",
"vllm.distributed.kv_transfer.kv_connector.v1.decode_bench_connector",
"DecodeBenchConnector",
)

View File

@ -4,5 +4,8 @@ from vllm.distributed.kv_transfer.kv_connector.v1.base import (
KVConnectorBase_V1,
KVConnectorRole,
)
from vllm.distributed.kv_transfer.kv_connector.v1.decode_bench_connector import ( # noqa E:501
DecodeBenchConnector,
)
__all__ = ["KVConnectorRole", "KVConnectorBase_V1"]
__all__ = ["KVConnectorRole", "KVConnectorBase_V1", "DecodeBenchConnector"]

View File

@ -0,0 +1,413 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
DecodeBenchConnector: A KV Connector for decode instance performance testing.
This connector emulates a prefill-decode disaggregated setting by filling
the KV cache with dummy values, allowing measurement of decoder performance
under larger input sequence lengths (ISL) in resource-limited environments.
Usage:
To use this connector for benchmarking, configure it in the kv_transfer_config:
Example:
vllm serve <model> --kv-transfer-config '{
"kv_connector": "DecodeBenchConnector",
"kv_role": "kv_both",
"kv_connector_extra_config": {
"fill_mean": 0.015,
"fill_std": 0.0
}
}'
Then run your benchmark with desired input/output lengths:
vllm bench serve --base-url http://127.0.0.1:8000 --model <model> \\
--dataset-name random --random-input-len 40000 \\
--random-output-len 100 --max-concurrency 10
Configuration options (via kv_connector_extra_config):
- fill_mean (float): Mean value for random normal fill (default: 0.015)
- fill_std (float): Standard deviation for random fill (default: 0.0)
Set to 0 for constant values, >0 for random sampling
"""
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any
import torch
from vllm.distributed.kv_transfer.kv_connector.v1 import (
KVConnectorBase_V1,
KVConnectorRole,
)
from vllm.distributed.kv_transfer.kv_connector.v1.base import KVConnectorMetadata
from vllm.logger import init_logger
from vllm.utils import cdiv
if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.config import VllmConfig
from vllm.forward_context import ForwardContext
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.request import Request
logger = init_logger(__name__)
@dataclass
class DecodeBenchConnectorMetadata(KVConnectorMetadata):
"""Metadata for DecodeBenchConnector.
Contains information about which requests need their KV cache filled
with dummy values for benchmarking purposes.
"""
# request_id -> (block_ids_per_group, num_tokens_to_fill)
# block_ids_per_group is a tuple of lists, one per KV cache group
# For standard attention: single group, e.g., ([1, 2, 3],)
# For MLA: multiple groups, e.g., ([1, 2], [1, 2])
reqs_to_fill: dict[str, tuple[tuple[list[int], ...], int]]
class DecodeBenchConnector(KVConnectorBase_V1):
"""
A KV Connector for decode instance performance testing.
This connector fills the KV cache with dummy (non-zero) values to
emulate a prefill-decode disaggregated setting, enabling performance
testing of the decoder with larger input sequence lengths.
"""
def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole):
super().__init__(vllm_config, role)
self.connector_scheduler: DecodeBenchConnectorScheduler | None = None
self.connector_worker: DecodeBenchConnectorWorker | None = None
if role == KVConnectorRole.SCHEDULER:
self.connector_scheduler = DecodeBenchConnectorScheduler(vllm_config)
elif role == KVConnectorRole.WORKER:
self.connector_worker = DecodeBenchConnectorWorker(vllm_config)
# ==============================
# Worker-side methods
# ==============================
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
assert self.connector_worker is not None
self.connector_worker.register_kv_caches(kv_caches)
def start_load_kv(self, forward_context: "ForwardContext", **kwargs: Any) -> None:
assert self.connector_worker is not None
assert isinstance(self._connector_metadata, DecodeBenchConnectorMetadata)
self.connector_worker.start_fill_kv(self._connector_metadata)
def wait_for_layer_load(self, layer_name: str) -> None:
# All operations are synchronous, so nothing to wait for
pass
def save_kv_layer(
self,
layer_name: str,
kv_layer: torch.Tensor,
attn_metadata: "AttentionMetadata",
**kwargs: Any,
) -> None:
# This connector doesn't save KV cache (benchmarking only)
pass
def wait_for_save(self):
# This connector doesn't save KV cache (benchmarking only)
pass
# ==============================
# Scheduler-side methods
# ==============================
def get_num_new_matched_tokens(
self,
request: "Request",
num_computed_tokens: int,
) -> tuple[int | None, bool]:
assert self.connector_scheduler is not None
return self.connector_scheduler.get_num_new_matched_tokens(
request, num_computed_tokens
)
def update_state_after_alloc(
self, request: "Request", blocks: "KVCacheBlocks", num_external_tokens: int
):
assert self.connector_scheduler is not None
return self.connector_scheduler.update_state_after_alloc(
request, blocks, num_external_tokens
)
def build_connector_meta(
self, scheduler_output: "SchedulerOutput"
) -> KVConnectorMetadata:
assert self.connector_scheduler is not None
return self.connector_scheduler.build_connector_meta(scheduler_output)
def request_finished(
self,
request: "Request",
block_ids: list[int],
) -> tuple[bool, dict[str, Any] | None]:
assert self.connector_scheduler is not None
self.connector_scheduler.request_finished(request)
return False, None
class DecodeBenchConnectorScheduler:
"""Scheduler-side implementation for DecodeBenchConnector."""
def __init__(self, vllm_config: "VllmConfig"):
self.vllm_config = vllm_config
self.block_size = vllm_config.cache_config.block_size
# Track which requests have already been filled
self._filled_requests: set[str] = set()
# Track pending fills for the current scheduler step
# request_id -> (block_ids_per_group, num_tokens_to_fill)
# Note: _pending_fills doesn't need explicit cleanup - it's cleared
# after build_connector_meta() is called in the same scheduler step
self._pending_fills: dict[str, tuple[tuple[list[int], ...], int]] = {}
def get_num_new_matched_tokens(
self,
request: "Request",
num_computed_tokens: int,
) -> tuple[int, bool]:
"""
For new requests, return the number of tokens that should be filled
with dummy KV cache values.
Returns:
(num_tokens_to_fill, is_async)
- num_tokens_to_fill: number of uncomputed tokens minus 1
(we fill everything except the last token for decode)
- is_async: False (synchronous filling)
"""
req_id = request.request_id
# Only fill once per request on first scheduling
if req_id in self._filled_requests:
return 0, False
# Calculate how many tokens we need to fill
# Fill all uncomputed tokens except the last one (which will be decoded)
# This simulates having processed a long prefill
num_uncomputed_tokens = request.num_tokens - num_computed_tokens
num_tokens_to_fill = max(0, num_uncomputed_tokens - 1)
if num_tokens_to_fill == 0:
return 0, False
# Return False for synchronous operation - the fill is fast enough
# that async overhead isn't worth it
return num_tokens_to_fill, False
def update_state_after_alloc(
self, request: "Request", blocks: "KVCacheBlocks", num_external_tokens: int
):
"""
Called after blocks are allocated. Store the block IDs so we can
fill them with dummy values.
Supports both standard attention (single KV cache group) and MLA
(multiple KV cache groups).
"""
req_id = request.request_id
if num_external_tokens == 0:
return
# Get the block IDs that were allocated
# block_groups is a tuple of lists, one per KV cache group
# For standard attention: 1 group
# For MLA: multiple groups (one per attention type)
block_groups = blocks.get_block_ids()
# Calculate how many blocks we need to fill
# num_external_tokens are the tokens we said we'd provide
num_blocks_to_fill = cdiv(num_external_tokens, self.block_size)
# Extract the first num_blocks_to_fill blocks from each group
# All groups should have the same block IDs for the same request
block_ids_per_group = tuple(
group_blocks[:num_blocks_to_fill] for group_blocks in block_groups
)
# Store the blocks to fill for all group. _pending_fills doesn't need cleanup
# as it's cleared after build_connector_meta
self._pending_fills[req_id] = (
block_ids_per_group,
num_external_tokens,
)
self._filled_requests.add(req_id)
logger.debug(
"DecodeBenchConnector: Allocated %d blocks across %d KV cache groups "
"for request %s",
num_blocks_to_fill,
len(block_groups),
req_id,
)
def build_connector_meta(
self, scheduler_output: "SchedulerOutput"
) -> KVConnectorMetadata:
"""
Build metadata containing information about which blocks to fill
with dummy KV values.
"""
meta = DecodeBenchConnectorMetadata(reqs_to_fill=self._pending_fills.copy())
# Clear pending fills after building metadata
self._pending_fills.clear()
return meta
def request_finished(self, request: "Request"):
"""
Called when a request has finished. Clean up any state.
"""
self._filled_requests.discard(request.request_id)
class DecodeBenchConnectorWorker:
"""Worker-side implementation for DecodeBenchConnector."""
def __init__(self, vllm_config: "VllmConfig"):
self.vllm_config = vllm_config
self.block_size = vllm_config.cache_config.block_size
# Get fill parameters from extra config
kv_transfer_config = vllm_config.kv_transfer_config
assert kv_transfer_config is not None
self.fill_mean = kv_transfer_config.get_from_extra_config("fill_mean", 0.015)
self.fill_std = kv_transfer_config.get_from_extra_config("fill_std", 0.0)
# Will be populated via register_kv_caches
self.kv_caches: dict[str, torch.Tensor] | None = None
# Mapping from KV cache group index to list of layer names in that group
self.group_to_layers: dict[int, list[str]] | None = None
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
"""Store references to the KV cache tensors and build group mapping."""
self.kv_caches = kv_caches
# For simplicity, assume all layers belong to group 0 (standard attention)
# For MLA models with multiple groups, the metadata will handle the mapping
# We just need to fill the blocks specified in the metadata
self.group_to_layers = {0: list(kv_caches.keys())}
logger.debug(
"DecodeBenchConnector: Registered %d KV cache layers",
len(kv_caches),
)
def start_fill_kv(self, metadata: DecodeBenchConnectorMetadata):
"""
Fill the allocated KV cache blocks with dummy (non-zero) values.
This simulates having a populated KV cache from a prefill phase,
allowing decode performance testing with larger context sizes.
Supports both standard attention (single group) and MLA (multiple groups).
"""
if not metadata.reqs_to_fill:
return
assert self.kv_caches is not None, "KV caches must be registered before filling"
assert self.group_to_layers is not None, "Group mapping must be initialized"
for req_id, (block_ids_per_group, num_tokens) in metadata.reqs_to_fill.items():
# Fill blocks for each KV cache group
for group_idx, block_ids in enumerate(block_ids_per_group):
self._fill_blocks(group_idx, block_ids, num_tokens)
logger.debug(
"DecodeBenchConnector: Filled %d blocks (%d tokens) across %d groups "
"for request %s",
len(block_ids_per_group[0]) if block_ids_per_group else 0,
num_tokens,
len(block_ids_per_group),
req_id,
)
def _fill_blocks(self, group_idx: int, block_ids: list[int], num_tokens: int):
"""
Fill specified blocks with dummy non-zero values for a specific KV cache group.
Args:
group_idx: The KV cache group index to fill
block_ids: List of block IDs to fill in this group
num_tokens: Total number of tokens to fill across these blocks
"""
if not block_ids:
return
assert self.kv_caches is not None
assert self.group_to_layers is not None
# Get the layers that belong to this group
layer_names = self.group_to_layers.get(group_idx, [])
# Fill only the layers in this group
for layer_name in layer_names:
if layer_name not in self.kv_caches:
logger.warning(
"DecodeBenchConnector: Layer %s not found in KV caches", layer_name
)
continue
kv_cache = self.kv_caches[layer_name]
# Convert block_ids to tensor on device
block_ids_tensor = torch.tensor(
block_ids, dtype=torch.long, device=kv_cache.device
)
# Filter invalid block IDs
valid_mask = block_ids_tensor < kv_cache.shape[0]
valid_block_ids = block_ids_tensor[valid_mask]
if len(valid_block_ids) == 0:
continue
# Create fill values - either constant or random
block_shape = kv_cache.shape[1:]
if self.fill_std > 0:
# Random normal sampling
fill_values = torch.normal(
mean=self.fill_mean,
std=self.fill_std,
size=(len(valid_block_ids),) + block_shape,
dtype=kv_cache.dtype,
device=kv_cache.device,
)
else:
# Constant fill value
fill_values = torch.full(
(len(valid_block_ids),) + block_shape,
self.fill_mean,
dtype=kv_cache.dtype,
device=kv_cache.device,
)
# Batch fill operation
kv_cache[valid_block_ids] = fill_values
logger.debug(
"DecodeBenchConnector: Filled %d blocks in group %d with %s values "
"(mean=%.3f, std=%.3f)",
len(block_ids),
group_idx,
"random" if self.fill_std > 0 else "constant",
self.fill_mean,
self.fill_std,
)