[Hybrid]: Decouple Kernel Block Size from KV Page Size (#24486)

Signed-off-by: lizhiyuan <uniartisan2017@gmail.com>
Signed-off-by: Zhiyuan Li <uniartisan2017@gmail.com>
This commit is contained in:
Zhiyuan Li 2025-10-09 14:43:39 +08:00 committed by GitHub
parent d17f0fbf30
commit d24cf322e1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
18 changed files with 573 additions and 55 deletions

View File

@ -241,6 +241,7 @@ def test_sampling_metadata_in_input_batch(device: str, batch_size: int):
pin_memory=is_pin_memory_available(),
vocab_size=1024,
block_sizes=[1],
kernel_block_sizes=[1],
)
reqs: list[CachedRequestState] = []
req_id_reqs = {}
@ -335,6 +336,7 @@ def test_swap_states_in_input_batch(device: str, batch_size: int, swap_list: lis
pin_memory=is_pin_memory_available(),
vocab_size=1024,
block_sizes=[1],
kernel_block_sizes=[1],
)
ref_input_batch: InputBatch = InputBatch(
max_num_reqs=batch_size,
@ -344,6 +346,7 @@ def test_swap_states_in_input_batch(device: str, batch_size: int, swap_list: lis
pin_memory=is_pin_memory_available(),
vocab_size=1024,
block_sizes=[1],
kernel_block_sizes=[1],
)
reqs: list[CachedRequestState] = []

View File

@ -68,6 +68,9 @@ def initialize_kv_cache(runner: GPUModelRunner):
pin_memory=runner.pin_memory,
vocab_size=runner.model_config.get_vocab_size(),
block_sizes=[kv_cache_config.kv_cache_groups[0].kv_cache_spec.block_size],
kernel_block_sizes=[
kv_cache_config.kv_cache_groups[0].kv_cache_spec.block_size
],
)
runner.initialize_attn_backend(kv_cache_config)
@ -817,42 +820,231 @@ def test_hybrid_attention_mamba_tensor_shapes(monkeypatch):
ssm_shape = vllm_ctx[layer_2].kv_cache[0][1].shape
# assert we are using FlashInfer
assert attn_shape[0] == num_blocks
assert attn_shape[0] % num_blocks == 0
block_split_ratio = attn_shape[0] // num_blocks
# use small blocks for testing to avoid memory issues
test_block_size = min(2, len(blocks0), len(blocks1))
# use non-overlapping blocks to avoid data contamination
# Split kernel blocks: first half for attention, second half for mamba
mid_point = num_blocks // 2
# attention uses kernel blocks from first half (mapped to logical blocks)
kv_blocks_for_attention = np.array([0, 1])[:test_block_size]
# mamba uses kernel blocks from second half
kv_blocks_for_mamba = np.array([mid_point, mid_point + 1])[:test_block_size]
# create small constant tensors for testing with corrected shapes
# attention: [block_size, ...] starting from dimension 2
attn_constant_shape = attn_shape[2:]
conv_constant_shape = conv_shape[1:]
ssm_constant_shape = ssm_shape[1:]
attn_blocks_constant = torch.full(
(len(blocks0), *attn_shape[1:]), device=DEVICE, fill_value=3.33
(test_block_size, *attn_constant_shape), device=DEVICE, fill_value=3.33
)
conv_blocks_constant = torch.full(
(len(blocks1), *conv_shape[1:]), device=DEVICE, fill_value=6.66
(test_block_size, *conv_constant_shape), device=DEVICE, fill_value=6.66
)
ssm_blocks_constant = torch.full(
(len(blocks1), *ssm_shape[1:]), device=DEVICE, fill_value=9.99
(test_block_size, *ssm_constant_shape), device=DEVICE, fill_value=9.99
)
# fill all attention blocks with constant
for layer in [layer_0, layer_1]:
vllm_ctx[layer].kv_cache[0][blocks0, :] = (
attn_blocks_constant.detach().clone()
)
# Fill attention blocks with constants using kv block indices
kernel_blocks_for_attention = kv_blocks_for_attention * block_split_ratio
# fill all mamba blocks with constant
for layer in [layer_0, layer_1]:
# attention: kv_cache[0][kernel_block_idx, kv_idx, ...]
for i, kernel_block in enumerate(kernel_blocks_for_attention):
vllm_ctx[layer].kv_cache[0][kernel_block, :] = attn_blocks_constant[i]
# fill mamba blocks with constants using kernel block indices
for layer in [layer_2, layer_3, layer_4, layer_5]:
vllm_ctx[layer].kv_cache[0][0][blocks1, :] = (
conv_blocks_constant.detach().clone()
)
vllm_ctx[layer].kv_cache[0][1][blocks1, :] = (
ssm_blocks_constant.detach().clone()
)
# mamba: kv_cache[0][component][kernel_block_idx, ...]
for i, kv_block in enumerate(kv_blocks_for_mamba):
vllm_ctx[layer].kv_cache[0][0][kv_block, :] = conv_blocks_constant[i]
vllm_ctx[layer].kv_cache[0][1][kv_block, :] = ssm_blocks_constant[i]
# verify attention and mamba contents are correct
for layer in [layer_0, layer_1]:
assert torch.equal(
vllm_ctx[layer].kv_cache[0][blocks0, :], attn_blocks_constant
)
for i, kernel_block in enumerate(kernel_blocks_for_attention):
actual_kv = vllm_ctx[layer].kv_cache[0][kernel_block, :]
expected = attn_blocks_constant[i]
# Check K and V separately
assert torch.equal(actual_kv[0], expected)
assert torch.equal(actual_kv[1], expected)
for layer in [layer_2, layer_3, layer_4, layer_5]:
assert torch.equal(
vllm_ctx[layer].kv_cache[0][0][blocks1, :], conv_blocks_constant
)
assert torch.equal(
vllm_ctx[layer].kv_cache[0][1][blocks1, :], ssm_blocks_constant
)
for i, kv_block in enumerate(kv_blocks_for_mamba):
actual_conv = vllm_ctx[layer].kv_cache[0][0][kv_block, :]
actual_ssm = vllm_ctx[layer].kv_cache[0][1][kv_block, :]
expected_conv = conv_blocks_constant[i]
expected_ssm = ssm_blocks_constant[i]
assert torch.equal(actual_conv, expected_conv)
assert torch.equal(actual_ssm, expected_ssm)
for layer in [layer_2, layer_3, layer_4, layer_5]:
for i, kv_block in enumerate(kv_blocks_for_mamba):
actual_conv = vllm_ctx[layer].kv_cache[0][0][kv_block, :]
actual_ssm = vllm_ctx[layer].kv_cache[0][1][kv_block, :]
expected_conv = conv_blocks_constant[i]
expected_ssm = ssm_blocks_constant[i]
assert torch.equal(actual_conv, expected_conv)
assert torch.equal(actual_ssm, expected_ssm)
def test_hybrid_block_table_initialization():
"""Test hybrid block table with different kernel and kvcache_manager block
sizes."""
from vllm.v1.worker.block_table import BlockTable
# Test configuration: kvcache_manager block size = 32,
# kernel block size = 16
block_size = 32
kernel_block_sizes = [16]
max_num_reqs = 10
max_num_blocks_per_req = 20
max_num_batched_tokens = 512
block_table = BlockTable(
block_size=block_size,
max_num_reqs=max_num_reqs,
max_num_blocks_per_req=max_num_blocks_per_req,
max_num_batched_tokens=max_num_batched_tokens,
pin_memory=False,
device=torch.device(DEVICE),
kernel_block_size=kernel_block_sizes[0],
)
# Verify hybrid block configuration
assert block_table.use_hybrid_blocks is True
assert block_table.block_size == kernel_block_sizes[0]
assert block_table.blocks_per_kv_block == (
block_size // kernel_block_sizes[0]
) # Changed to use first element
# Test block table conversion logic
# One kvcache_manager block should map to multiple kernel blocks
kvcache_manager_blocks = [0, 1, 2]
# Verify that kvcache_manager blocks can be converted to kernel blocks
# and that block table operations work correctly.
req_index = 0
block_table.append_row(kvcache_manager_blocks, req_index)
# Get expected kernel blocks from the implementation for verification.
expected_kernel_blocks = block_table._map_to_kernel_blocks(
np.array(kvcache_manager_blocks)
)
# Verify block table state
assert block_table.num_blocks_per_row[req_index] == len(expected_kernel_blocks)
assert np.array_equal(
block_table.block_table.np[req_index, : len(expected_kernel_blocks)],
expected_kernel_blocks,
)
def test_input_batch_with_kernel_block_sizes():
"""Test InputBatch initialization with kernel_block_sizes parameter."""
max_num_reqs = 10
max_model_len = 512
max_num_batched_tokens = 512
device = torch.device(DEVICE)
pin_memory = False
vocab_size = 50272
# Test with different kernel block sizes
block_sizes = [32, 64]
kernel_block_sizes = [16, 32]
input_batch = InputBatch(
max_num_reqs=max_num_reqs,
max_model_len=max_model_len,
max_num_batched_tokens=max_num_batched_tokens,
device=device,
pin_memory=pin_memory,
vocab_size=vocab_size,
block_sizes=block_sizes,
kernel_block_sizes=kernel_block_sizes,
)
# Verify that block tables were created with kernel block sizes
assert len(input_batch.block_table.block_tables) == len(block_sizes)
for i, (kv_size, kernel_size) in enumerate(zip(block_sizes, kernel_block_sizes)):
block_table = input_batch.block_table.block_tables[i]
if kv_size != kernel_size:
assert block_table.use_hybrid_blocks is True
assert block_table.block_size == kernel_size
else:
assert block_table.use_hybrid_blocks is False
assert block_table.block_size == kernel_size
def test_hybrid_cache_integration(model_runner, dist_init):
"""Test hybrid cache architecture integration with GPUModelRunner."""
# Create a new model runner with hybrid cache configuration
vllm_config = get_vllm_config()
# Configure hybrid cache with different kvcache_manager block size
vllm_config.cache_config.block_size = 32
model_config = vllm_config.model_config
num_heads = model_config.get_num_kv_heads(vllm_config.parallel_config)
head_size = model_config.get_head_size()
vllm_config.compilation_config.static_forward_context["layer.0"] = Attention(
num_heads, head_size, 0.1
)
runner = GPUModelRunner(vllm_config, DEVICE)
# Initialize KV cache with configuration
attn_spec = FullAttentionSpec(
block_size=16, # Use kernel block size directly
num_kv_heads=runner.model_config.get_num_kv_heads(runner.parallel_config),
head_size=runner.model_config.get_head_size(),
dtype=runner.kv_cache_dtype,
)
tensor_size = attn_spec.page_size_bytes * NUM_BLOCKS
kv_cache_config = KVCacheConfig(
num_blocks=NUM_BLOCKS,
kv_cache_tensors=[
KVCacheTensor(size=tensor_size, shared_by=["layer.0"]),
],
kv_cache_groups=[
KVCacheGroupSpec(layer_names=["layer.0"], kv_cache_spec=attn_spec)
],
)
runner.kv_cache_config = kv_cache_config
# Initialize input batch with kernel block sizes
runner.input_batch = InputBatch(
max_num_reqs=runner.max_num_reqs,
max_model_len=runner.max_model_len,
max_num_batched_tokens=runner.max_num_tokens,
device=runner.device,
pin_memory=runner.pin_memory,
vocab_size=runner.model_config.get_vocab_size(),
block_sizes=[kv_cache_config.kv_cache_groups[0].kv_cache_spec.block_size],
kernel_block_sizes=[16],
) # Use kernel block size
runner.initialize_attn_backend(kv_cache_config)
# Verify hybrid block table configuration
block_table = runner.input_batch.block_table.block_tables[0]
assert block_table.block_size == (
kv_cache_config.kv_cache_groups[0].kv_cache_spec.block_size
)
# Test request processing with hybrid blocks
req_id = "hybrid_req_0"
scheduler_output = _schedule_new_request(req_id)
# Update states should work with hybrid blocks
runner._update_states(scheduler_output)
assert _is_req_scheduled(runner, req_id)
assert _is_req_state_block_table_match(runner, req_id)

