[Hybrid]: Decouple Kernel Block Size from KV Page Size (#24486)

Signed-off-by: lizhiyuan <uniartisan2017@gmail.com>
Signed-off-by: Zhiyuan Li <uniartisan2017@gmail.com>
This commit is contained in:
Zhiyuan Li 2025-10-09 14:43:39 +08:00 committed by GitHub
parent d17f0fbf30
commit d24cf322e1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
18 changed files with 573 additions and 55 deletions

View File

@ -241,6 +241,7 @@ def test_sampling_metadata_in_input_batch(device: str, batch_size: int):
pin_memory=is_pin_memory_available(), pin_memory=is_pin_memory_available(),
vocab_size=1024, vocab_size=1024,
block_sizes=[1], block_sizes=[1],
kernel_block_sizes=[1],
) )
reqs: list[CachedRequestState] = [] reqs: list[CachedRequestState] = []
req_id_reqs = {} req_id_reqs = {}
@ -335,6 +336,7 @@ def test_swap_states_in_input_batch(device: str, batch_size: int, swap_list: lis
pin_memory=is_pin_memory_available(), pin_memory=is_pin_memory_available(),
vocab_size=1024, vocab_size=1024,
block_sizes=[1], block_sizes=[1],
kernel_block_sizes=[1],
) )
ref_input_batch: InputBatch = InputBatch( ref_input_batch: InputBatch = InputBatch(
max_num_reqs=batch_size, max_num_reqs=batch_size,
@ -344,6 +346,7 @@ def test_swap_states_in_input_batch(device: str, batch_size: int, swap_list: lis
pin_memory=is_pin_memory_available(), pin_memory=is_pin_memory_available(),
vocab_size=1024, vocab_size=1024,
block_sizes=[1], block_sizes=[1],
kernel_block_sizes=[1],
) )
reqs: list[CachedRequestState] = [] reqs: list[CachedRequestState] = []

View File

@ -68,6 +68,9 @@ def initialize_kv_cache(runner: GPUModelRunner):
pin_memory=runner.pin_memory, pin_memory=runner.pin_memory,
vocab_size=runner.model_config.get_vocab_size(), vocab_size=runner.model_config.get_vocab_size(),
block_sizes=[kv_cache_config.kv_cache_groups[0].kv_cache_spec.block_size], block_sizes=[kv_cache_config.kv_cache_groups[0].kv_cache_spec.block_size],
kernel_block_sizes=[
kv_cache_config.kv_cache_groups[0].kv_cache_spec.block_size
],
) )
runner.initialize_attn_backend(kv_cache_config) runner.initialize_attn_backend(kv_cache_config)
@ -817,42 +820,231 @@ def test_hybrid_attention_mamba_tensor_shapes(monkeypatch):
ssm_shape = vllm_ctx[layer_2].kv_cache[0][1].shape ssm_shape = vllm_ctx[layer_2].kv_cache[0][1].shape
# assert we are using FlashInfer # assert we are using FlashInfer
assert attn_shape[0] == num_blocks assert attn_shape[0] % num_blocks == 0
block_split_ratio = attn_shape[0] // num_blocks
# use small blocks for testing to avoid memory issues
test_block_size = min(2, len(blocks0), len(blocks1))
# use non-overlapping blocks to avoid data contamination
# Split kernel blocks: first half for attention, second half for mamba
mid_point = num_blocks // 2
# attention uses kernel blocks from first half (mapped to logical blocks)
kv_blocks_for_attention = np.array([0, 1])[:test_block_size]
# mamba uses kernel blocks from second half
kv_blocks_for_mamba = np.array([mid_point, mid_point + 1])[:test_block_size]
# create small constant tensors for testing with corrected shapes
# attention: [block_size, ...] starting from dimension 2
attn_constant_shape = attn_shape[2:]
conv_constant_shape = conv_shape[1:]
ssm_constant_shape = ssm_shape[1:]
attn_blocks_constant = torch.full( attn_blocks_constant = torch.full(
(len(blocks0), *attn_shape[1:]), device=DEVICE, fill_value=3.33 (test_block_size, *attn_constant_shape), device=DEVICE, fill_value=3.33
) )
conv_blocks_constant = torch.full( conv_blocks_constant = torch.full(
(len(blocks1), *conv_shape[1:]), device=DEVICE, fill_value=6.66 (test_block_size, *conv_constant_shape), device=DEVICE, fill_value=6.66
) )
ssm_blocks_constant = torch.full( ssm_blocks_constant = torch.full(
(len(blocks1), *ssm_shape[1:]), device=DEVICE, fill_value=9.99 (test_block_size, *ssm_constant_shape), device=DEVICE, fill_value=9.99
) )
# fill all attention blocks with constant # Fill attention blocks with constants using kv block indices
for layer in [layer_0, layer_1]: kernel_blocks_for_attention = kv_blocks_for_attention * block_split_ratio
vllm_ctx[layer].kv_cache[0][blocks0, :] = (
attn_blocks_constant.detach().clone()
)
# fill all mamba blocks with constant for layer in [layer_0, layer_1]:
# attention: kv_cache[0][kernel_block_idx, kv_idx, ...]
for i, kernel_block in enumerate(kernel_blocks_for_attention):
vllm_ctx[layer].kv_cache[0][kernel_block, :] = attn_blocks_constant[i]
# fill mamba blocks with constants using kernel block indices
for layer in [layer_2, layer_3, layer_4, layer_5]: for layer in [layer_2, layer_3, layer_4, layer_5]:
vllm_ctx[layer].kv_cache[0][0][blocks1, :] = ( # mamba: kv_cache[0][component][kernel_block_idx, ...]
conv_blocks_constant.detach().clone() for i, kv_block in enumerate(kv_blocks_for_mamba):
) vllm_ctx[layer].kv_cache[0][0][kv_block, :] = conv_blocks_constant[i]
vllm_ctx[layer].kv_cache[0][1][blocks1, :] = ( vllm_ctx[layer].kv_cache[0][1][kv_block, :] = ssm_blocks_constant[i]
ssm_blocks_constant.detach().clone()
)
# verify attention and mamba contents are correct # verify attention and mamba contents are correct
for layer in [layer_0, layer_1]: for layer in [layer_0, layer_1]:
assert torch.equal( for i, kernel_block in enumerate(kernel_blocks_for_attention):
vllm_ctx[layer].kv_cache[0][blocks0, :], attn_blocks_constant actual_kv = vllm_ctx[layer].kv_cache[0][kernel_block, :]
) expected = attn_blocks_constant[i]
# Check K and V separately
assert torch.equal(actual_kv[0], expected)
assert torch.equal(actual_kv[1], expected)
for layer in [layer_2, layer_3, layer_4, layer_5]: for layer in [layer_2, layer_3, layer_4, layer_5]:
assert torch.equal( for i, kv_block in enumerate(kv_blocks_for_mamba):
vllm_ctx[layer].kv_cache[0][0][blocks1, :], conv_blocks_constant actual_conv = vllm_ctx[layer].kv_cache[0][0][kv_block, :]
) actual_ssm = vllm_ctx[layer].kv_cache[0][1][kv_block, :]
assert torch.equal( expected_conv = conv_blocks_constant[i]
vllm_ctx[layer].kv_cache[0][1][blocks1, :], ssm_blocks_constant expected_ssm = ssm_blocks_constant[i]
)
assert torch.equal(actual_conv, expected_conv)
assert torch.equal(actual_ssm, expected_ssm)
for layer in [layer_2, layer_3, layer_4, layer_5]:
for i, kv_block in enumerate(kv_blocks_for_mamba):
actual_conv = vllm_ctx[layer].kv_cache[0][0][kv_block, :]
actual_ssm = vllm_ctx[layer].kv_cache[0][1][kv_block, :]
expected_conv = conv_blocks_constant[i]
expected_ssm = ssm_blocks_constant[i]
assert torch.equal(actual_conv, expected_conv)
assert torch.equal(actual_ssm, expected_ssm)
def test_hybrid_block_table_initialization():
"""Test hybrid block table with different kernel and kvcache_manager block
sizes."""
from vllm.v1.worker.block_table import BlockTable
# Test configuration: kvcache_manager block size = 32,
# kernel block size = 16
block_size = 32
kernel_block_sizes = [16]
max_num_reqs = 10
max_num_blocks_per_req = 20
max_num_batched_tokens = 512
block_table = BlockTable(
block_size=block_size,
max_num_reqs=max_num_reqs,
max_num_blocks_per_req=max_num_blocks_per_req,
max_num_batched_tokens=max_num_batched_tokens,
pin_memory=False,
device=torch.device(DEVICE),
kernel_block_size=kernel_block_sizes[0],
)
# Verify hybrid block configuration
assert block_table.use_hybrid_blocks is True
assert block_table.block_size == kernel_block_sizes[0]
assert block_table.blocks_per_kv_block == (
block_size // kernel_block_sizes[0]
) # Changed to use first element
# Test block table conversion logic
# One kvcache_manager block should map to multiple kernel blocks
kvcache_manager_blocks = [0, 1, 2]
# Verify that kvcache_manager blocks can be converted to kernel blocks
# and that block table operations work correctly.
req_index = 0
block_table.append_row(kvcache_manager_blocks, req_index)
# Get expected kernel blocks from the implementation for verification.
expected_kernel_blocks = block_table._map_to_kernel_blocks(
np.array(kvcache_manager_blocks)
)
# Verify block table state
assert block_table.num_blocks_per_row[req_index] == len(expected_kernel_blocks)
assert np.array_equal(
block_table.block_table.np[req_index, : len(expected_kernel_blocks)],
expected_kernel_blocks,
)
def test_input_batch_with_kernel_block_sizes():
"""Test InputBatch initialization with kernel_block_sizes parameter."""
max_num_reqs = 10
max_model_len = 512
max_num_batched_tokens = 512
device = torch.device(DEVICE)
pin_memory = False
vocab_size = 50272
# Test with different kernel block sizes
block_sizes = [32, 64]
kernel_block_sizes = [16, 32]
input_batch = InputBatch(
max_num_reqs=max_num_reqs,
max_model_len=max_model_len,
max_num_batched_tokens=max_num_batched_tokens,
device=device,
pin_memory=pin_memory,
vocab_size=vocab_size,
block_sizes=block_sizes,
kernel_block_sizes=kernel_block_sizes,
)
# Verify that block tables were created with kernel block sizes
assert len(input_batch.block_table.block_tables) == len(block_sizes)
for i, (kv_size, kernel_size) in enumerate(zip(block_sizes, kernel_block_sizes)):
block_table = input_batch.block_table.block_tables[i]
if kv_size != kernel_size:
assert block_table.use_hybrid_blocks is True
assert block_table.block_size == kernel_size
else:
assert block_table.use_hybrid_blocks is False
assert block_table.block_size == kernel_size
def test_hybrid_cache_integration(model_runner, dist_init):
"""Test hybrid cache architecture integration with GPUModelRunner."""
# Create a new model runner with hybrid cache configuration
vllm_config = get_vllm_config()
# Configure hybrid cache with different kvcache_manager block size
vllm_config.cache_config.block_size = 32
model_config = vllm_config.model_config
num_heads = model_config.get_num_kv_heads(vllm_config.parallel_config)
head_size = model_config.get_head_size()
vllm_config.compilation_config.static_forward_context["layer.0"] = Attention(
num_heads, head_size, 0.1
)
runner = GPUModelRunner(vllm_config, DEVICE)
# Initialize KV cache with configuration
attn_spec = FullAttentionSpec(
block_size=16, # Use kernel block size directly
num_kv_heads=runner.model_config.get_num_kv_heads(runner.parallel_config),
head_size=runner.model_config.get_head_size(),
dtype=runner.kv_cache_dtype,
)
tensor_size = attn_spec.page_size_bytes * NUM_BLOCKS
kv_cache_config = KVCacheConfig(
num_blocks=NUM_BLOCKS,
kv_cache_tensors=[
KVCacheTensor(size=tensor_size, shared_by=["layer.0"]),
],
kv_cache_groups=[
KVCacheGroupSpec(layer_names=["layer.0"], kv_cache_spec=attn_spec)
],
)
runner.kv_cache_config = kv_cache_config
# Initialize input batch with kernel block sizes
runner.input_batch = InputBatch(
max_num_reqs=runner.max_num_reqs,
max_model_len=runner.max_model_len,
max_num_batched_tokens=runner.max_num_tokens,
device=runner.device,
pin_memory=runner.pin_memory,
vocab_size=runner.model_config.get_vocab_size(),
block_sizes=[kv_cache_config.kv_cache_groups[0].kv_cache_spec.block_size],
kernel_block_sizes=[16],
) # Use kernel block size
runner.initialize_attn_backend(kv_cache_config)
# Verify hybrid block table configuration
block_table = runner.input_batch.block_table.block_tables[0]
assert block_table.block_size == (
kv_cache_config.kv_cache_groups[0].kv_cache_spec.block_size
)
# Test request processing with hybrid blocks
req_id = "hybrid_req_0"
scheduler_output = _schedule_new_request(req_id)
# Update states should work with hybrid blocks
runner._update_states(scheduler_output)
assert _is_req_scheduled(runner, req_id)
assert _is_req_state_block_table_match(runner, req_id)

