mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-28 03:47:13 +08:00
[Hybrid]: Decouple Kernel Block Size from KV Page Size (#24486)
Signed-off-by: lizhiyuan <uniartisan2017@gmail.com> Signed-off-by: Zhiyuan Li <uniartisan2017@gmail.com>
This commit is contained in:
parent
d17f0fbf30
commit
d24cf322e1
@ -241,6 +241,7 @@ def test_sampling_metadata_in_input_batch(device: str, batch_size: int):
|
||||
pin_memory=is_pin_memory_available(),
|
||||
vocab_size=1024,
|
||||
block_sizes=[1],
|
||||
kernel_block_sizes=[1],
|
||||
)
|
||||
reqs: list[CachedRequestState] = []
|
||||
req_id_reqs = {}
|
||||
@ -335,6 +336,7 @@ def test_swap_states_in_input_batch(device: str, batch_size: int, swap_list: lis
|
||||
pin_memory=is_pin_memory_available(),
|
||||
vocab_size=1024,
|
||||
block_sizes=[1],
|
||||
kernel_block_sizes=[1],
|
||||
)
|
||||
ref_input_batch: InputBatch = InputBatch(
|
||||
max_num_reqs=batch_size,
|
||||
@ -344,6 +346,7 @@ def test_swap_states_in_input_batch(device: str, batch_size: int, swap_list: lis
|
||||
pin_memory=is_pin_memory_available(),
|
||||
vocab_size=1024,
|
||||
block_sizes=[1],
|
||||
kernel_block_sizes=[1],
|
||||
)
|
||||
|
||||
reqs: list[CachedRequestState] = []
|
||||
|
||||
@ -68,6 +68,9 @@ def initialize_kv_cache(runner: GPUModelRunner):
|
||||
pin_memory=runner.pin_memory,
|
||||
vocab_size=runner.model_config.get_vocab_size(),
|
||||
block_sizes=[kv_cache_config.kv_cache_groups[0].kv_cache_spec.block_size],
|
||||
kernel_block_sizes=[
|
||||
kv_cache_config.kv_cache_groups[0].kv_cache_spec.block_size
|
||||
],
|
||||
)
|
||||
runner.initialize_attn_backend(kv_cache_config)
|
||||
|
||||
@ -817,42 +820,231 @@ def test_hybrid_attention_mamba_tensor_shapes(monkeypatch):
|
||||
ssm_shape = vllm_ctx[layer_2].kv_cache[0][1].shape
|
||||
|
||||
# assert we are using FlashInfer
|
||||
assert attn_shape[0] == num_blocks
|
||||
assert attn_shape[0] % num_blocks == 0
|
||||
block_split_ratio = attn_shape[0] // num_blocks
|
||||
|
||||
# use small blocks for testing to avoid memory issues
|
||||
test_block_size = min(2, len(blocks0), len(blocks1))
|
||||
|
||||
# use non-overlapping blocks to avoid data contamination
|
||||
# Split kernel blocks: first half for attention, second half for mamba
|
||||
mid_point = num_blocks // 2
|
||||
|
||||
# attention uses kernel blocks from first half (mapped to logical blocks)
|
||||
kv_blocks_for_attention = np.array([0, 1])[:test_block_size]
|
||||
|
||||
# mamba uses kernel blocks from second half
|
||||
kv_blocks_for_mamba = np.array([mid_point, mid_point + 1])[:test_block_size]
|
||||
|
||||
# create small constant tensors for testing with corrected shapes
|
||||
# attention: [block_size, ...] starting from dimension 2
|
||||
attn_constant_shape = attn_shape[2:]
|
||||
conv_constant_shape = conv_shape[1:]
|
||||
ssm_constant_shape = ssm_shape[1:]
|
||||
|
||||
attn_blocks_constant = torch.full(
|
||||
(len(blocks0), *attn_shape[1:]), device=DEVICE, fill_value=3.33
|
||||
(test_block_size, *attn_constant_shape), device=DEVICE, fill_value=3.33
|
||||
)
|
||||
conv_blocks_constant = torch.full(
|
||||
(len(blocks1), *conv_shape[1:]), device=DEVICE, fill_value=6.66
|
||||
(test_block_size, *conv_constant_shape), device=DEVICE, fill_value=6.66
|
||||
)
|
||||
ssm_blocks_constant = torch.full(
|
||||
(len(blocks1), *ssm_shape[1:]), device=DEVICE, fill_value=9.99
|
||||
(test_block_size, *ssm_constant_shape), device=DEVICE, fill_value=9.99
|
||||
)
|
||||
|
||||
# fill all attention blocks with constant
|
||||
for layer in [layer_0, layer_1]:
|
||||
vllm_ctx[layer].kv_cache[0][blocks0, :] = (
|
||||
attn_blocks_constant.detach().clone()
|
||||
)
|
||||
# Fill attention blocks with constants using kv block indices
|
||||
kernel_blocks_for_attention = kv_blocks_for_attention * block_split_ratio
|
||||
|
||||
# fill all mamba blocks with constant
|
||||
for layer in [layer_0, layer_1]:
|
||||
# attention: kv_cache[0][kernel_block_idx, kv_idx, ...]
|
||||
for i, kernel_block in enumerate(kernel_blocks_for_attention):
|
||||
vllm_ctx[layer].kv_cache[0][kernel_block, :] = attn_blocks_constant[i]
|
||||
|
||||
# fill mamba blocks with constants using kernel block indices
|
||||
for layer in [layer_2, layer_3, layer_4, layer_5]:
|
||||
vllm_ctx[layer].kv_cache[0][0][blocks1, :] = (
|
||||
conv_blocks_constant.detach().clone()
|
||||
)
|
||||
vllm_ctx[layer].kv_cache[0][1][blocks1, :] = (
|
||||
ssm_blocks_constant.detach().clone()
|
||||
)
|
||||
# mamba: kv_cache[0][component][kernel_block_idx, ...]
|
||||
for i, kv_block in enumerate(kv_blocks_for_mamba):
|
||||
vllm_ctx[layer].kv_cache[0][0][kv_block, :] = conv_blocks_constant[i]
|
||||
vllm_ctx[layer].kv_cache[0][1][kv_block, :] = ssm_blocks_constant[i]
|
||||
|
||||
# verify attention and mamba contents are correct
|
||||
for layer in [layer_0, layer_1]:
|
||||
assert torch.equal(
|
||||
vllm_ctx[layer].kv_cache[0][blocks0, :], attn_blocks_constant
|
||||
)
|
||||
for i, kernel_block in enumerate(kernel_blocks_for_attention):
|
||||
actual_kv = vllm_ctx[layer].kv_cache[0][kernel_block, :]
|
||||
expected = attn_blocks_constant[i]
|
||||
|
||||
# Check K and V separately
|
||||
assert torch.equal(actual_kv[0], expected)
|
||||
assert torch.equal(actual_kv[1], expected)
|
||||
|
||||
for layer in [layer_2, layer_3, layer_4, layer_5]:
|
||||
assert torch.equal(
|
||||
vllm_ctx[layer].kv_cache[0][0][blocks1, :], conv_blocks_constant
|
||||
)
|
||||
assert torch.equal(
|
||||
vllm_ctx[layer].kv_cache[0][1][blocks1, :], ssm_blocks_constant
|
||||
)
|
||||
for i, kv_block in enumerate(kv_blocks_for_mamba):
|
||||
actual_conv = vllm_ctx[layer].kv_cache[0][0][kv_block, :]
|
||||
actual_ssm = vllm_ctx[layer].kv_cache[0][1][kv_block, :]
|
||||
expected_conv = conv_blocks_constant[i]
|
||||
expected_ssm = ssm_blocks_constant[i]
|
||||
|
||||
assert torch.equal(actual_conv, expected_conv)
|
||||
assert torch.equal(actual_ssm, expected_ssm)
|
||||
|
||||
for layer in [layer_2, layer_3, layer_4, layer_5]:
|
||||
for i, kv_block in enumerate(kv_blocks_for_mamba):
|
||||
actual_conv = vllm_ctx[layer].kv_cache[0][0][kv_block, :]
|
||||
actual_ssm = vllm_ctx[layer].kv_cache[0][1][kv_block, :]
|
||||
expected_conv = conv_blocks_constant[i]
|
||||
expected_ssm = ssm_blocks_constant[i]
|
||||
assert torch.equal(actual_conv, expected_conv)
|
||||
assert torch.equal(actual_ssm, expected_ssm)
|
||||
|
||||
|
||||
def test_hybrid_block_table_initialization():
|
||||
"""Test hybrid block table with different kernel and kvcache_manager block
|
||||
sizes."""
|
||||
from vllm.v1.worker.block_table import BlockTable
|
||||
|
||||
# Test configuration: kvcache_manager block size = 32,
|
||||
# kernel block size = 16
|
||||
block_size = 32
|
||||
kernel_block_sizes = [16]
|
||||
max_num_reqs = 10
|
||||
max_num_blocks_per_req = 20
|
||||
max_num_batched_tokens = 512
|
||||
|
||||
block_table = BlockTable(
|
||||
block_size=block_size,
|
||||
max_num_reqs=max_num_reqs,
|
||||
max_num_blocks_per_req=max_num_blocks_per_req,
|
||||
max_num_batched_tokens=max_num_batched_tokens,
|
||||
pin_memory=False,
|
||||
device=torch.device(DEVICE),
|
||||
kernel_block_size=kernel_block_sizes[0],
|
||||
)
|
||||
|
||||
# Verify hybrid block configuration
|
||||
assert block_table.use_hybrid_blocks is True
|
||||
assert block_table.block_size == kernel_block_sizes[0]
|
||||
assert block_table.blocks_per_kv_block == (
|
||||
block_size // kernel_block_sizes[0]
|
||||
) # Changed to use first element
|
||||
|
||||
# Test block table conversion logic
|
||||
# One kvcache_manager block should map to multiple kernel blocks
|
||||
kvcache_manager_blocks = [0, 1, 2]
|
||||
|
||||
# Verify that kvcache_manager blocks can be converted to kernel blocks
|
||||
# and that block table operations work correctly.
|
||||
req_index = 0
|
||||
block_table.append_row(kvcache_manager_blocks, req_index)
|
||||
# Get expected kernel blocks from the implementation for verification.
|
||||
expected_kernel_blocks = block_table._map_to_kernel_blocks(
|
||||
np.array(kvcache_manager_blocks)
|
||||
)
|
||||
# Verify block table state
|
||||
assert block_table.num_blocks_per_row[req_index] == len(expected_kernel_blocks)
|
||||
assert np.array_equal(
|
||||
block_table.block_table.np[req_index, : len(expected_kernel_blocks)],
|
||||
expected_kernel_blocks,
|
||||
)
|
||||
|
||||
|
||||
def test_input_batch_with_kernel_block_sizes():
|
||||
"""Test InputBatch initialization with kernel_block_sizes parameter."""
|
||||
max_num_reqs = 10
|
||||
max_model_len = 512
|
||||
max_num_batched_tokens = 512
|
||||
device = torch.device(DEVICE)
|
||||
pin_memory = False
|
||||
vocab_size = 50272
|
||||
|
||||
# Test with different kernel block sizes
|
||||
block_sizes = [32, 64]
|
||||
kernel_block_sizes = [16, 32]
|
||||
|
||||
input_batch = InputBatch(
|
||||
max_num_reqs=max_num_reqs,
|
||||
max_model_len=max_model_len,
|
||||
max_num_batched_tokens=max_num_batched_tokens,
|
||||
device=device,
|
||||
pin_memory=pin_memory,
|
||||
vocab_size=vocab_size,
|
||||
block_sizes=block_sizes,
|
||||
kernel_block_sizes=kernel_block_sizes,
|
||||
)
|
||||
|
||||
# Verify that block tables were created with kernel block sizes
|
||||
assert len(input_batch.block_table.block_tables) == len(block_sizes)
|
||||
|
||||
for i, (kv_size, kernel_size) in enumerate(zip(block_sizes, kernel_block_sizes)):
|
||||
block_table = input_batch.block_table.block_tables[i]
|
||||
if kv_size != kernel_size:
|
||||
assert block_table.use_hybrid_blocks is True
|
||||
assert block_table.block_size == kernel_size
|
||||
else:
|
||||
assert block_table.use_hybrid_blocks is False
|
||||
assert block_table.block_size == kernel_size
|
||||
|
||||
|
||||
def test_hybrid_cache_integration(model_runner, dist_init):
|
||||
"""Test hybrid cache architecture integration with GPUModelRunner."""
|
||||
# Create a new model runner with hybrid cache configuration
|
||||
vllm_config = get_vllm_config()
|
||||
|
||||
# Configure hybrid cache with different kvcache_manager block size
|
||||
vllm_config.cache_config.block_size = 32
|
||||
|
||||
model_config = vllm_config.model_config
|
||||
num_heads = model_config.get_num_kv_heads(vllm_config.parallel_config)
|
||||
head_size = model_config.get_head_size()
|
||||
vllm_config.compilation_config.static_forward_context["layer.0"] = Attention(
|
||||
num_heads, head_size, 0.1
|
||||
)
|
||||
|
||||
runner = GPUModelRunner(vllm_config, DEVICE)
|
||||
|
||||
# Initialize KV cache with configuration
|
||||
attn_spec = FullAttentionSpec(
|
||||
block_size=16, # Use kernel block size directly
|
||||
num_kv_heads=runner.model_config.get_num_kv_heads(runner.parallel_config),
|
||||
head_size=runner.model_config.get_head_size(),
|
||||
dtype=runner.kv_cache_dtype,
|
||||
)
|
||||
tensor_size = attn_spec.page_size_bytes * NUM_BLOCKS
|
||||
kv_cache_config = KVCacheConfig(
|
||||
num_blocks=NUM_BLOCKS,
|
||||
kv_cache_tensors=[
|
||||
KVCacheTensor(size=tensor_size, shared_by=["layer.0"]),
|
||||
],
|
||||
kv_cache_groups=[
|
||||
KVCacheGroupSpec(layer_names=["layer.0"], kv_cache_spec=attn_spec)
|
||||
],
|
||||
)
|
||||
runner.kv_cache_config = kv_cache_config
|
||||
|
||||
# Initialize input batch with kernel block sizes
|
||||
runner.input_batch = InputBatch(
|
||||
max_num_reqs=runner.max_num_reqs,
|
||||
max_model_len=runner.max_model_len,
|
||||
max_num_batched_tokens=runner.max_num_tokens,
|
||||
device=runner.device,
|
||||
pin_memory=runner.pin_memory,
|
||||
vocab_size=runner.model_config.get_vocab_size(),
|
||||
block_sizes=[kv_cache_config.kv_cache_groups[0].kv_cache_spec.block_size],
|
||||
kernel_block_sizes=[16],
|
||||
) # Use kernel block size
|
||||
|
||||
runner.initialize_attn_backend(kv_cache_config)
|
||||
|
||||
# Verify hybrid block table configuration
|
||||
block_table = runner.input_batch.block_table.block_tables[0]
|
||||
assert block_table.block_size == (
|
||||
kv_cache_config.kv_cache_groups[0].kv_cache_spec.block_size
|
||||
)
|
||||
|
||||
# Test request processing with hybrid blocks
|
||||
req_id = "hybrid_req_0"
|
||||
scheduler_output = _schedule_new_request(req_id)
|
||||
|
||||
# Update states should work with hybrid blocks
|
||||
runner._update_states(scheduler_output)
|
||||
assert _is_req_scheduled(runner, req_id)
|
||||
assert _is_req_state_block_table_match(runner, req_id)
|
||||
|
||||
@ -2,7 +2,7 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Generic, Optional, Protocol, TypeVar
|
||||
from typing import Generic, Optional, Protocol, TypeVar, Union
|
||||
|
||||
import torch
|
||||
|
||||
@ -26,6 +26,13 @@ class AttentionType:
|
||||
"""Attention between dec. Q and enc. K/V for encoder-decoder."""
|
||||
|
||||
|
||||
class MultipleOf:
|
||||
base: int
|
||||
|
||||
def __init__(self, base: int):
|
||||
self.base = base
|
||||
|
||||
|
||||
class AttentionBackend(ABC):
|
||||
"""Abstract class for attention backends."""
|
||||
|
||||
@ -57,6 +64,10 @@ class AttentionBackend(ABC):
|
||||
def get_metadata_cls() -> type["AttentionMetadata"]:
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def get_supported_kernel_block_size(cls) -> list[Union[int, MultipleOf]]:
|
||||
return cls.get_impl_cls().get_supported_kernel_block_size()
|
||||
|
||||
@classmethod
|
||||
def make_metadata(cls, *args, **kwargs) -> "AttentionMetadata":
|
||||
return cls.get_metadata_cls()(*args, **kwargs)
|
||||
@ -157,6 +168,11 @@ class AttentionImpl(ABC, Generic[T]):
|
||||
) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
def get_supported_kernel_block_size() -> list[Union[int, MultipleOf]]:
|
||||
# TODO: implement this function for all backends.
|
||||
return [MultipleOf(1)]
|
||||
|
||||
@abstractmethod
|
||||
def forward(
|
||||
self,
|
||||
|
||||
@ -365,6 +365,23 @@ class HybridAttentionMambaModelConfig(VerifyAndUpdateConfig):
|
||||
block_size=model_config.max_model_len,
|
||||
).page_size_bytes
|
||||
|
||||
# Model may be marked as is_hybrid
|
||||
# but mamba is skipped via config,
|
||||
# return directly
|
||||
if mamba_page_size == 0:
|
||||
return
|
||||
|
||||
# Attention backend constraints:
|
||||
# - FlashAttention (FA) requires block size to be multiple of 16
|
||||
# - MLA (Multi-head Latent Attention) requires larger alignment:
|
||||
# * CUTLASS_MLA backend: 128-byte alignment
|
||||
# * Other MLA backends: 64-byte alignment
|
||||
if model_config.use_mla:
|
||||
use_cutlass_mla = envs.VLLM_ATTENTION_BACKEND == "CUTLASS_MLA"
|
||||
kernel_block_alignment_size = 128 if use_cutlass_mla else 64
|
||||
else:
|
||||
kernel_block_alignment_size = 16
|
||||
|
||||
if cache_config.enable_prefix_caching:
|
||||
# With prefix caching, select attention block size to
|
||||
# optimize for mamba kernel performance
|
||||
@ -381,19 +398,28 @@ class HybridAttentionMambaModelConfig(VerifyAndUpdateConfig):
|
||||
# TODO(tdoublep): this constraint can be relaxed fairly
|
||||
# easily by changing the way we layout chunks in the
|
||||
# mamba2 kernels.
|
||||
chunk_size = model_config.get_mamba_chunk_size()
|
||||
|
||||
from math import gcd
|
||||
|
||||
def lcm(a, b):
|
||||
return a * b // gcd(a, b)
|
||||
|
||||
base_chunk_size = model_config.get_mamba_chunk_size()
|
||||
attn_tokens_per_mamba_state = cdiv(mamba_page_size, attn_page_size_1_token)
|
||||
|
||||
chunk_size = lcm(base_chunk_size, kernel_block_alignment_size)
|
||||
attn_block_size = chunk_size * cdiv(attn_tokens_per_mamba_state, chunk_size)
|
||||
cache_config.mamba_block_size = attn_block_size
|
||||
else:
|
||||
# Without prefix caching, select minimum valid attention block size
|
||||
# to minimize mamba state padding
|
||||
|
||||
# some attention backends (e.g. FA) only support setting
|
||||
# block size to multiple of 16, so let's suggest a value
|
||||
# that would work (note: FA is currently not compatible
|
||||
# with mamba layers, use FlashInfer instead).
|
||||
attn_block_size = 16 * cdiv(mamba_page_size, 16 * attn_page_size_1_token)
|
||||
# Calculate minimum attention block size that satisfies both:
|
||||
# 1. Backend alignment requirements (kernel_block_alignment_size)
|
||||
# 2. Mamba page size compatibility (attn_page_size >= mamba_page_size)
|
||||
attn_block_size = kernel_block_alignment_size * cdiv(
|
||||
mamba_page_size, kernel_block_alignment_size * attn_page_size_1_token
|
||||
)
|
||||
|
||||
# override attention block size if either (a) the
|
||||
# user has not set it or (b) the user has set it
|
||||
|
||||
@ -118,7 +118,15 @@ class CudaPlatformBase(Platform):
|
||||
|
||||
# TODO(lucas): handle this more gracefully
|
||||
# Note: model_config may be None during testing
|
||||
if model_config is not None and model_config.use_mla:
|
||||
# Note: block_size is initialized in
|
||||
# HybridAttentionMambaModelConfig.verify_and_update_config
|
||||
# for models with both attention and mamba,
|
||||
# and doesn't need to be reinitialized here
|
||||
if (
|
||||
model_config is not None
|
||||
and model_config.use_mla
|
||||
and cache_config.block_size is not None
|
||||
):
|
||||
use_sparse = hasattr(vllm_config.model_config.hf_config, "index_topk")
|
||||
# If `VLLM_ATTENTION_BACKEND` is not set and we are using MLA,
|
||||
# then we default to FlashMLA backend for non-blackwell GPUs,
|
||||
@ -151,18 +159,22 @@ class CudaPlatformBase(Platform):
|
||||
if (
|
||||
use_flashmla
|
||||
and is_flashmla_dense_supported()[0]
|
||||
and cache_config.block_size != 64
|
||||
and cache_config.block_size % 64 != 0
|
||||
):
|
||||
cache_config.block_size = 64
|
||||
logger.info("Forcing kv cache block size to 64 for FlashMLA backend.")
|
||||
|
||||
if use_cutlass_mla and cache_config.block_size != 128:
|
||||
if use_cutlass_mla and cache_config.block_size % 128 != 0:
|
||||
cache_config.block_size = 128
|
||||
logger.info(
|
||||
"Forcing kv cache block size to 128 for CUTLASS_MLA backend."
|
||||
)
|
||||
|
||||
if use_flashinfer_mla and cache_config.block_size not in [32, 64]:
|
||||
if (
|
||||
use_flashinfer_mla
|
||||
and cache_config.block_size != 32
|
||||
and cache_config.block_size % 64 != 0
|
||||
):
|
||||
cache_config.block_size = 64
|
||||
logger.info(
|
||||
"Forcing kv cache block size to 64 for FlashInferMLA backend."
|
||||
@ -269,12 +281,12 @@ class CudaPlatformBase(Platform):
|
||||
use_cutlassmla = selected_backend == _Backend.CUTLASS_MLA or (
|
||||
selected_backend is None
|
||||
and cls.is_device_capability(100)
|
||||
and block_size == 128
|
||||
and block_size % 128 == 0
|
||||
)
|
||||
use_flashinfermla = selected_backend == _Backend.FLASHINFER_MLA or (
|
||||
selected_backend is None
|
||||
and cls.is_device_capability(100)
|
||||
and block_size in [32, 64]
|
||||
and (block_size == 32 or block_size % 64 == 0)
|
||||
)
|
||||
use_flashmla = selected_backend == _Backend.FLASHMLA or (
|
||||
selected_backend is None and is_flashmla_dense_supported()[0]
|
||||
@ -298,7 +310,7 @@ class CudaPlatformBase(Platform):
|
||||
"vllm.v1.attention.backends.mla.flashinfer_mla.FlashInferMLABackend"
|
||||
)
|
||||
if use_flashmla:
|
||||
if block_size != 64:
|
||||
if block_size % 64 != 0:
|
||||
logger.warning(
|
||||
"FlashMLA backend is not supported for block size %d"
|
||||
" (currently only supports block size 64).",
|
||||
|
||||
@ -3,7 +3,7 @@
|
||||
"""Attention layer with FlashAttention."""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
from typing import Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@ -14,6 +14,7 @@ from vllm.attention.backends.abstract import (
|
||||
AttentionImpl,
|
||||
AttentionMetadata,
|
||||
AttentionType,
|
||||
MultipleOf,
|
||||
is_quantized_kv_cache,
|
||||
)
|
||||
from vllm.attention.layer import Attention
|
||||
@ -57,6 +58,10 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
def get_supported_head_sizes(cls) -> list[int]:
|
||||
return [32, 64, 96, 128, 160, 192, 224, 256]
|
||||
|
||||
@staticmethod
|
||||
def get_supported_kernel_block_size() -> list[Union[int, MultipleOf]]:
|
||||
return [MultipleOf(16)]
|
||||
|
||||
@classmethod
|
||||
def validate_head_size(cls, head_size: int) -> None:
|
||||
supported_head_sizes = cls.get_supported_head_sizes()
|
||||
|
||||
@ -23,6 +23,7 @@ from vllm.attention.backends.abstract import (
|
||||
AttentionBackend,
|
||||
AttentionImpl,
|
||||
AttentionType,
|
||||
MultipleOf,
|
||||
)
|
||||
from vllm.config import CUDAGraphMode, VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
@ -165,6 +166,13 @@ class FlashInferBackend(AttentionBackend):
|
||||
# https://github.com/flashinfer-ai/flashinfer/blob/3d55c71a62052c590c130897d3a3db49b14fcc34/include/flashinfer/utils.cuh#L157
|
||||
return [64, 128, 256]
|
||||
|
||||
@staticmethod
|
||||
def get_supported_kernel_block_size() -> list[Union[int, MultipleOf]]:
|
||||
# Note: Not sure for all platforms,
|
||||
# but on Blackwell, only support a page size of
|
||||
# 16, 32, 64
|
||||
return [16, 32, 64]
|
||||
|
||||
@classmethod
|
||||
def validate_head_size(cls, head_size: int) -> None:
|
||||
supported_head_sizes = cls.get_supported_head_sizes()
|
||||
|
||||
@ -10,6 +10,7 @@ import vllm._custom_ops as ops
|
||||
from vllm.attention.backends.abstract import (
|
||||
AttentionLayer,
|
||||
AttentionType,
|
||||
MultipleOf,
|
||||
is_quantized_kv_cache,
|
||||
)
|
||||
from vllm.logger import init_logger
|
||||
@ -44,6 +45,10 @@ class CutlassMLABackend(MLACommonBackend):
|
||||
def get_builder_cls() -> type["CutlassMLAMetadataBuilder"]:
|
||||
return CutlassMLAMetadataBuilder
|
||||
|
||||
@staticmethod
|
||||
def get_supported_kernel_block_size() -> list[Union[int, MultipleOf]]:
|
||||
return [128]
|
||||
|
||||
|
||||
class SM100Workspace:
|
||||
def __init__(self, initial_workspace_size):
|
||||
|
||||
@ -6,7 +6,7 @@ from typing import ClassVar, Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.attention.backends.abstract import AttentionLayer, AttentionType
|
||||
from vllm.attention.backends.abstract import AttentionLayer, AttentionType, MultipleOf
|
||||
from vllm.attention.ops.flashmla import (
|
||||
flash_mla_with_kvcache,
|
||||
get_mla_metadata,
|
||||
@ -44,6 +44,10 @@ class FlashMLABackend(MLACommonBackend):
|
||||
def get_impl_cls() -> type["FlashMLAImpl"]:
|
||||
return FlashMLAImpl
|
||||
|
||||
@staticmethod
|
||||
def get_supported_kernel_block_size() -> list[Union[int, MultipleOf]]:
|
||||
return [64]
|
||||
|
||||
|
||||
@dataclass
|
||||
class FlashMLADecodeMetadata(MLACommonDecodeMetadata):
|
||||
|
||||
@ -3,7 +3,7 @@
|
||||
"""Attention layer with AiterFlashAttention."""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
@ -12,6 +12,7 @@ from vllm.attention.backends.abstract import (
|
||||
AttentionImpl,
|
||||
AttentionMetadata,
|
||||
AttentionType,
|
||||
MultipleOf,
|
||||
)
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
@ -359,6 +360,10 @@ class AiterFlashAttentionBackend(AttentionBackend):
|
||||
def get_supported_head_sizes(cls) -> list[int]:
|
||||
return [64, 128, 256]
|
||||
|
||||
@staticmethod
|
||||
def get_supported_kernel_block_size() -> list[Union[int, MultipleOf]]:
|
||||
return [MultipleOf(16)]
|
||||
|
||||
@classmethod
|
||||
def validate_head_size(cls, head_size: int) -> None:
|
||||
supported_head_sizes = cls.get_supported_head_sizes()
|
||||
|
||||
@ -4,7 +4,7 @@
|
||||
|
||||
import ast
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
@ -14,6 +14,7 @@ from vllm.attention.backends.abstract import (
|
||||
AttentionImpl,
|
||||
AttentionMetadata,
|
||||
AttentionType,
|
||||
MultipleOf,
|
||||
)
|
||||
from vllm.attention.ops.triton_unified_attention import unified_attention
|
||||
from vllm.config import VllmConfig
|
||||
@ -39,6 +40,10 @@ class TreeAttentionBackend(AttentionBackend):
|
||||
def get_supported_head_sizes(cls) -> list[int]:
|
||||
return [32, 64, 96, 128, 160, 192, 224, 256]
|
||||
|
||||
@staticmethod
|
||||
def get_supported_kernel_block_size() -> list[Union[int, MultipleOf]]:
|
||||
return [MultipleOf(16)]
|
||||
|
||||
@classmethod
|
||||
def validate_head_size(cls, head_size: int) -> None:
|
||||
supported_head_sizes = cls.get_supported_head_sizes()
|
||||
|
||||
@ -3,7 +3,7 @@
|
||||
"""High-Performance Triton-only Attention layer."""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import ClassVar, Optional
|
||||
from typing import ClassVar, Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
@ -12,6 +12,7 @@ from vllm.attention.backends.abstract import (
|
||||
AttentionImpl,
|
||||
AttentionMetadata,
|
||||
AttentionType,
|
||||
MultipleOf,
|
||||
)
|
||||
from vllm.attention.ops.triton_reshape_and_cache_flash import (
|
||||
triton_reshape_and_cache_flash,
|
||||
@ -157,6 +158,10 @@ class TritonAttentionBackend(AttentionBackend):
|
||||
def get_supported_dtypes(cls) -> list[torch.dtype]:
|
||||
return [torch.float16, torch.bfloat16, torch.float32]
|
||||
|
||||
@staticmethod
|
||||
def get_supported_kernel_block_size() -> list[Union[int, MultipleOf]]:
|
||||
return [MultipleOf(16)]
|
||||
|
||||
@classmethod
|
||||
def validate_head_size(cls, head_size: int) -> None:
|
||||
# Triton Attention supports any head size above 32
|
||||
|
||||
@ -3,7 +3,7 @@
|
||||
"""Attention layer with XFormersAttention."""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
@ -12,6 +12,7 @@ from vllm.attention.backends.abstract import (
|
||||
AttentionImpl,
|
||||
AttentionMetadata,
|
||||
AttentionType,
|
||||
MultipleOf,
|
||||
)
|
||||
from vllm.attention.ops.triton_unified_attention import unified_attention
|
||||
from vllm.config import VllmConfig
|
||||
@ -80,6 +81,10 @@ class XFormersAttentionBackend(AttentionBackend):
|
||||
256,
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def get_supported_kernel_block_size() -> list[Union[int, MultipleOf]]:
|
||||
return [MultipleOf(16)]
|
||||
|
||||
@classmethod
|
||||
def validate_head_size(cls, head_size: int) -> None:
|
||||
supported_head_sizes = cls.get_supported_head_sizes()
|
||||
|
||||
@ -22,22 +22,64 @@ class BlockTable:
|
||||
max_num_batched_tokens: int,
|
||||
pin_memory: bool,
|
||||
device: torch.device,
|
||||
kernel_block_size: int,
|
||||
):
|
||||
self.block_size = block_size
|
||||
"""
|
||||
Args:
|
||||
block_size: Block size used for KV cache memory allocation
|
||||
max_num_reqs: Maximum number of concurrent requests supported.
|
||||
max_num_blocks_per_req: Maximum number of blocks per request.
|
||||
max_num_batched_tokens: Maximum number of tokens in a batch.
|
||||
pin_memory: Whether to pin memory for faster GPU transfers.
|
||||
device: Target device for the block table.
|
||||
kernel_block_size: The block_size of underlying attention kernel.
|
||||
Will be the same as `block_size` if `block_size` is supported
|
||||
by the attention kernel.
|
||||
"""
|
||||
self.max_num_reqs = max_num_reqs
|
||||
self.max_num_blocks_per_req = max_num_blocks_per_req
|
||||
self.max_num_batched_tokens = max_num_batched_tokens
|
||||
self.pin_memory = pin_memory
|
||||
self.device = device
|
||||
|
||||
if kernel_block_size == block_size:
|
||||
# Standard case: allocation and computation use same block size
|
||||
# No block splitting needed, direct mapping
|
||||
self.block_size = block_size
|
||||
self.blocks_per_kv_block = 1
|
||||
self.use_hybrid_blocks = False
|
||||
else:
|
||||
# Hybrid case: allocation block size differs from kernel block size
|
||||
# Memory blocks are subdivided to match kernel requirements
|
||||
# Example: 32-token memory blocks with 16-token kernel blocks
|
||||
# → Each memory block corresponds to 2 kernel blocks
|
||||
if block_size % kernel_block_size != 0:
|
||||
raise ValueError(
|
||||
f"kernel_block_size {kernel_block_size} must divide "
|
||||
f"kv_manager_block_size size {block_size} evenly"
|
||||
)
|
||||
|
||||
self.block_size = kernel_block_size
|
||||
self.blocks_per_kv_block = block_size // kernel_block_size
|
||||
self.use_hybrid_blocks = True
|
||||
|
||||
self.max_num_blocks_per_req = max_num_blocks_per_req * self.blocks_per_kv_block
|
||||
|
||||
self.block_table = self._make_buffer(
|
||||
max_num_reqs, max_num_blocks_per_req, dtype=torch.int32
|
||||
self.max_num_reqs, self.max_num_blocks_per_req, dtype=torch.int32
|
||||
)
|
||||
self.num_blocks_per_row = np.zeros(max_num_reqs, dtype=np.int32)
|
||||
|
||||
self.slot_mapping = self._make_buffer(
|
||||
self.max_num_batched_tokens, dtype=torch.int64
|
||||
)
|
||||
|
||||
if self.use_hybrid_blocks:
|
||||
self._kernel_block_arange = np.arange(0, self.blocks_per_kv_block).reshape(
|
||||
1, -1
|
||||
)
|
||||
else:
|
||||
self._kernel_block_arange = None
|
||||
|
||||
try:
|
||||
self.dcp_world_size = get_dcp_group().world_size
|
||||
self.dcp_rank = get_dcp_group().rank_in_group
|
||||
@ -53,6 +95,10 @@ class BlockTable:
|
||||
) -> None:
|
||||
if not block_ids:
|
||||
return
|
||||
|
||||
if self.use_hybrid_blocks:
|
||||
block_ids = self._map_to_kernel_blocks(np.array(block_ids))
|
||||
|
||||
num_blocks = len(block_ids)
|
||||
start = self.num_blocks_per_row[row_idx]
|
||||
self.num_blocks_per_row[row_idx] += num_blocks
|
||||
@ -94,6 +140,7 @@ class BlockTable:
|
||||
req_indices * self.max_num_blocks_per_req
|
||||
+ positions // virtual_block_size
|
||||
)
|
||||
|
||||
block_numbers = self.block_table.np.ravel()[block_table_indices]
|
||||
# Use virtual_block_size for mask calculation, which marks local
|
||||
# tokens.
|
||||
@ -111,6 +158,7 @@ class BlockTable:
|
||||
block_table_indices = (
|
||||
req_indices * self.max_num_blocks_per_req + positions // self.block_size
|
||||
)
|
||||
|
||||
block_numbers = self.block_table.np.ravel()[block_table_indices]
|
||||
block_offsets = positions % self.block_size
|
||||
np.add(
|
||||
@ -129,6 +177,31 @@ class BlockTable:
|
||||
self.block_table.gpu.fill_(0)
|
||||
self.block_table.cpu.fill_(0)
|
||||
|
||||
def _map_to_kernel_blocks(self, kv_manager_block_ids: np.ndarray) -> np.ndarray:
|
||||
"""Convert kv_manager_block_id IDs to kernel block IDs.
|
||||
|
||||
Example:
|
||||
# kv_manager_block_ids: 32 tokens,
|
||||
# Kernel block size: 16 tokens
|
||||
# blocks_per_kv_block = 2
|
||||
>>> kv_manager_block_ids = np.array([0, 1, 2])
|
||||
>>> Result: [0, 1, 2, 3, 4, 5]
|
||||
|
||||
# Each kv_manager_block_id maps to 2 kernel block id:
|
||||
# kv_manager_block_id 0 → kernel block id [0, 1]
|
||||
# kv_manager_block_id 1 → kernel block id [2, 3]
|
||||
# kv_manager_block_id 2 → kernel block id [4, 5]
|
||||
"""
|
||||
if not self.use_hybrid_blocks:
|
||||
return kv_manager_block_ids
|
||||
|
||||
kernel_block_ids = (
|
||||
kv_manager_block_ids.reshape(-1, 1) * self.blocks_per_kv_block
|
||||
+ self._kernel_block_arange
|
||||
)
|
||||
|
||||
return kernel_block_ids.reshape(-1)
|
||||
|
||||
def get_device_tensor(self, num_reqs: int) -> torch.Tensor:
|
||||
"""Returns the device tensor of the block table."""
|
||||
return self.block_table.gpu[:num_reqs]
|
||||
@ -160,6 +233,7 @@ class MultiGroupBlockTable:
|
||||
pin_memory: bool,
|
||||
device: torch.device,
|
||||
block_sizes: list[int],
|
||||
kernel_block_sizes: list[int],
|
||||
num_speculative_tokens: int = 0,
|
||||
) -> None:
|
||||
# Note(hc): each dcp rank only store
|
||||
@ -172,6 +246,12 @@ class MultiGroupBlockTable:
|
||||
# DCP might not be initialized in testing
|
||||
dcp_world_size = 1
|
||||
|
||||
if len(kernel_block_sizes) != len(block_sizes):
|
||||
raise ValueError(
|
||||
f"kernel_block_sizes length ({len(kernel_block_sizes)}) "
|
||||
f"must match block_sizes length ({len(block_sizes)})"
|
||||
)
|
||||
|
||||
self.block_tables = [
|
||||
BlockTable(
|
||||
block_size,
|
||||
@ -183,8 +263,9 @@ class MultiGroupBlockTable:
|
||||
max_num_batched_tokens,
|
||||
pin_memory,
|
||||
device,
|
||||
kernel_block_size,
|
||||
)
|
||||
for block_size in block_sizes
|
||||
for block_size, kernel_block_size in zip(block_sizes, kernel_block_sizes)
|
||||
]
|
||||
|
||||
def append_row(self, block_ids: tuple[list[int], ...], row_idx: int) -> None:
|
||||
|
||||
@ -78,6 +78,7 @@ class InputBatch:
|
||||
pin_memory: bool,
|
||||
vocab_size: int,
|
||||
block_sizes: list[int], # The block_size of each kv cache group
|
||||
kernel_block_sizes: list[int],
|
||||
logitsprocs: Optional[LogitsProcessors] = None,
|
||||
is_spec_decode: bool = False,
|
||||
is_pooling_model: bool = False,
|
||||
@ -132,6 +133,7 @@ class InputBatch:
|
||||
pin_memory=pin_memory,
|
||||
device=device,
|
||||
block_sizes=block_sizes,
|
||||
kernel_block_sizes=kernel_block_sizes,
|
||||
num_speculative_tokens=num_speculative_tokens,
|
||||
)
|
||||
|
||||
|
||||
@ -19,7 +19,7 @@ from typing_extensions import TypeAlias
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.attention import Attention, AttentionType
|
||||
from vllm.attention.backends.abstract import AttentionBackend
|
||||
from vllm.attention.backends.abstract import AttentionBackend, MultipleOf
|
||||
from vllm.attention.layer import MLAAttention
|
||||
from vllm.attention.layers.chunked_local_attention import ChunkedLocalAttention
|
||||
from vllm.compilation.counter import compilation_counter
|
||||
@ -359,6 +359,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
pin_memory=self.pin_memory,
|
||||
vocab_size=self.model_config.get_vocab_size(),
|
||||
block_sizes=[self.cache_config.block_size],
|
||||
kernel_block_sizes=[self.cache_config.block_size],
|
||||
is_spec_decode=bool(self.vllm_config.speculative_config),
|
||||
logitsprocs=build_logitsprocs(
|
||||
self.vllm_config,
|
||||
@ -4050,6 +4051,86 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
else:
|
||||
self.reorder_batch_threshold = reorder_batch_threshold_i
|
||||
|
||||
def _find_compatible_block_sizes(
|
||||
self,
|
||||
kv_manager_block_size: int,
|
||||
backend_cls: type[AttentionBackend],
|
||||
return_all: bool = False,
|
||||
) -> list[int]:
|
||||
"""
|
||||
Find compatible block sizes for a backend.
|
||||
|
||||
Args:
|
||||
kv_manager_block_size: Physical block size of KV cache
|
||||
backend_cls: Attention backend class
|
||||
return_all: Return all compatible sizes if True, max size if False
|
||||
|
||||
Returns:
|
||||
Compatible block size(s) based on return_all parameter
|
||||
|
||||
Raises:
|
||||
ValueError: If no compatible block size found
|
||||
"""
|
||||
supported_block_size = backend_cls.get_supported_kernel_block_size()
|
||||
compatible_sizes = []
|
||||
|
||||
for block_size in supported_block_size:
|
||||
if isinstance(block_size, int):
|
||||
if kv_manager_block_size % block_size == 0:
|
||||
compatible_sizes.append(block_size)
|
||||
elif (
|
||||
isinstance(block_size, MultipleOf)
|
||||
and kv_manager_block_size % block_size.base == 0
|
||||
):
|
||||
compatible_sizes.append(kv_manager_block_size)
|
||||
|
||||
if not compatible_sizes:
|
||||
raise ValueError(f"No compatible block size for {kv_manager_block_size}")
|
||||
|
||||
return compatible_sizes if return_all else [max(compatible_sizes)]
|
||||
|
||||
def _select_common_block_size(
|
||||
self, kv_manager_block_size: int, attn_groups: list[AttentionGroup]
|
||||
) -> int:
|
||||
"""
|
||||
Select common block size for all backends.
|
||||
|
||||
Args:
|
||||
kv_manager_block_size: Block size of KV cache
|
||||
attn_groups: List of attention groups
|
||||
|
||||
Returns:
|
||||
Block size supported by all backends,
|
||||
prioritizing cache_config.block_size
|
||||
|
||||
Raises:
|
||||
ValueError: If no common block size found
|
||||
"""
|
||||
all_backend_supports = []
|
||||
|
||||
for attn_group in attn_groups:
|
||||
compatible_sizes = self._find_compatible_block_sizes(
|
||||
kv_manager_block_size, attn_group.backend, return_all=True
|
||||
)
|
||||
supported_sizes = sorted(list(set(compatible_sizes)), reverse=True)
|
||||
all_backend_supports.append(set(supported_sizes))
|
||||
|
||||
common_supported_sizes = set.intersection(*all_backend_supports)
|
||||
|
||||
if not common_supported_sizes:
|
||||
error_msg = f"No common block size for {kv_manager_block_size}. "
|
||||
for i, attn_group in enumerate(attn_groups):
|
||||
supported = all_backend_supports[i]
|
||||
error_msg += (
|
||||
f"Backend {attn_group.backend} supports: {sorted(supported)}. "
|
||||
)
|
||||
raise ValueError(error_msg)
|
||||
|
||||
if self.cache_config.block_size in common_supported_sizes:
|
||||
return self.cache_config.block_size
|
||||
|
||||
return max(common_supported_sizes)
|
||||
|
||||
def may_reinitialize_input_batch(self, kv_cache_config: KVCacheConfig) -> None:
|
||||
"""
|
||||
Re-initialize the input batch if the block sizes are different from
|
||||
@ -4062,8 +4143,15 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
block_sizes = [
|
||||
kv_cache_group.kv_cache_spec.block_size
|
||||
for kv_cache_group in kv_cache_config.kv_cache_groups
|
||||
if not isinstance(kv_cache_group.kv_cache_spec, EncoderOnlyAttentionSpec)
|
||||
]
|
||||
if block_sizes != [self.cache_config.block_size]:
|
||||
|
||||
# Generate kernel_block_sizes that matches each block_size
|
||||
kernel_block_sizes = self._prepare_kernel_block_sizes(kv_cache_config)
|
||||
|
||||
if block_sizes != [self.cache_config.block_size] or kernel_block_sizes != [
|
||||
self.cache_config.block_size
|
||||
]:
|
||||
assert self.cache_config.cpu_offload_gb == 0, (
|
||||
"Cannot re-initialize the input batch when CPU weight "
|
||||
"offloading is enabled. See https://github.com/vllm-project/vllm/pull/18298 " # noqa: E501
|
||||
@ -4077,6 +4165,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
pin_memory=self.pin_memory,
|
||||
vocab_size=self.model_config.get_vocab_size(),
|
||||
block_sizes=block_sizes,
|
||||
kernel_block_sizes=kernel_block_sizes,
|
||||
is_spec_decode=bool(self.vllm_config.speculative_config),
|
||||
logitsprocs=self.input_batch.logitsprocs,
|
||||
is_pooling_model=self.is_pooling_model,
|
||||
@ -4128,6 +4217,46 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
for attn_groups in self.attn_groups:
|
||||
yield from attn_groups
|
||||
|
||||
def _prepare_kernel_block_sizes(self, kv_cache_config: KVCacheConfig) -> list[int]:
|
||||
"""
|
||||
Generate kernel_block_sizes that matches each block_size.
|
||||
|
||||
For attention backends that support virtual block splitting,
|
||||
use the supported block sizes from the backend.
|
||||
For other backends (like Mamba), use the same block size (no splitting).
|
||||
|
||||
Args:
|
||||
kv_cache_config: The KV cache configuration.
|
||||
|
||||
Returns:
|
||||
list[int]: List of kernel block sizes for each cache group.
|
||||
"""
|
||||
kernel_block_sizes = []
|
||||
for kv_cache_group_id, kv_cache_group in enumerate(
|
||||
kv_cache_config.kv_cache_groups
|
||||
):
|
||||
if isinstance(kv_cache_group.kv_cache_spec, EncoderOnlyAttentionSpec):
|
||||
continue
|
||||
elif isinstance(kv_cache_group.kv_cache_spec, AttentionSpec):
|
||||
# This is an attention backend that supports virtual
|
||||
# block splitting. Get the supported block sizes from
|
||||
# all backends in the group.
|
||||
attn_groups = self.attn_groups[kv_cache_group_id]
|
||||
kv_manager_block_size = kv_cache_group.kv_cache_spec.block_size
|
||||
selected_kernel_size = self._select_common_block_size(
|
||||
kv_manager_block_size, attn_groups
|
||||
)
|
||||
kernel_block_sizes.append(selected_kernel_size)
|
||||
elif isinstance(kv_cache_group.kv_cache_spec, MambaSpec):
|
||||
# This is likely Mamba or other non-attention cache,
|
||||
# no splitting.
|
||||
kernel_block_sizes.append(kv_cache_group.kv_cache_spec.block_size)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"unknown kv cache spec {kv_cache_group.kv_cache_spec}"
|
||||
)
|
||||
return kernel_block_sizes
|
||||
|
||||
def _reshape_kv_cache_tensors(
|
||||
self,
|
||||
kv_cache_config: KVCacheConfig,
|
||||
@ -4157,16 +4286,24 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
num_blocks = raw_tensor.numel() // kv_cache_spec.page_size_bytes
|
||||
if isinstance(kv_cache_spec, AttentionSpec):
|
||||
has_attn = True
|
||||
kv_manager_block_size = kv_cache_spec.block_size
|
||||
kernel_size_list = self._find_compatible_block_sizes(
|
||||
kv_manager_block_size, attn_backend, return_all=False
|
||||
)
|
||||
kernel_size = kernel_size_list[0]
|
||||
num_blocks_per_kv_block = kv_manager_block_size // kernel_size
|
||||
kernel_num_blocks = num_blocks * num_blocks_per_kv_block
|
||||
|
||||
kv_cache_shape = attn_backend.get_kv_cache_shape(
|
||||
num_blocks,
|
||||
kv_cache_spec.block_size,
|
||||
kernel_num_blocks,
|
||||
kernel_size,
|
||||
kv_cache_spec.num_kv_heads,
|
||||
kv_cache_spec.head_size,
|
||||
cache_dtype_str=self.cache_config.cache_dtype,
|
||||
)
|
||||
dtype = kv_cache_spec.dtype
|
||||
try:
|
||||
kv_cache_stride_order = attn_backend.get_kv_cache_stride_order()
|
||||
kv_cache_stride_order = attn_backend.get_kv_cache_stride_order() # noqa: E501
|
||||
assert len(kv_cache_stride_order) == len(kv_cache_shape)
|
||||
except (AttributeError, NotImplementedError):
|
||||
kv_cache_stride_order = tuple(range(len(kv_cache_shape)))
|
||||
@ -4320,10 +4457,11 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
"""
|
||||
kv_cache_config = deepcopy(kv_cache_config)
|
||||
self.kv_cache_config = kv_cache_config
|
||||
self.may_reinitialize_input_batch(kv_cache_config)
|
||||
self.may_add_encoder_only_layers_to_kv_cache_config()
|
||||
self.maybe_add_kv_sharing_layers_to_kv_cache_groups(kv_cache_config)
|
||||
self.initialize_attn_backend(kv_cache_config)
|
||||
# Reinitialize need to after initialize_attn_backend
|
||||
self.may_reinitialize_input_batch(kv_cache_config)
|
||||
kv_caches = self.initialize_kv_cache_tensors(kv_cache_config)
|
||||
|
||||
if self.speculative_config and self.speculative_config.use_eagle():
|
||||
|
||||
@ -27,6 +27,7 @@ class InputBatch:
|
||||
pin_memory: bool,
|
||||
vocab_size: int,
|
||||
block_sizes: list[int], # The block_size of each kv cache group
|
||||
kernel_block_sizes: list[int],
|
||||
):
|
||||
self.max_num_reqs = max_num_reqs
|
||||
self.max_model_len = max_model_len
|
||||
@ -68,6 +69,7 @@ class InputBatch:
|
||||
pin_memory=pin_memory,
|
||||
device=device,
|
||||
block_sizes=block_sizes,
|
||||
kernel_block_sizes=kernel_block_sizes,
|
||||
)
|
||||
|
||||
# Sampling-related.
|
||||
|
||||
@ -259,6 +259,7 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
pin_memory=self.pin_memory,
|
||||
vocab_size=self.model_config.get_vocab_size(),
|
||||
block_sizes=[self.block_size],
|
||||
kernel_block_sizes=[self.cache_config.block_size],
|
||||
)
|
||||
|
||||
# Cached torch/numpy tensor
|
||||
@ -1788,6 +1789,9 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
block_sizes=[
|
||||
kv_cache_config.kv_cache_groups[0].kv_cache_spec.block_size
|
||||
],
|
||||
kernel_block_sizes=[
|
||||
kv_cache_config.kv_cache_groups[0].kv_cache_spec.block_size
|
||||
],
|
||||
)
|
||||
# Verify dtype compatibility between block_table_cpu and input_batch
|
||||
assert (
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user