View File

@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from abc import ABC, abstractmethod
from typing import Generic, Optional, Protocol, TypeVar
from typing import Generic, Optional, Protocol, TypeVar, Union
import torch
@ -26,6 +26,13 @@ class AttentionType:
"""Attention between dec. Q and enc. K/V for encoder-decoder."""
class MultipleOf:
base: int
def __init__(self, base: int):
self.base = base
class AttentionBackend(ABC):
"""Abstract class for attention backends."""
@ -57,6 +64,10 @@ class AttentionBackend(ABC):
def get_metadata_cls() -> type["AttentionMetadata"]:
raise NotImplementedError
@classmethod
def get_supported_kernel_block_size(cls) -> list[Union[int, MultipleOf]]:
return cls.get_impl_cls().get_supported_kernel_block_size()
@classmethod
def make_metadata(cls, *args, **kwargs) -> "AttentionMetadata":
return cls.get_metadata_cls()(*args, **kwargs)
@ -157,6 +168,11 @@ class AttentionImpl(ABC, Generic[T]):
) -> None:
raise NotImplementedError
@staticmethod
def get_supported_kernel_block_size() -> list[Union[int, MultipleOf]]:
# TODO: implement this function for all backends.
return [MultipleOf(1)]
@abstractmethod
def forward(
self,

View File

@ -365,6 +365,23 @@ class HybridAttentionMambaModelConfig(VerifyAndUpdateConfig):
block_size=model_config.max_model_len,
).page_size_bytes
# Model may be marked as is_hybrid
# but mamba is skipped via config,
# return directly
if mamba_page_size == 0:
return
# Attention backend constraints:
# - FlashAttention (FA) requires block size to be multiple of 16
# - MLA (Multi-head Latent Attention) requires larger alignment:
# * CUTLASS_MLA backend: 128-byte alignment
# * Other MLA backends: 64-byte alignment
if model_config.use_mla:
use_cutlass_mla = envs.VLLM_ATTENTION_BACKEND == "CUTLASS_MLA"
kernel_block_alignment_size = 128 if use_cutlass_mla else 64
else:
kernel_block_alignment_size = 16
if cache_config.enable_prefix_caching:
# With prefix caching, select attention block size to
# optimize for mamba kernel performance
@ -381,19 +398,28 @@ class HybridAttentionMambaModelConfig(VerifyAndUpdateConfig):
# TODO(tdoublep): this constraint can be relaxed fairly
# easily by changing the way we layout chunks in the
# mamba2 kernels.
chunk_size = model_config.get_mamba_chunk_size()
from math import gcd
def lcm(a, b):
return a * b // gcd(a, b)
base_chunk_size = model_config.get_mamba_chunk_size()
attn_tokens_per_mamba_state = cdiv(mamba_page_size, attn_page_size_1_token)
chunk_size = lcm(base_chunk_size, kernel_block_alignment_size)
attn_block_size = chunk_size * cdiv(attn_tokens_per_mamba_state, chunk_size)
cache_config.mamba_block_size = attn_block_size
else:
# Without prefix caching, select minimum valid attention block size
# to minimize mamba state padding
# some attention backends (e.g. FA) only support setting
# block size to multiple of 16, so let's suggest a value
# that would work (note: FA is currently not compatible
# with mamba layers, use FlashInfer instead).
attn_block_size = 16 * cdiv(mamba_page_size, 16 * attn_page_size_1_token)
# Calculate minimum attention block size that satisfies both:
# 1. Backend alignment requirements (kernel_block_alignment_size)
# 2. Mamba page size compatibility (attn_page_size >= mamba_page_size)
attn_block_size = kernel_block_alignment_size * cdiv(
mamba_page_size, kernel_block_alignment_size * attn_page_size_1_token
)
# override attention block size if either (a) the
# user has not set it or (b) the user has set it