View File

@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Generic, Optional, Protocol, TypeVar from typing import Generic, Optional, Protocol, TypeVar, Union
import torch import torch
@ -26,6 +26,13 @@ class AttentionType:
"""Attention between dec. Q and enc. K/V for encoder-decoder.""" """Attention between dec. Q and enc. K/V for encoder-decoder."""
class MultipleOf:
base: int
def __init__(self, base: int):
self.base = base
class AttentionBackend(ABC): class AttentionBackend(ABC):
"""Abstract class for attention backends.""" """Abstract class for attention backends."""
@ -57,6 +64,10 @@ class AttentionBackend(ABC):
def get_metadata_cls() -> type["AttentionMetadata"]: def get_metadata_cls() -> type["AttentionMetadata"]:
raise NotImplementedError raise NotImplementedError
@classmethod
def get_supported_kernel_block_size(cls) -> list[Union[int, MultipleOf]]:
return cls.get_impl_cls().get_supported_kernel_block_size()
@classmethod @classmethod
def make_metadata(cls, *args, **kwargs) -> "AttentionMetadata": def make_metadata(cls, *args, **kwargs) -> "AttentionMetadata":
return cls.get_metadata_cls()(*args, **kwargs) return cls.get_metadata_cls()(*args, **kwargs)
@ -157,6 +168,11 @@ class AttentionImpl(ABC, Generic[T]):
) -> None: ) -> None:
raise NotImplementedError raise NotImplementedError
@staticmethod
def get_supported_kernel_block_size() -> list[Union[int, MultipleOf]]:
# TODO: implement this function for all backends.
return [MultipleOf(1)]
@abstractmethod @abstractmethod
def forward( def forward(
self, self,

View File

@ -365,6 +365,23 @@ class HybridAttentionMambaModelConfig(VerifyAndUpdateConfig):
block_size=model_config.max_model_len, block_size=model_config.max_model_len,
).page_size_bytes ).page_size_bytes
# Model may be marked as is_hybrid
# but mamba is skipped via config,
# return directly
if mamba_page_size == 0:
return
# Attention backend constraints:
# - FlashAttention (FA) requires block size to be multiple of 16
# - MLA (Multi-head Latent Attention) requires larger alignment:
# * CUTLASS_MLA backend: 128-byte alignment
# * Other MLA backends: 64-byte alignment
if model_config.use_mla:
use_cutlass_mla = envs.VLLM_ATTENTION_BACKEND == "CUTLASS_MLA"
kernel_block_alignment_size = 128 if use_cutlass_mla else 64
else:
kernel_block_alignment_size = 16
if cache_config.enable_prefix_caching: if cache_config.enable_prefix_caching:
# With prefix caching, select attention block size to # With prefix caching, select attention block size to
# optimize for mamba kernel performance # optimize for mamba kernel performance
@ -381,19 +398,28 @@ class HybridAttentionMambaModelConfig(VerifyAndUpdateConfig):
# TODO(tdoublep): this constraint can be relaxed fairly # TODO(tdoublep): this constraint can be relaxed fairly
# easily by changing the way we layout chunks in the # easily by changing the way we layout chunks in the
# mamba2 kernels. # mamba2 kernels.
chunk_size = model_config.get_mamba_chunk_size()
from math import gcd
def lcm(a, b):
return a * b // gcd(a, b)
base_chunk_size = model_config.get_mamba_chunk_size()
attn_tokens_per_mamba_state = cdiv(mamba_page_size, attn_page_size_1_token) attn_tokens_per_mamba_state = cdiv(mamba_page_size, attn_page_size_1_token)
chunk_size = lcm(base_chunk_size, kernel_block_alignment_size)
attn_block_size = chunk_size * cdiv(attn_tokens_per_mamba_state, chunk_size) attn_block_size = chunk_size * cdiv(attn_tokens_per_mamba_state, chunk_size)
cache_config.mamba_block_size = attn_block_size cache_config.mamba_block_size = attn_block_size
else: else:
# Without prefix caching, select minimum valid attention block size # Without prefix caching, select minimum valid attention block size
# to minimize mamba state padding # to minimize mamba state padding
# some attention backends (e.g. FA) only support setting # Calculate minimum attention block size that satisfies both:
# block size to multiple of 16, so let's suggest a value # 1. Backend alignment requirements (kernel_block_alignment_size)
# that would work (note: FA is currently not compatible # 2. Mamba page size compatibility (attn_page_size >= mamba_page_size)
# with mamba layers, use FlashInfer instead). attn_block_size = kernel_block_alignment_size * cdiv(
attn_block_size = 16 * cdiv(mamba_page_size, 16 * attn_page_size_1_token) mamba_page_size, kernel_block_alignment_size * attn_page_size_1_token
)
# override attention block size if either (a) the # override attention block size if either (a) the
# user has not set it or (b) the user has set it # user has not set it or (b) the user has set it

