[v1] Re-init input batch for multiple kv cache groups (#18654)

Signed-off-by: Chen Zhang <zhangch99@outlook.com>
This commit is contained in:
Chen Zhang 2025-06-04 05:41:36 +08:00 committed by GitHub
parent 6865fe0074
commit 6cac54f4d1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 61 additions and 46 deletions

View File

@ -10,8 +10,6 @@ import torch
from vllm.sampling_params import SamplingParams
from vllm.utils import is_pin_memory_available, make_tensor_with_pad
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
KVCacheGroupSpec, KVCacheTensor)
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.worker.block_table import BlockTable, MultiGroupBlockTable
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
@ -25,27 +23,6 @@ CUDA_DEVICES = [
MAX_NUM_PROMPT_TOKENS = 64
def get_kv_cache_config() -> KVCacheConfig:
return KVCacheConfig(
num_blocks=10,
tensors={
"layer.0": KVCacheTensor(size=1024),
},
kv_cache_groups=[
KVCacheGroupSpec(
layer_names=["layer.0"],
kv_cache_spec=FullAttentionSpec(
block_size=1,
num_kv_heads=1,
head_size=16,
dtype=torch.float16,
use_mla=False,
),
),
],
)
def _compare_objs(obj1, obj2):
attrs = inspect.getmembers(obj1, lambda a: not (inspect.isroutine(a)))
attr_names = set([
@ -252,7 +229,7 @@ def test_sampling_metadata_in_input_batch(device: str, batch_size: int):
device=torch.device(device),
pin_memory=is_pin_memory_available(),
vocab_size=1024,
block_size=1,
block_sizes=[1],
)
reqs: list[CachedRequestState] = []
req_id_reqs = {}
@ -342,7 +319,7 @@ def test_swap_states_in_input_batch(device: str, batch_size: int,
device=torch.device(device),
pin_memory=is_pin_memory_available(),
vocab_size=1024,
block_size=1,
block_sizes=[1],
)
ref_input_batch: InputBatch = InputBatch(
max_num_reqs=batch_size,
@ -351,7 +328,7 @@ def test_swap_states_in_input_batch(device: str, batch_size: int,
device=torch.device(device),
pin_memory=is_pin_memory_available(),
vocab_size=1024,
block_size=1,
block_sizes=[1],
)
reqs: list[CachedRequestState] = []

View File

@ -54,7 +54,9 @@ def initialize_kv_cache(runner: GPUModelRunner):
device=runner.device,
pin_memory=runner.pin_memory,
vocab_size=runner.model_config.get_vocab_size(),
block_size=kv_cache_config.kv_cache_groups[0].kv_cache_spec.block_size,
block_sizes=[
kv_cache_config.kv_cache_groups[0].kv_cache_spec.block_size
],
)
runner.initialize_attn_backend(kv_cache_config)

View File

@ -105,10 +105,11 @@ class MultiGroupBlockTable:
def __init__(self, max_num_reqs: int, max_model_len: int,
max_num_batched_tokens: int, pin_memory: bool,
device: torch.device, block_size: int) -> None:
device: torch.device, block_sizes: list[int]) -> None:
self.block_tables = [
BlockTable(max_num_reqs, cdiv(max_model_len, block_size),
max_num_batched_tokens, pin_memory, device)
for block_size in block_sizes
]
def append_row(self, block_ids: list[list[int]], row_idx: int) -> None:

View File

@ -63,7 +63,7 @@ class InputBatch:
device: torch.device,
pin_memory: bool,
vocab_size: int,
block_size: int,
block_sizes: list[int], # The block_size of each kv cache group
):
self.max_num_reqs = max_num_reqs
self.max_model_len = max_model_len
@ -105,7 +105,7 @@ class InputBatch:
max_num_batched_tokens=max_num_batched_tokens,
pin_memory=pin_memory,
device=device,
block_size=block_size,
block_sizes=block_sizes,
)
# Sampling-related.

View File