View File

@ -118,7 +118,15 @@ class CudaPlatformBase(Platform):
# TODO(lucas): handle this more gracefully
# Note: model_config may be None during testing
if model_config is not None and model_config.use_mla:
# Note: block_size is initialized in
# HybridAttentionMambaModelConfig.verify_and_update_config
# for models with both attention and mamba,
# and doesn't need to be reinitialized here
if (
model_config is not None
and model_config.use_mla
and cache_config.block_size is not None
):
use_sparse = hasattr(vllm_config.model_config.hf_config, "index_topk")
# If `VLLM_ATTENTION_BACKEND` is not set and we are using MLA,
# then we default to FlashMLA backend for non-blackwell GPUs,
@ -151,18 +159,22 @@ class CudaPlatformBase(Platform):
if (
use_flashmla
and is_flashmla_dense_supported()[0]
and cache_config.block_size != 64
and cache_config.block_size % 64 != 0
):
cache_config.block_size = 64
logger.info("Forcing kv cache block size to 64 for FlashMLA backend.")
if use_cutlass_mla and cache_config.block_size != 128:
if use_cutlass_mla and cache_config.block_size % 128 != 0:
cache_config.block_size = 128
logger.info(
"Forcing kv cache block size to 128 for CUTLASS_MLA backend."
)
if use_flashinfer_mla and cache_config.block_size not in [32, 64]:
if (
use_flashinfer_mla
and cache_config.block_size != 32
and cache_config.block_size % 64 != 0
):
cache_config.block_size = 64
logger.info(
"Forcing kv cache block size to 64 for FlashInferMLA backend."
@ -269,12 +281,12 @@ class CudaPlatformBase(Platform):
use_cutlassmla = selected_backend == _Backend.CUTLASS_MLA or (
selected_backend is None
and cls.is_device_capability(100)
and block_size == 128
and block_size % 128 == 0
)
use_flashinfermla = selected_backend == _Backend.FLASHINFER_MLA or (
selected_backend is None
and cls.is_device_capability(100)
and block_size in [32, 64]
and (block_size == 32 or block_size % 64 == 0)
)
use_flashmla = selected_backend == _Backend.FLASHMLA or (
selected_backend is None and is_flashmla_dense_supported()[0]
@ -298,7 +310,7 @@ class CudaPlatformBase(Platform):
"vllm.v1.attention.backends.mla.flashinfer_mla.FlashInferMLABackend"
)
if use_flashmla:
if block_size != 64:
if block_size % 64 != 0:
logger.warning(
"FlashMLA backend is not supported for block size %d"
" (currently only supports block size 64).",

View File

@ -3,7 +3,7 @@
"""Attention layer with FlashAttention."""
from dataclasses import dataclass
from typing import Optional
from typing import Optional, Union
import numpy as np
import torch
@ -14,6 +14,7 @@ from vllm.attention.backends.abstract import (
AttentionImpl,
AttentionMetadata,
AttentionType,
MultipleOf,
is_quantized_kv_cache,
)
from vllm.attention.layer import Attention
@ -57,6 +58,10 @@ class FlashAttentionBackend(AttentionBackend):
def get_supported_head_sizes(cls) -> list[int]:
return [32, 64, 96, 128, 160, 192, 224, 256]
@staticmethod
def get_supported_kernel_block_size() -> list[Union[int, MultipleOf]]:
return [MultipleOf(16)]
@classmethod
def validate_head_size(cls, head_size: int) -> None:
supported_head_sizes = cls.get_supported_head_sizes()

View File

@ -23,6 +23,7 @@ from vllm.attention.backends.abstract import (
AttentionBackend,
AttentionImpl,
AttentionType,
MultipleOf,
)
from vllm.config import CUDAGraphMode, VllmConfig
from vllm.logger import init_logger
@ -165,6 +166,13 @@ class FlashInferBackend(AttentionBackend):
# https://github.com/flashinfer-ai/flashinfer/blob/3d55c71a62052c590c130897d3a3db49b14fcc34/include/flashinfer/utils.cuh#L157
return [64, 128, 256]
@staticmethod
def get_supported_kernel_block_size() -> list[Union[int, MultipleOf]]:
# Note: Not sure for all platforms,
# but on Blackwell, only support a page size of
# 16, 32, 64
return [16, 32, 64]
@classmethod
def validate_head_size(cls, head_size: int) -> None:
supported_head_sizes = cls.get_supported_head_sizes()

View File

@ -10,6 +10,7 @@ import vllm._custom_ops as ops
from vllm.attention.backends.abstract import (
AttentionLayer,
AttentionType,
MultipleOf,
is_quantized_kv_cache,
)
from vllm.logger import init_logger
@ -44,6 +45,10 @@ class CutlassMLABackend(MLACommonBackend):
def get_builder_cls() -> type["CutlassMLAMetadataBuilder"]:
return CutlassMLAMetadataBuilder
@staticmethod
def get_supported_kernel_block_size() -> list[Union[int, MultipleOf]]:
return [128]
class SM100Workspace:
def __init__(self, initial_workspace_size):

View File

@ -6,7 +6,7 @@ from typing import ClassVar, Optional, Union
import torch
from vllm.attention.backends.abstract import AttentionLayer, AttentionType
from vllm.attention.backends.abstract import AttentionLayer, AttentionType, MultipleOf
from vllm.attention.ops.flashmla import (
flash_mla_with_kvcache,
get_mla_metadata,
@ -44,6 +44,10 @@ class FlashMLABackend(MLACommonBackend):
def get_impl_cls() -> type["FlashMLAImpl"]:
return FlashMLAImpl
@staticmethod
def get_supported_kernel_block_size() -> list[Union[int, MultipleOf]]:
return [64]
@dataclass
class FlashMLADecodeMetadata(MLACommonDecodeMetadata):

View File

@ -3,7 +3,7 @@
"""Attention layer with AiterFlashAttention."""
from dataclasses import dataclass
from typing import Optional
from typing import Optional, Union
import torch
@ -12,6 +12,7 @@ from vllm.attention.backends.abstract import (
AttentionImpl,
AttentionMetadata,
AttentionType,
MultipleOf,
)
from vllm.config import VllmConfig
from vllm.logger import init_logger
@ -359,6 +360,10 @@ class AiterFlashAttentionBackend(AttentionBackend):
def get_supported_head_sizes(cls) -> list[int]:
return [64, 128, 256]
@staticmethod
def get_supported_kernel_block_size() -> list[Union[int, MultipleOf]]:
return [MultipleOf(16)]
@classmethod
def validate_head_size(cls, head_size: int) -> None:
supported_head_sizes = cls.get_supported_head_sizes()

View File

@ -4,7 +4,7 @@
import ast
from dataclasses import dataclass
from typing import Optional
from typing import Optional, Union
import torch
@ -14,6 +14,7 @@ from vllm.attention.backends.abstract import (
AttentionImpl,
AttentionMetadata,
AttentionType,
MultipleOf,
)
from vllm.attention.ops.triton_unified_attention import unified_attention
from vllm.config import VllmConfig
@ -39,6 +40,10 @@ class TreeAttentionBackend(AttentionBackend):
def get_supported_head_sizes(cls) -> list[int]:
return [32, 64, 96, 128, 160, 192, 224, 256]
@staticmethod
def get_supported_kernel_block_size() -> list[Union[int, MultipleOf]]:
return [MultipleOf(16)]
@classmethod
def validate_head_size(cls, head_size: int) -> None:
supported_head_sizes = cls.get_supported_head_sizes()

View File

@ -3,7 +3,7 @@
"""High-Performance Triton-only Attention layer."""
from dataclasses import dataclass
from typing import ClassVar, Optional
from typing import ClassVar, Optional, Union
import torch
@ -12,6 +12,7 @@ from vllm.attention.backends.abstract import (
AttentionImpl,
AttentionMetadata,
AttentionType,
MultipleOf,
)
from vllm.attention.ops.triton_reshape_and_cache_flash import (
triton_reshape_and_cache_flash,
@ -157,6 +158,10 @@ class TritonAttentionBackend(AttentionBackend):
def get_supported_dtypes(cls) -> list[torch.dtype]:
return [torch.float16, torch.bfloat16, torch.float32]
@staticmethod
def get_supported_kernel_block_size() -> list[Union[int, MultipleOf]]:
return [MultipleOf(16)]
@classmethod
def validate_head_size(cls, head_size: int) -> None:
# Triton Attention supports any head size above 32

View File

@ -3,7 +3,7 @@
"""Attention layer with XFormersAttention."""
from dataclasses import dataclass
from typing import Optional
from typing import Optional, Union
import torch
@ -12,6 +12,7 @@ from vllm.attention.backends.abstract import (
AttentionImpl,
AttentionMetadata,
AttentionType,
MultipleOf,
)
from vllm.attention.ops.triton_unified_attention import unified_attention
from vllm.config import VllmConfig
@ -80,6 +81,10 @@ class XFormersAttentionBackend(AttentionBackend):
256,
]
@staticmethod
def get_supported_kernel_block_size() -> list[Union[int, MultipleOf]]:
return [MultipleOf(16)]
@classmethod
def validate_head_size(cls, head_size: int) -> None:
supported_head_sizes = cls.get_supported_head_sizes()

View File

@ -22,22 +22,64 @@ class BlockTable:
max_num_batched_tokens: int,
pin_memory: bool,
device: torch.device,
kernel_block_size: int,
):
self.block_size = block_size
"""
Args:
block_size: Block size used for KV cache memory allocation
max_num_reqs: Maximum number of concurrent requests supported.
max_num_blocks_per_req: Maximum number of blocks per request.
max_num_batched_tokens: Maximum number of tokens in a batch.
pin_memory: Whether to pin memory for faster GPU transfers.
device: Target device for the block table.
kernel_block_size: The block_size of underlying attention kernel.
Will be the same as `block_size` if `block_size` is supported
by the attention kernel.
"""
self.max_num_reqs = max_num_reqs
self.max_num_blocks_per_req = max_num_blocks_per_req
self.max_num_batched_tokens = max_num_batched_tokens
self.pin_memory = pin_memory
self.device = device
if kernel_block_size == block_size:
# Standard case: allocation and computation use same block size
# No block splitting needed, direct mapping
self.block_size = block_size
self.blocks_per_kv_block = 1
self.use_hybrid_blocks = False
else:
# Hybrid case: allocation block size differs from kernel block size
# Memory blocks are subdivided to match kernel requirements
# Example: 32-token memory blocks with 16-token kernel blocks
# → Each memory block corresponds to 2 kernel blocks
if block_size % kernel_block_size != 0:
raise ValueError(
f"kernel_block_size {kernel_block_size} must divide "
f"kv_manager_block_size size {block_size} evenly"
)
self.block_size = kernel_block_size
self.blocks_per_kv_block = block_size // kernel_block_size
self.use_hybrid_blocks = True
self.max_num_blocks_per_req = max_num_blocks_per_req * self.blocks_per_kv_block
self.block_table = self._make_buffer(
max_num_reqs, max_num_blocks_per_req, dtype=torch.int32
self.max_num_reqs, self.max_num_blocks_per_req, dtype=torch.int32
)
self.num_blocks_per_row = np.zeros(max_num_reqs, dtype=np.int32)
self.slot_mapping = self._make_buffer(
self.max_num_batched_tokens, dtype=torch.int64
)
if self.use_hybrid_blocks:
self._kernel_block_arange = np.arange(0, self.blocks_per_kv_block).reshape(
1, -1
)
else:
self._kernel_block_arange = None
try:
self.dcp_world_size = get_dcp_group().world_size
self.dcp_rank = get_dcp_group().rank_in_group
@ -53,6 +95,10 @@ class BlockTable:
) -> None:
if not block_ids:
return
if self.use_hybrid_blocks:
block_ids = self._map_to_kernel_blocks(np.array(block_ids))
num_blocks = len(block_ids)
start = self.num_blocks_per_row[row_idx]
self.num_blocks_per_row[row_idx] += num_blocks
@ -94,6 +140,7 @@ class BlockTable:
req_indices * self.max_num_blocks_per_req
+ positions // virtual_block_size
)
block_numbers = self.block_table.np.ravel()[block_table_indices]
# Use virtual_block_size for mask calculation, which marks local
# tokens.
@ -111,6 +158,7 @@ class BlockTable:
block_table_indices = (
req_indices * self.max_num_blocks_per_req + positions // self.block_size
)
block_numbers = self.block_table.np.ravel()[block_table_indices]
block_offsets = positions % self.block_size
np.add(
@ -129,6 +177,31 @@ class BlockTable:
self.block_table.gpu.fill_(0)
self.block_table.cpu.fill_(0)
def _map_to_kernel_blocks(self, kv_manager_block_ids: np.ndarray) -> np.ndarray:
"""Convert kv_manager_block_id IDs to kernel block IDs.
Example:
# kv_manager_block_ids: 32 tokens,
# Kernel block size: 16 tokens
# blocks_per_kv_block = 2
>>> kv_manager_block_ids = np.array([0, 1, 2])
>>> Result: [0, 1, 2, 3, 4, 5]
# Each kv_manager_block_id maps to 2 kernel block id:
# kv_manager_block_id 0 → kernel block id [0, 1]
# kv_manager_block_id 1 → kernel block id [2, 3]
# kv_manager_block_id 2 → kernel block id [4, 5]
"""
if not self.use_hybrid_blocks:
return kv_manager_block_ids
kernel_block_ids = (
kv_manager_block_ids.reshape(-1, 1) * self.blocks_per_kv_block
+ self._kernel_block_arange
)
return kernel_block_ids.reshape(-1)
def get_device_tensor(self, num_reqs: int) -> torch.Tensor:
"""Returns the device tensor of the block table."""
return self.block_table.gpu[:num_reqs]
@ -160,6 +233,7 @@ class MultiGroupBlockTable:
pin_memory: bool,
device: torch.device,
block_sizes: list[int],
kernel_block_sizes: list[int],
num_speculative_tokens: int = 0,
) -> None:
# Note(hc): each dcp rank only store
@ -172,6 +246,12 @@ class MultiGroupBlockTable:
# DCP might not be initialized in testing
dcp_world_size = 1
if len(kernel_block_sizes) != len(block_sizes):
raise ValueError(
f"kernel_block_sizes length ({len(kernel_block_sizes)}) "
f"must match block_sizes length ({len(block_sizes)})"
)
self.block_tables = [
BlockTable(
block_size,
@ -183,8 +263,9 @@ class MultiGroupBlockTable:
max_num_batched_tokens,
pin_memory,
device,
kernel_block_size,
)
for block_size in block_sizes
for block_size, kernel_block_size in zip(block_sizes, kernel_block_sizes)
]
def append_row(self, block_ids: tuple[list[int], ...], row_idx: int) -> None:

View File

@ -78,6 +78,7 @@ class InputBatch:
pin_memory: bool,
vocab_size: int,
block_sizes: list[int], # The block_size of each kv cache group
kernel_block_sizes: list[int],
logitsprocs: Optional[LogitsProcessors] = None,
is_spec_decode: bool = False,
is_pooling_model: bool = False,
@ -132,6 +133,7 @@ class InputBatch:
pin_memory=pin_memory,
device=device,
block_sizes=block_sizes,
kernel_block_sizes=kernel_block_sizes,
num_speculative_tokens=num_speculative_tokens,
)

View File

@ -19,7 +19,7 @@ from typing_extensions import TypeAlias
import vllm.envs as envs
from vllm.attention import Attention, AttentionType
from vllm.attention.backends.abstract import AttentionBackend
from vllm.attention.backends.abstract import AttentionBackend, MultipleOf
from vllm.attention.layer import MLAAttention
from vllm.attention.layers.chunked_local_attention import ChunkedLocalAttention
from vllm.compilation.counter import compilation_counter
@ -359,6 +359,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
pin_memory=self.pin_memory,
vocab_size=self.model_config.get_vocab_size(),
block_sizes=[self.cache_config.block_size],
kernel_block_sizes=[self.cache_config.block_size],
is_spec_decode=bool(self.vllm_config.speculative_config),
logitsprocs=build_logitsprocs(
self.vllm_config,
@ -4050,6 +4051,86 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
else:
self.reorder_batch_threshold = reorder_batch_threshold_i
def _find_compatible_block_sizes(
self,
kv_manager_block_size: int,
backend_cls: type[AttentionBackend],
return_all: bool = False,
) -> list[int]:
"""
Find compatible block sizes for a backend.
Args:
kv_manager_block_size: Physical block size of KV cache
backend_cls: Attention backend class
return_all: Return all compatible sizes if True, max size if False
Returns:
Compatible block size(s) based on return_all parameter
Raises:
ValueError: If no compatible block size found
"""
supported_block_size = backend_cls.get_supported_kernel_block_size()
compatible_sizes = []
for block_size in supported_block_size:
if isinstance(block_size, int):
if kv_manager_block_size % block_size == 0:
compatible_sizes.append(block_size)
elif (
isinstance(block_size, MultipleOf)
and kv_manager_block_size % block_size.base == 0
):
compatible_sizes.append(kv_manager_block_size)
if not compatible_sizes:
raise ValueError(f"No compatible block size for {kv_manager_block_size}")
return compatible_sizes if return_all else [max(compatible_sizes)]
def _select_common_block_size(
self, kv_manager_block_size: int, attn_groups: list[AttentionGroup]
) -> int:
"""
Select common block size for all backends.
Args:
kv_manager_block_size: Block size of KV cache
attn_groups: List of attention groups
Returns:
Block size supported by all backends,
prioritizing cache_config.block_size
Raises:
ValueError: If no common block size found
"""
all_backend_supports = []
for attn_group in attn_groups:
compatible_sizes = self._find_compatible_block_sizes(
kv_manager_block_size, attn_group.backend, return_all=True
)
supported_sizes = sorted(list(set(compatible_sizes)), reverse=True)
all_backend_supports.append(set(supported_sizes))
common_supported_sizes = set.intersection(*all_backend_supports)
if not common_supported_sizes:
error_msg = f"No common block size for {kv_manager_block_size}. "
for i, attn_group in enumerate(attn_groups):
supported = all_backend_supports[i]
error_msg += (
f"Backend {attn_group.backend} supports: {sorted(supported)}. "
)
raise ValueError(error_msg)
if self.cache_config.block_size in common_supported_sizes:
return self.cache_config.block_size
return max(common_supported_sizes)
def may_reinitialize_input_batch(self, kv_cache_config: KVCacheConfig) -> None:
"""
Re-initialize the input batch if the block sizes are different from
@ -4062,8 +4143,15 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
block_sizes = [
kv_cache_group.kv_cache_spec.block_size
for kv_cache_group in kv_cache_config.kv_cache_groups
if not isinstance(kv_cache_group.kv_cache_spec, EncoderOnlyAttentionSpec)
]
if block_sizes != [self.cache_config.block_size]:
# Generate kernel_block_sizes that matches each block_size
kernel_block_sizes = self._prepare_kernel_block_sizes(kv_cache_config)
if block_sizes != [self.cache_config.block_size] or kernel_block_sizes != [
self.cache_config.block_size
]:
assert self.cache_config.cpu_offload_gb == 0, (
"Cannot re-initialize the input batch when CPU weight "
"offloading is enabled. See https://github.com/vllm-project/vllm/pull/18298 " # noqa: E501
@ -4077,6 +4165,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
pin_memory=self.pin_memory,
vocab_size=self.model_config.get_vocab_size(),
block_sizes=block_sizes,
kernel_block_sizes=kernel_block_sizes,
is_spec_decode=bool(self.vllm_config.speculative_config),
logitsprocs=self.input_batch.logitsprocs,
is_pooling_model=self.is_pooling_model,
@ -4128,6 +4217,46 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
for attn_groups in self.attn_groups:
yield from attn_groups
def _prepare_kernel_block_sizes(self, kv_cache_config: KVCacheConfig) -> list[int]:
"""
Generate kernel_block_sizes that matches each block_size.
For attention backends that support virtual block splitting,
use the supported block sizes from the backend.
For other backends (like Mamba), use the same block size (no splitting).
Args:
kv_cache_config: The KV cache configuration.
Returns:
list[int]: List of kernel block sizes for each cache group.
"""
kernel_block_sizes = []
for kv_cache_group_id, kv_cache_group in enumerate(
kv_cache_config.kv_cache_groups
):
if isinstance(kv_cache_group.kv_cache_spec, EncoderOnlyAttentionSpec):
continue
elif isinstance(kv_cache_group.kv_cache_spec, AttentionSpec):
# This is an attention backend that supports virtual
# block splitting. Get the supported block sizes from
# all backends in the group.
attn_groups = self.attn_groups[kv_cache_group_id]
kv_manager_block_size = kv_cache_group.kv_cache_spec.block_size
selected_kernel_size = self._select_common_block_size(
kv_manager_block_size, attn_groups
)
kernel_block_sizes.append(selected_kernel_size)
elif isinstance(kv_cache_group.kv_cache_spec, MambaSpec):
# This is likely Mamba or other non-attention cache,
# no splitting.
kernel_block_sizes.append(kv_cache_group.kv_cache_spec.block_size)
else:
raise NotImplementedError(
f"unknown kv cache spec {kv_cache_group.kv_cache_spec}"
)
return kernel_block_sizes
def _reshape_kv_cache_tensors(
self,
kv_cache_config: KVCacheConfig,
@ -4157,16 +4286,24 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
num_blocks = raw_tensor.numel() // kv_cache_spec.page_size_bytes
if isinstance(kv_cache_spec, AttentionSpec):
has_attn = True
kv_manager_block_size = kv_cache_spec.block_size
kernel_size_list = self._find_compatible_block_sizes(
kv_manager_block_size, attn_backend, return_all=False
)
kernel_size = kernel_size_list[0]
num_blocks_per_kv_block = kv_manager_block_size // kernel_size
kernel_num_blocks = num_blocks * num_blocks_per_kv_block
kv_cache_shape = attn_backend.get_kv_cache_shape(
num_blocks,
kv_cache_spec.block_size,
kernel_num_blocks,
kernel_size,
kv_cache_spec.num_kv_heads,
kv_cache_spec.head_size,
cache_dtype_str=self.cache_config.cache_dtype,
)
dtype = kv_cache_spec.dtype
try:
kv_cache_stride_order = attn_backend.get_kv_cache_stride_order()
kv_cache_stride_order = attn_backend.get_kv_cache_stride_order() # noqa: E501
assert len(kv_cache_stride_order) == len(kv_cache_shape)
except (AttributeError, NotImplementedError):
kv_cache_stride_order = tuple(range(len(kv_cache_shape)))
@ -4320,10 +4457,11 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
"""
kv_cache_config = deepcopy(kv_cache_config)
self.kv_cache_config = kv_cache_config
self.may_reinitialize_input_batch(kv_cache_config)
self.may_add_encoder_only_layers_to_kv_cache_config()
self.maybe_add_kv_sharing_layers_to_kv_cache_groups(kv_cache_config)
self.initialize_attn_backend(kv_cache_config)
# Reinitialize need to after initialize_attn_backend
self.may_reinitialize_input_batch(kv_cache_config)
kv_caches = self.initialize_kv_cache_tensors(kv_cache_config)
if self.speculative_config and self.speculative_config.use_eagle():

View File

@ -27,6 +27,7 @@ class InputBatch:
pin_memory: bool,
vocab_size: int,
block_sizes: list[int], # The block_size of each kv cache group
kernel_block_sizes: list[int],
):
self.max_num_reqs = max_num_reqs
self.max_model_len = max_model_len
@ -68,6 +69,7 @@ class InputBatch:
pin_memory=pin_memory,
device=device,
block_sizes=block_sizes,
kernel_block_sizes=kernel_block_sizes,
)
# Sampling-related.

View File

@ -259,6 +259,7 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
pin_memory=self.pin_memory,
vocab_size=self.model_config.get_vocab_size(),
block_sizes=[self.block_size],
kernel_block_sizes=[self.cache_config.block_size],
)
# Cached torch/numpy tensor
@ -1788,6 +1789,9 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
block_sizes=[
kv_cache_config.kv_cache_groups[0].kv_cache_spec.block_size
],
kernel_block_sizes=[
kv_cache_config.kv_cache_groups[0].kv_cache_spec.block_size
],
)
# Verify dtype compatibility between block_table_cpu and input_batch
assert (