View File

@ -118,7 +118,15 @@ class CudaPlatformBase(Platform):
# TODO(lucas): handle this more gracefully # TODO(lucas): handle this more gracefully
# Note: model_config may be None during testing # Note: model_config may be None during testing
if model_config is not None and model_config.use_mla: # Note: block_size is initialized in
# HybridAttentionMambaModelConfig.verify_and_update_config
# for models with both attention and mamba,
# and doesn't need to be reinitialized here
if (
model_config is not None
and model_config.use_mla
and cache_config.block_size is not None
):
use_sparse = hasattr(vllm_config.model_config.hf_config, "index_topk") use_sparse = hasattr(vllm_config.model_config.hf_config, "index_topk")
# If `VLLM_ATTENTION_BACKEND` is not set and we are using MLA, # If `VLLM_ATTENTION_BACKEND` is not set and we are using MLA,
# then we default to FlashMLA backend for non-blackwell GPUs, # then we default to FlashMLA backend for non-blackwell GPUs,
@ -151,18 +159,22 @@ class CudaPlatformBase(Platform):
if ( if (
use_flashmla use_flashmla
and is_flashmla_dense_supported()[0] and is_flashmla_dense_supported()[0]
and cache_config.block_size != 64 and cache_config.block_size % 64 != 0
): ):
cache_config.block_size = 64 cache_config.block_size = 64
logger.info("Forcing kv cache block size to 64 for FlashMLA backend.") logger.info("Forcing kv cache block size to 64 for FlashMLA backend.")
if use_cutlass_mla and cache_config.block_size != 128: if use_cutlass_mla and cache_config.block_size % 128 != 0:
cache_config.block_size = 128 cache_config.block_size = 128
logger.info( logger.info(
"Forcing kv cache block size to 128 for CUTLASS_MLA backend." "Forcing kv cache block size to 128 for CUTLASS_MLA backend."
) )
if use_flashinfer_mla and cache_config.block_size not in [32, 64]: if (
use_flashinfer_mla
and cache_config.block_size != 32
and cache_config.block_size % 64 != 0
):
cache_config.block_size = 64 cache_config.block_size = 64
logger.info( logger.info(
"Forcing kv cache block size to 64 for FlashInferMLA backend." "Forcing kv cache block size to 64 for FlashInferMLA backend."
@ -269,12 +281,12 @@ class CudaPlatformBase(Platform):
use_cutlassmla = selected_backend == _Backend.CUTLASS_MLA or ( use_cutlassmla = selected_backend == _Backend.CUTLASS_MLA or (
selected_backend is None selected_backend is None
and cls.is_device_capability(100) and cls.is_device_capability(100)
and block_size == 128 and block_size % 128 == 0
) )
use_flashinfermla = selected_backend == _Backend.FLASHINFER_MLA or ( use_flashinfermla = selected_backend == _Backend.FLASHINFER_MLA or (
selected_backend is None selected_backend is None
and cls.is_device_capability(100) and cls.is_device_capability(100)
and block_size in [32, 64] and (block_size == 32 or block_size % 64 == 0)
) )
use_flashmla = selected_backend == _Backend.FLASHMLA or ( use_flashmla = selected_backend == _Backend.FLASHMLA or (
selected_backend is None and is_flashmla_dense_supported()[0] selected_backend is None and is_flashmla_dense_supported()[0]
@ -298,7 +310,7 @@ class CudaPlatformBase(Platform):
"vllm.v1.attention.backends.mla.flashinfer_mla.FlashInferMLABackend" "vllm.v1.attention.backends.mla.flashinfer_mla.FlashInferMLABackend"
) )
if use_flashmla: if use_flashmla:
if block_size != 64: if block_size % 64 != 0:
logger.warning( logger.warning(
"FlashMLA backend is not supported for block size %d" "FlashMLA backend is not supported for block size %d"
" (currently only supports block size 64).", " (currently only supports block size 64).",

View File

@ -3,7 +3,7 @@
"""Attention layer with FlashAttention.""" """Attention layer with FlashAttention."""
from dataclasses import dataclass from dataclasses import dataclass
from typing import Optional from typing import Optional, Union
import numpy as np import numpy as np
import torch import torch
@ -14,6 +14,7 @@ from vllm.attention.backends.abstract import (
AttentionImpl, AttentionImpl,
AttentionMetadata, AttentionMetadata,
AttentionType, AttentionType,
MultipleOf,
is_quantized_kv_cache, is_quantized_kv_cache,
) )
from vllm.attention.layer import Attention from vllm.attention.layer import Attention
@ -57,6 +58,10 @@ class FlashAttentionBackend(AttentionBackend):
def get_supported_head_sizes(cls) -> list[int]: def get_supported_head_sizes(cls) -> list[int]:
return [32, 64, 96, 128, 160, 192, 224, 256] return [32, 64, 96, 128, 160, 192, 224, 256]
@staticmethod
def get_supported_kernel_block_size() -> list[Union[int, MultipleOf]]:
return [MultipleOf(16)]
@classmethod @classmethod
def validate_head_size(cls, head_size: int) -> None: def validate_head_size(cls, head_size: int) -> None:
supported_head_sizes = cls.get_supported_head_sizes() supported_head_sizes = cls.get_supported_head_sizes()

View File

@ -23,6 +23,7 @@ from vllm.attention.backends.abstract import (
AttentionBackend, AttentionBackend,
AttentionImpl, AttentionImpl,
AttentionType, AttentionType,
MultipleOf,
) )
from vllm.config import CUDAGraphMode, VllmConfig from vllm.config import CUDAGraphMode, VllmConfig
from vllm.logger import init_logger from vllm.logger import init_logger
@ -165,6 +166,13 @@ class FlashInferBackend(AttentionBackend):
# https://github.com/flashinfer-ai/flashinfer/blob/3d55c71a62052c590c130897d3a3db49b14fcc34/include/flashinfer/utils.cuh#L157 # https://github.com/flashinfer-ai/flashinfer/blob/3d55c71a62052c590c130897d3a3db49b14fcc34/include/flashinfer/utils.cuh#L157
return [64, 128, 256] return [64, 128, 256]
@staticmethod
def get_supported_kernel_block_size() -> list[Union[int, MultipleOf]]:
# Note: Not sure for all platforms,
# but on Blackwell, only support a page size of
# 16, 32, 64
return [16, 32, 64]
@classmethod @classmethod
def validate_head_size(cls, head_size: int) -> None: def validate_head_size(cls, head_size: int) -> None:
supported_head_sizes = cls.get_supported_head_sizes() supported_head_sizes = cls.get_supported_head_sizes()

View File

@ -10,6 +10,7 @@ import vllm._custom_ops as ops
from vllm.attention.backends.abstract import ( from vllm.attention.backends.abstract import (
AttentionLayer, AttentionLayer,
AttentionType, AttentionType,
MultipleOf,
is_quantized_kv_cache, is_quantized_kv_cache,
) )
from vllm.logger import init_logger from vllm.logger import init_logger
@ -44,6 +45,10 @@ class CutlassMLABackend(MLACommonBackend):
def get_builder_cls() -> type["CutlassMLAMetadataBuilder"]: def get_builder_cls() -> type["CutlassMLAMetadataBuilder"]:
return CutlassMLAMetadataBuilder return CutlassMLAMetadataBuilder
@staticmethod
def get_supported_kernel_block_size() -> list[Union[int, MultipleOf]]:
return [128]
class SM100Workspace: class SM100Workspace:
def __init__(self, initial_workspace_size): def __init__(self, initial_workspace_size):

View File

@ -6,7 +6,7 @@ from typing import ClassVar, Optional, Union
import torch import torch
from vllm.attention.backends.abstract import AttentionLayer, AttentionType from vllm.attention.backends.abstract import AttentionLayer, AttentionType, MultipleOf
from vllm.attention.ops.flashmla import ( from vllm.attention.ops.flashmla import (
flash_mla_with_kvcache, flash_mla_with_kvcache,
get_mla_metadata, get_mla_metadata,
@ -44,6 +44,10 @@ class FlashMLABackend(MLACommonBackend):
def get_impl_cls() -> type["FlashMLAImpl"]: def get_impl_cls() -> type["FlashMLAImpl"]:
return FlashMLAImpl return FlashMLAImpl
@staticmethod
def get_supported_kernel_block_size() -> list[Union[int, MultipleOf]]:
return [64]
@dataclass @dataclass
class FlashMLADecodeMetadata(MLACommonDecodeMetadata): class FlashMLADecodeMetadata(MLACommonDecodeMetadata):

View File

@ -3,7 +3,7 @@
"""Attention layer with AiterFlashAttention.""" """Attention layer with AiterFlashAttention."""
from dataclasses import dataclass from dataclasses import dataclass
from typing import Optional from typing import Optional, Union
import torch import torch
@ -12,6 +12,7 @@ from vllm.attention.backends.abstract import (
AttentionImpl, AttentionImpl,
AttentionMetadata, AttentionMetadata,
AttentionType, AttentionType,
MultipleOf,
) )
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.logger import init_logger from vllm.logger import init_logger
@ -359,6 +360,10 @@ class AiterFlashAttentionBackend(AttentionBackend):
def get_supported_head_sizes(cls) -> list[int]: def get_supported_head_sizes(cls) -> list[int]:
return [64, 128, 256] return [64, 128, 256]
@staticmethod
def get_supported_kernel_block_size() -> list[Union[int, MultipleOf]]:
return [MultipleOf(16)]
@classmethod @classmethod
def validate_head_size(cls, head_size: int) -> None: def validate_head_size(cls, head_size: int) -> None:
supported_head_sizes = cls.get_supported_head_sizes() supported_head_sizes = cls.get_supported_head_sizes()

View File

@ -4,7 +4,7 @@
import ast import ast
from dataclasses import dataclass from dataclasses import dataclass
from typing import Optional from typing import Optional, Union
import torch import torch
@ -14,6 +14,7 @@ from vllm.attention.backends.abstract import (
AttentionImpl, AttentionImpl,
AttentionMetadata, AttentionMetadata,
AttentionType, AttentionType,
MultipleOf,
) )
from vllm.attention.ops.triton_unified_attention import unified_attention from vllm.attention.ops.triton_unified_attention import unified_attention
from vllm.config import VllmConfig from vllm.config import VllmConfig
@ -39,6 +40,10 @@ class TreeAttentionBackend(AttentionBackend):
def get_supported_head_sizes(cls) -> list[int]: def get_supported_head_sizes(cls) -> list[int]:
return [32, 64, 96, 128, 160, 192, 224, 256] return [32, 64, 96, 128, 160, 192, 224, 256]
@staticmethod
def get_supported_kernel_block_size() -> list[Union[int, MultipleOf]]:
return [MultipleOf(16)]
@classmethod @classmethod
def validate_head_size(cls, head_size: int) -> None: def validate_head_size(cls, head_size: int) -> None:
supported_head_sizes = cls.get_supported_head_sizes() supported_head_sizes = cls.get_supported_head_sizes()

View File

@ -3,7 +3,7 @@
"""High-Performance Triton-only Attention layer.""" """High-Performance Triton-only Attention layer."""
from dataclasses import dataclass from dataclasses import dataclass
from typing import ClassVar, Optional from typing import ClassVar, Optional, Union
import torch import torch
@ -12,6 +12,7 @@ from vllm.attention.backends.abstract import (
AttentionImpl, AttentionImpl,
AttentionMetadata, AttentionMetadata,
AttentionType, AttentionType,
MultipleOf,
) )
from vllm.attention.ops.triton_reshape_and_cache_flash import ( from vllm.attention.ops.triton_reshape_and_cache_flash import (
triton_reshape_and_cache_flash, triton_reshape_and_cache_flash,
@ -157,6 +158,10 @@ class TritonAttentionBackend(AttentionBackend):
def get_supported_dtypes(cls) -> list[torch.dtype]: def get_supported_dtypes(cls) -> list[torch.dtype]:
return [torch.float16, torch.bfloat16, torch.float32] return [torch.float16, torch.bfloat16, torch.float32]
@staticmethod
def get_supported_kernel_block_size() -> list[Union[int, MultipleOf]]:
return [MultipleOf(16)]
@classmethod @classmethod
def validate_head_size(cls, head_size: int) -> None: def validate_head_size(cls, head_size: int) -> None:
# Triton Attention supports any head size above 32 # Triton Attention supports any head size above 32

View File

@ -3,7 +3,7 @@
"""Attention layer with XFormersAttention.""" """Attention layer with XFormersAttention."""
from dataclasses import dataclass from dataclasses import dataclass
from typing import Optional from typing import Optional, Union
import torch import torch
@ -12,6 +12,7 @@ from vllm.attention.backends.abstract import (
AttentionImpl, AttentionImpl,
AttentionMetadata, AttentionMetadata,
AttentionType, AttentionType,
MultipleOf,
) )
from vllm.attention.ops.triton_unified_attention import unified_attention from vllm.attention.ops.triton_unified_attention import unified_attention
from vllm.config import VllmConfig from vllm.config import VllmConfig
@ -80,6 +81,10 @@ class XFormersAttentionBackend(AttentionBackend):
256, 256,
] ]
@staticmethod
def get_supported_kernel_block_size() -> list[Union[int, MultipleOf]]:
return [MultipleOf(16)]
@classmethod @classmethod
def validate_head_size(cls, head_size: int) -> None: def validate_head_size(cls, head_size: int) -> None:
supported_head_sizes = cls.get_supported_head_sizes() supported_head_sizes = cls.get_supported_head_sizes()