@ -143,7 +143,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self.attn_metadata_builders: list[AttentionMetadataBuilder] = []
self.attn_backends: list[type[AttentionBackend]] = []
# self.kv_cache_config: KVCacheConfig
# self.input_batch: InputBatch # Persistent batch.
# req_id -> (input_id -> encoder_output)
self.encoder_cache: dict[str, dict[int, torch.Tensor]] = {}
@ -173,6 +172,15 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# Request states.
self.requests: dict[str, CachedRequestState] = {}
# Input Batch
# NOTE(Chen): Ideally, we should initialize the input batch inside
# `initialize_kv_cache` based on the kv cache config. However, as in
# https://github.com/vllm-project/vllm/pull/18298, due to some unknown
# reasons, we have to initialize the input batch before `load_model`,
# quantization + weight offloading will fail otherwise. As a temporary
# solution, we initialize the input batch here, and re-initialize it
# in `initialize_kv_cache` if the block_sizes here is different from
# the block_sizes in the kv cache config.
self.input_batch = InputBatch(
max_num_reqs=self.max_num_reqs,
max_model_len=self.max_model_len,
@ -180,7 +188,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
device=self.device,
pin_memory=self.pin_memory,
vocab_size=self.model_config.get_vocab_size(),
block_size=self.cache_config.block_size,
block_sizes=[self.cache_config.block_size],
)
self.use_cuda_graph = (self.vllm_config.compilation_config.level
@ -2040,6 +2048,35 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self.attn_backends.append(attn_backend_i)
self.attn_metadata_builders.append(attn_metadata_builder_i)
def may_reinitialize_input_batch(self,
kv_cache_config: KVCacheConfig) -> None:
"""
Re-initialize the input batch if the block sizes are different from
`[self.cache_config.block_size]`. This usually happens when there
are multiple KV cache groups.
Args:
kv_cache_config: The KV cache configuration.
"""
block_sizes = [
kv_cache_group.kv_cache_spec.block_size
for kv_cache_group in kv_cache_config.kv_cache_groups
]
if block_sizes != [self.cache_config.block_size]:
assert self.cache_config.cpu_offload_gb == 0, (
"Cannot re-initialize the input batch when CPU weight "
"offloading is enabled. See https://github.com/vllm-project/vllm/pull/18298 " # noqa: E501
"for more details.")
self.input_batch = InputBatch(
max_num_reqs=self.max_num_reqs,
max_model_len=self.max_model_len,
max_num_batched_tokens=self.max_num_tokens,
device=self.device,
pin_memory=self.pin_memory,
vocab_size=self.model_config.get_vocab_size(),
block_sizes=block_sizes,
)
def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
"""
Initialize KV cache based on `kv_cache_config`.
@ -2047,11 +2084,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
kv_cache_config: Configuration for the KV cache, including the KV
cache size of each layer
"""
if len(kv_cache_config.kv_cache_groups) > 1:
raise NotImplementedError(
"Hybrid models with more than one KV cache type are not "
"supported yet.")
self.kv_cache_config = kv_cache_config
self.may_reinitialize_input_batch(kv_cache_config)
self.initialize_attn_backend(kv_cache_config)
kv_caches: dict[str, torch.Tensor] = {}

View File

@ -200,7 +200,7 @@ class TPUModelRunner(LoRAModelRunnerMixin):
device=self.device,
pin_memory=self.pin_memory,
vocab_size=self.model_config.get_vocab_size(),
block_size=self.block_size,
block_sizes=[self.block_size],
)
# Cached torch/numpy tensor
@ -1358,8 +1358,9 @@ class TPUModelRunner(LoRAModelRunnerMixin):
device=self.device,
pin_memory=self.pin_memory,
vocab_size=self.model_config.get_vocab_size(),
block_size=kv_cache_config.kv_cache_groups[0].kv_cache_spec.
block_size,
block_sizes=[
kv_cache_config.kv_cache_groups[0].kv_cache_spec.block_size
],
)
# Verify dtype compatibility between block_table_cpu and input_batch
assert self.block_table_cpu.dtype == self.input_batch.block_table[