View File

@ -22,22 +22,64 @@ class BlockTable:
max_num_batched_tokens: int, max_num_batched_tokens: int,
pin_memory: bool, pin_memory: bool,
device: torch.device, device: torch.device,
kernel_block_size: int,
): ):
self.block_size = block_size """
Args:
block_size: Block size used for KV cache memory allocation
max_num_reqs: Maximum number of concurrent requests supported.
max_num_blocks_per_req: Maximum number of blocks per request.
max_num_batched_tokens: Maximum number of tokens in a batch.
pin_memory: Whether to pin memory for faster GPU transfers.
device: Target device for the block table.
kernel_block_size: The block_size of underlying attention kernel.
Will be the same as `block_size` if `block_size` is supported
by the attention kernel.
"""
self.max_num_reqs = max_num_reqs self.max_num_reqs = max_num_reqs
self.max_num_blocks_per_req = max_num_blocks_per_req
self.max_num_batched_tokens = max_num_batched_tokens self.max_num_batched_tokens = max_num_batched_tokens
self.pin_memory = pin_memory self.pin_memory = pin_memory
self.device = device self.device = device
if kernel_block_size == block_size:
# Standard case: allocation and computation use same block size
# No block splitting needed, direct mapping
self.block_size = block_size
self.blocks_per_kv_block = 1
self.use_hybrid_blocks = False
else:
# Hybrid case: allocation block size differs from kernel block size
# Memory blocks are subdivided to match kernel requirements
# Example: 32-token memory blocks with 16-token kernel blocks
# → Each memory block corresponds to 2 kernel blocks
if block_size % kernel_block_size != 0:
raise ValueError(
f"kernel_block_size {kernel_block_size} must divide "
f"kv_manager_block_size size {block_size} evenly"
)
self.block_size = kernel_block_size
self.blocks_per_kv_block = block_size // kernel_block_size
self.use_hybrid_blocks = True
self.max_num_blocks_per_req = max_num_blocks_per_req * self.blocks_per_kv_block
self.block_table = self._make_buffer( self.block_table = self._make_buffer(
max_num_reqs, max_num_blocks_per_req, dtype=torch.int32 self.max_num_reqs, self.max_num_blocks_per_req, dtype=torch.int32
) )
self.num_blocks_per_row = np.zeros(max_num_reqs, dtype=np.int32) self.num_blocks_per_row = np.zeros(max_num_reqs, dtype=np.int32)
self.slot_mapping = self._make_buffer( self.slot_mapping = self._make_buffer(
self.max_num_batched_tokens, dtype=torch.int64 self.max_num_batched_tokens, dtype=torch.int64
) )
if self.use_hybrid_blocks:
self._kernel_block_arange = np.arange(0, self.blocks_per_kv_block).reshape(
1, -1
)
else:
self._kernel_block_arange = None
try: try:
self.dcp_world_size = get_dcp_group().world_size self.dcp_world_size = get_dcp_group().world_size
self.dcp_rank = get_dcp_group().rank_in_group self.dcp_rank = get_dcp_group().rank_in_group
@ -53,6 +95,10 @@ class BlockTable:
) -> None: ) -> None:
if not block_ids: if not block_ids:
return return
if self.use_hybrid_blocks:
block_ids = self._map_to_kernel_blocks(np.array(block_ids))
num_blocks = len(block_ids) num_blocks = len(block_ids)
start = self.num_blocks_per_row[row_idx] start = self.num_blocks_per_row[row_idx]
self.num_blocks_per_row[row_idx] += num_blocks self.num_blocks_per_row[row_idx] += num_blocks
@ -94,6 +140,7 @@ class BlockTable:
req_indices * self.max_num_blocks_per_req req_indices * self.max_num_blocks_per_req
+ positions // virtual_block_size + positions // virtual_block_size
) )
block_numbers = self.block_table.np.ravel()[block_table_indices] block_numbers = self.block_table.np.ravel()[block_table_indices]
# Use virtual_block_size for mask calculation, which marks local # Use virtual_block_size for mask calculation, which marks local
# tokens. # tokens.
@ -111,6 +158,7 @@ class BlockTable:
block_table_indices = ( block_table_indices = (
req_indices * self.max_num_blocks_per_req + positions // self.block_size req_indices * self.max_num_blocks_per_req + positions // self.block_size
) )
block_numbers = self.block_table.np.ravel()[block_table_indices] block_numbers = self.block_table.np.ravel()[block_table_indices]
block_offsets = positions % self.block_size block_offsets = positions % self.block_size
np.add( np.add(
@ -129,6 +177,31 @@ class BlockTable:
self.block_table.gpu.fill_(0) self.block_table.gpu.fill_(0)
self.block_table.cpu.fill_(0) self.block_table.cpu.fill_(0)
def _map_to_kernel_blocks(self, kv_manager_block_ids: np.ndarray) -> np.ndarray:
"""Convert kv_manager_block_id IDs to kernel block IDs.
Example:
# kv_manager_block_ids: 32 tokens,
# Kernel block size: 16 tokens
# blocks_per_kv_block = 2
>>> kv_manager_block_ids = np.array([0, 1, 2])
>>> Result: [0, 1, 2, 3, 4, 5]
# Each kv_manager_block_id maps to 2 kernel block id:
# kv_manager_block_id 0 → kernel block id [0, 1]
# kv_manager_block_id 1 → kernel block id [2, 3]
# kv_manager_block_id 2 → kernel block id [4, 5]
"""
if not self.use_hybrid_blocks:
return kv_manager_block_ids
kernel_block_ids = (
kv_manager_block_ids.reshape(-1, 1) * self.blocks_per_kv_block
+ self._kernel_block_arange
)
return kernel_block_ids.reshape(-1)
def get_device_tensor(self, num_reqs: int) -> torch.Tensor: def get_device_tensor(self, num_reqs: int) -> torch.Tensor:
"""Returns the device tensor of the block table.""" """Returns the device tensor of the block table."""
return self.block_table.gpu[:num_reqs] return self.block_table.gpu[:num_reqs]
@ -160,6 +233,7 @@ class MultiGroupBlockTable:
pin_memory: bool, pin_memory: bool,
device: torch.device, device: torch.device,
block_sizes: list[int], block_sizes: list[int],
kernel_block_sizes: list[int],
num_speculative_tokens: int = 0, num_speculative_tokens: int = 0,
) -> None: ) -> None:
# Note(hc): each dcp rank only store # Note(hc): each dcp rank only store
@ -172,6 +246,12 @@ class MultiGroupBlockTable:
# DCP might not be initialized in testing # DCP might not be initialized in testing
dcp_world_size = 1 dcp_world_size = 1
if len(kernel_block_sizes) != len(block_sizes):
raise ValueError(
f"kernel_block_sizes length ({len(kernel_block_sizes)}) "
f"must match block_sizes length ({len(block_sizes)})"
)
self.block_tables = [ self.block_tables = [
BlockTable( BlockTable(
block_size, block_size,
@ -183,8 +263,9 @@ class MultiGroupBlockTable:
max_num_batched_tokens, max_num_batched_tokens,
pin_memory, pin_memory,
device, device,
kernel_block_size,
) )
for block_size in block_sizes for block_size, kernel_block_size in zip(block_sizes, kernel_block_sizes)
] ]
def append_row(self, block_ids: tuple[list[int], ...], row_idx: int) -> None: def append_row(self, block_ids: tuple[list[int], ...], row_idx: int) -> None:

View File

@ -78,6 +78,7 @@ class InputBatch:
pin_memory: bool, pin_memory: bool,
vocab_size: int, vocab_size: int,
block_sizes: list[int], # The block_size of each kv cache group block_sizes: list[int], # The block_size of each kv cache group
kernel_block_sizes: list[int],
logitsprocs: Optional[LogitsProcessors] = None, logitsprocs: Optional[LogitsProcessors] = None,
is_spec_decode: bool = False, is_spec_decode: bool = False,
is_pooling_model: bool = False, is_pooling_model: bool = False,
@ -132,6 +133,7 @@ class InputBatch:
pin_memory=pin_memory, pin_memory=pin_memory,
device=device, device=device,
block_sizes=block_sizes, block_sizes=block_sizes,
kernel_block_sizes=kernel_block_sizes,
num_speculative_tokens=num_speculative_tokens, num_speculative_tokens=num_speculative_tokens,
) )

View File

@ -19,7 +19,7 @@ from typing_extensions import TypeAlias
import vllm.envs as envs import vllm.envs as envs
from vllm.attention import Attention, AttentionType from vllm.attention import Attention, AttentionType
from vllm.attention.backends.abstract import AttentionBackend from vllm.attention.backends.abstract import AttentionBackend, MultipleOf
from vllm.attention.layer import MLAAttention from vllm.attention.layer import MLAAttention
from vllm.attention.layers.chunked_local_attention import ChunkedLocalAttention from vllm.attention.layers.chunked_local_attention import ChunkedLocalAttention
from vllm.compilation.counter import compilation_counter from vllm.compilation.counter import compilation_counter
@ -359,6 +359,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
pin_memory=self.pin_memory, pin_memory=self.pin_memory,
vocab_size=self.model_config.get_vocab_size(), vocab_size=self.model_config.get_vocab_size(),
block_sizes=[self.cache_config.block_size], block_sizes=[self.cache_config.block_size],
kernel_block_sizes=[self.cache_config.block_size],
is_spec_decode=bool(self.vllm_config.speculative_config), is_spec_decode=bool(self.vllm_config.speculative_config),
logitsprocs=build_logitsprocs( logitsprocs=build_logitsprocs(
self.vllm_config, self.vllm_config,
@ -4050,6 +4051,86 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
else: else:
self.reorder_batch_threshold = reorder_batch_threshold_i self.reorder_batch_threshold = reorder_batch_threshold_i
def _find_compatible_block_sizes(
self,
kv_manager_block_size: int,
backend_cls: type[AttentionBackend],
return_all: bool = False,
) -> list[int]:
"""
Find compatible block sizes for a backend.
Args:
kv_manager_block_size: Physical block size of KV cache
backend_cls: Attention backend class
return_all: Return all compatible sizes if True, max size if False
Returns:
Compatible block size(s) based on return_all parameter
Raises:
ValueError: If no compatible block size found
"""
supported_block_size = backend_cls.get_supported_kernel_block_size()
compatible_sizes = []
for block_size in supported_block_size:
if isinstance(block_size, int):
if kv_manager_block_size % block_size == 0:
compatible_sizes.append(block_size)
elif (
isinstance(block_size, MultipleOf)
and kv_manager_block_size % block_size.base == 0
):
compatible_sizes.append(kv_manager_block_size)
if not compatible_sizes:
raise ValueError(f"No compatible block size for {kv_manager_block_size}")
return compatible_sizes if return_all else [max(compatible_sizes)]
def _select_common_block_size(
self, kv_manager_block_size: int, attn_groups: list[AttentionGroup]
) -> int:
"""
Select common block size for all backends.
Args:
kv_manager_block_size: Block size of KV cache
attn_groups: List of attention groups
Returns:
Block size supported by all backends,
prioritizing cache_config.block_size
Raises:
ValueError: If no common block size found
"""
all_backend_supports = []
for attn_group in attn_groups:
compatible_sizes = self._find_compatible_block_sizes(
kv_manager_block_size, attn_group.backend, return_all=True
)
supported_sizes = sorted(list(set(compatible_sizes)), reverse=True)
all_backend_supports.append(set(supported_sizes))
common_supported_sizes = set.intersection(*all_backend_supports)
if not common_supported_sizes:
error_msg = f"No common block size for {kv_manager_block_size}. "
for i, attn_group in enumerate(attn_groups):
supported = all_backend_supports[i]
error_msg += (
f"Backend {attn_group.backend} supports: {sorted(supported)}. "
)
raise ValueError(error_msg)
if self.cache_config.block_size in common_supported_sizes:
return self.cache_config.block_size
return max(common_supported_sizes)
def may_reinitialize_input_batch(self, kv_cache_config: KVCacheConfig) -> None: def may_reinitialize_input_batch(self, kv_cache_config: KVCacheConfig) -> None:
""" """
Re-initialize the input batch if the block sizes are different from Re-initialize the input batch if the block sizes are different from
@ -4062,8 +4143,15 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
block_sizes = [ block_sizes = [
kv_cache_group.kv_cache_spec.block_size kv_cache_group.kv_cache_spec.block_size
for kv_cache_group in kv_cache_config.kv_cache_groups for kv_cache_group in kv_cache_config.kv_cache_groups
if not isinstance(kv_cache_group.kv_cache_spec, EncoderOnlyAttentionSpec)
] ]
if block_sizes != [self.cache_config.block_size]:
# Generate kernel_block_sizes that matches each block_size
kernel_block_sizes = self._prepare_kernel_block_sizes(kv_cache_config)
if block_sizes != [self.cache_config.block_size] or kernel_block_sizes != [
self.cache_config.block_size
]:
assert self.cache_config.cpu_offload_gb == 0, ( assert self.cache_config.cpu_offload_gb == 0, (
"Cannot re-initialize the input batch when CPU weight " "Cannot re-initialize the input batch when CPU weight "
"offloading is enabled. See https://github.com/vllm-project/vllm/pull/18298 " # noqa: E501 "offloading is enabled. See https://github.com/vllm-project/vllm/pull/18298 " # noqa: E501
@ -4077,6 +4165,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
pin_memory=self.pin_memory, pin_memory=self.pin_memory,
vocab_size=self.model_config.get_vocab_size(), vocab_size=self.model_config.get_vocab_size(),
block_sizes=block_sizes, block_sizes=block_sizes,
kernel_block_sizes=kernel_block_sizes,
is_spec_decode=bool(self.vllm_config.speculative_config), is_spec_decode=bool(self.vllm_config.speculative_config),
logitsprocs=self.input_batch.logitsprocs, logitsprocs=self.input_batch.logitsprocs,
is_pooling_model=self.is_pooling_model, is_pooling_model=self.is_pooling_model,
@ -4128,6 +4217,46 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
for attn_groups in self.attn_groups: for attn_groups in self.attn_groups:
yield from attn_groups yield from attn_groups
def _prepare_kernel_block_sizes(self, kv_cache_config: KVCacheConfig) -> list[int]:
"""
Generate kernel_block_sizes that matches each block_size.
For attention backends that support virtual block splitting,
use the supported block sizes from the backend.
For other backends (like Mamba), use the same block size (no splitting).
Args:
kv_cache_config: The KV cache configuration.
Returns:
list[int]: List of kernel block sizes for each cache group.
"""
kernel_block_sizes = []
for kv_cache_group_id, kv_cache_group in enumerate(
kv_cache_config.kv_cache_groups
):
if isinstance(kv_cache_group.kv_cache_spec, EncoderOnlyAttentionSpec):
continue
elif isinstance(kv_cache_group.kv_cache_spec, AttentionSpec):
# This is an attention backend that supports virtual
# block splitting. Get the supported block sizes from
# all backends in the group.
attn_groups = self.attn_groups[kv_cache_group_id]
kv_manager_block_size = kv_cache_group.kv_cache_spec.block_size
selected_kernel_size = self._select_common_block_size(
kv_manager_block_size, attn_groups
)
kernel_block_sizes.append(selected_kernel_size)
elif isinstance(kv_cache_group.kv_cache_spec, MambaSpec):
# This is likely Mamba or other non-attention cache,
# no splitting.
kernel_block_sizes.append(kv_cache_group.kv_cache_spec.block_size)
else:
raise NotImplementedError(
f"unknown kv cache spec {kv_cache_group.kv_cache_spec}"
)
return kernel_block_sizes
def _reshape_kv_cache_tensors( def _reshape_kv_cache_tensors(
self, self,
kv_cache_config: KVCacheConfig, kv_cache_config: KVCacheConfig,
@ -4157,16 +4286,24 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
num_blocks = raw_tensor.numel() // kv_cache_spec.page_size_bytes num_blocks = raw_tensor.numel() // kv_cache_spec.page_size_bytes
if isinstance(kv_cache_spec, AttentionSpec): if isinstance(kv_cache_spec, AttentionSpec):
has_attn = True has_attn = True
kv_manager_block_size = kv_cache_spec.block_size
kernel_size_list = self._find_compatible_block_sizes(
kv_manager_block_size, attn_backend, return_all=False
)
kernel_size = kernel_size_list[0]
num_blocks_per_kv_block = kv_manager_block_size // kernel_size
kernel_num_blocks = num_blocks * num_blocks_per_kv_block
kv_cache_shape = attn_backend.get_kv_cache_shape( kv_cache_shape = attn_backend.get_kv_cache_shape(
num_blocks, kernel_num_blocks,
kv_cache_spec.block_size, kernel_size,
kv_cache_spec.num_kv_heads, kv_cache_spec.num_kv_heads,
kv_cache_spec.head_size, kv_cache_spec.head_size,
cache_dtype_str=self.cache_config.cache_dtype, cache_dtype_str=self.cache_config.cache_dtype,
) )
dtype = kv_cache_spec.dtype dtype = kv_cache_spec.dtype
try: try:
kv_cache_stride_order = attn_backend.get_kv_cache_stride_order() kv_cache_stride_order = attn_backend.get_kv_cache_stride_order() # noqa: E501
assert len(kv_cache_stride_order) == len(kv_cache_shape) assert len(kv_cache_stride_order) == len(kv_cache_shape)
except (AttributeError, NotImplementedError): except (AttributeError, NotImplementedError):
kv_cache_stride_order = tuple(range(len(kv_cache_shape))) kv_cache_stride_order = tuple(range(len(kv_cache_shape)))
@ -4320,10 +4457,11 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
""" """
kv_cache_config = deepcopy(kv_cache_config) kv_cache_config = deepcopy(kv_cache_config)
self.kv_cache_config = kv_cache_config self.kv_cache_config = kv_cache_config
self.may_reinitialize_input_batch(kv_cache_config)
self.may_add_encoder_only_layers_to_kv_cache_config() self.may_add_encoder_only_layers_to_kv_cache_config()
self.maybe_add_kv_sharing_layers_to_kv_cache_groups(kv_cache_config) self.maybe_add_kv_sharing_layers_to_kv_cache_groups(kv_cache_config)
self.initialize_attn_backend(kv_cache_config) self.initialize_attn_backend(kv_cache_config)
# Reinitialize need to after initialize_attn_backend
self.may_reinitialize_input_batch(kv_cache_config)
kv_caches = self.initialize_kv_cache_tensors(kv_cache_config) kv_caches = self.initialize_kv_cache_tensors(kv_cache_config)
if self.speculative_config and self.speculative_config.use_eagle(): if self.speculative_config and self.speculative_config.use_eagle():

View File

@ -27,6 +27,7 @@ class InputBatch:
pin_memory: bool, pin_memory: bool,
vocab_size: int, vocab_size: int,
block_sizes: list[int], # The block_size of each kv cache group block_sizes: list[int], # The block_size of each kv cache group
kernel_block_sizes: list[int],
): ):
self.max_num_reqs = max_num_reqs self.max_num_reqs = max_num_reqs
self.max_model_len = max_model_len self.max_model_len = max_model_len
@ -68,6 +69,7 @@ class InputBatch:
pin_memory=pin_memory, pin_memory=pin_memory,
device=device, device=device,
block_sizes=block_sizes, block_sizes=block_sizes,
kernel_block_sizes=kernel_block_sizes,
) )
# Sampling-related. # Sampling-related.

View File

@ -259,6 +259,7 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
pin_memory=self.pin_memory, pin_memory=self.pin_memory,
vocab_size=self.model_config.get_vocab_size(), vocab_size=self.model_config.get_vocab_size(),
block_sizes=[self.block_size], block_sizes=[self.block_size],
kernel_block_sizes=[self.cache_config.block_size],
) )
# Cached torch/numpy tensor # Cached torch/numpy tensor
@ -1788,6 +1789,9 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
block_sizes=[ block_sizes=[
kv_cache_config.kv_cache_groups[0].kv_cache_spec.block_size kv_cache_config.kv_cache_groups[0].kv_cache_spec.block_size
], ],
kernel_block_sizes=[
kv_cache_config.kv_cache_groups[0].kv_cache_spec.block_size
],
) )
# Verify dtype compatibility between block_table_cpu and input_batch # Verify dtype compatibility between block_table_cpu and input_batch
assert ( assert (