mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 03:54:56 +08:00
[v1] Re-init input batch for multiple kv cache groups (#18654)
Signed-off-by: Chen Zhang <zhangch99@outlook.com>
This commit is contained in:
parent
6865fe0074
commit
6cac54f4d1
@ -10,8 +10,6 @@ import torch
|
||||
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.utils import is_pin_memory_available, make_tensor_with_pad
|
||||
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
|
||||
KVCacheGroupSpec, KVCacheTensor)
|
||||
from vllm.v1.sample.metadata import SamplingMetadata
|
||||
from vllm.v1.worker.block_table import BlockTable, MultiGroupBlockTable
|
||||
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
|
||||
@ -25,27 +23,6 @@ CUDA_DEVICES = [
|
||||
MAX_NUM_PROMPT_TOKENS = 64
|
||||
|
||||
|
||||
def get_kv_cache_config() -> KVCacheConfig:
|
||||
return KVCacheConfig(
|
||||
num_blocks=10,
|
||||
tensors={
|
||||
"layer.0": KVCacheTensor(size=1024),
|
||||
},
|
||||
kv_cache_groups=[
|
||||
KVCacheGroupSpec(
|
||||
layer_names=["layer.0"],
|
||||
kv_cache_spec=FullAttentionSpec(
|
||||
block_size=1,
|
||||
num_kv_heads=1,
|
||||
head_size=16,
|
||||
dtype=torch.float16,
|
||||
use_mla=False,
|
||||
),
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
def _compare_objs(obj1, obj2):
|
||||
attrs = inspect.getmembers(obj1, lambda a: not (inspect.isroutine(a)))
|
||||
attr_names = set([
|
||||
@ -252,7 +229,7 @@ def test_sampling_metadata_in_input_batch(device: str, batch_size: int):
|
||||
device=torch.device(device),
|
||||
pin_memory=is_pin_memory_available(),
|
||||
vocab_size=1024,
|
||||
block_size=1,
|
||||
block_sizes=[1],
|
||||
)
|
||||
reqs: list[CachedRequestState] = []
|
||||
req_id_reqs = {}
|
||||
@ -342,7 +319,7 @@ def test_swap_states_in_input_batch(device: str, batch_size: int,
|
||||
device=torch.device(device),
|
||||
pin_memory=is_pin_memory_available(),
|
||||
vocab_size=1024,
|
||||
block_size=1,
|
||||
block_sizes=[1],
|
||||
)
|
||||
ref_input_batch: InputBatch = InputBatch(
|
||||
max_num_reqs=batch_size,
|
||||
@ -351,7 +328,7 @@ def test_swap_states_in_input_batch(device: str, batch_size: int,
|
||||
device=torch.device(device),
|
||||
pin_memory=is_pin_memory_available(),
|
||||
vocab_size=1024,
|
||||
block_size=1,
|
||||
block_sizes=[1],
|
||||
)
|
||||
|
||||
reqs: list[CachedRequestState] = []
|
||||
|
||||
@ -54,7 +54,9 @@ def initialize_kv_cache(runner: GPUModelRunner):
|
||||
device=runner.device,
|
||||
pin_memory=runner.pin_memory,
|
||||
vocab_size=runner.model_config.get_vocab_size(),
|
||||
block_size=kv_cache_config.kv_cache_groups[0].kv_cache_spec.block_size,
|
||||
block_sizes=[
|
||||
kv_cache_config.kv_cache_groups[0].kv_cache_spec.block_size
|
||||
],
|
||||
)
|
||||
runner.initialize_attn_backend(kv_cache_config)
|
||||
|
||||
|
||||
@ -105,10 +105,11 @@ class MultiGroupBlockTable:
|
||||
|
||||
def __init__(self, max_num_reqs: int, max_model_len: int,
|
||||
max_num_batched_tokens: int, pin_memory: bool,
|
||||
device: torch.device, block_size: int) -> None:
|
||||
device: torch.device, block_sizes: list[int]) -> None:
|
||||
self.block_tables = [
|
||||
BlockTable(max_num_reqs, cdiv(max_model_len, block_size),
|
||||
max_num_batched_tokens, pin_memory, device)
|
||||
for block_size in block_sizes
|
||||
]
|
||||
|
||||
def append_row(self, block_ids: list[list[int]], row_idx: int) -> None:
|
||||
|
||||
@ -63,7 +63,7 @@ class InputBatch:
|
||||
device: torch.device,
|
||||
pin_memory: bool,
|
||||
vocab_size: int,
|
||||
block_size: int,
|
||||
block_sizes: list[int], # The block_size of each kv cache group
|
||||
):
|
||||
self.max_num_reqs = max_num_reqs
|
||||
self.max_model_len = max_model_len
|
||||
@ -105,7 +105,7 @@ class InputBatch:
|
||||
max_num_batched_tokens=max_num_batched_tokens,
|
||||
pin_memory=pin_memory,
|
||||
device=device,
|
||||
block_size=block_size,
|
||||
block_sizes=block_sizes,
|
||||
)
|
||||
|
||||
# Sampling-related.
|
||||
|
||||
@ -143,7 +143,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
self.attn_metadata_builders: list[AttentionMetadataBuilder] = []
|
||||
self.attn_backends: list[type[AttentionBackend]] = []
|
||||
# self.kv_cache_config: KVCacheConfig
|
||||
# self.input_batch: InputBatch # Persistent batch.
|
||||
|
||||
# req_id -> (input_id -> encoder_output)
|
||||
self.encoder_cache: dict[str, dict[int, torch.Tensor]] = {}
|
||||
@ -173,6 +172,15 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
# Request states.
|
||||
self.requests: dict[str, CachedRequestState] = {}
|
||||
|
||||
# Input Batch
|
||||
# NOTE(Chen): Ideally, we should initialize the input batch inside
|
||||
# `initialize_kv_cache` based on the kv cache config. However, as in
|
||||
# https://github.com/vllm-project/vllm/pull/18298, due to some unknown
|
||||
# reasons, we have to initialize the input batch before `load_model`,
|
||||
# quantization + weight offloading will fail otherwise. As a temporary
|
||||
# solution, we initialize the input batch here, and re-initialize it
|
||||
# in `initialize_kv_cache` if the block_sizes here is different from
|
||||
# the block_sizes in the kv cache config.
|
||||
self.input_batch = InputBatch(
|
||||
max_num_reqs=self.max_num_reqs,
|
||||
max_model_len=self.max_model_len,
|
||||
@ -180,7 +188,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
device=self.device,
|
||||
pin_memory=self.pin_memory,
|
||||
vocab_size=self.model_config.get_vocab_size(),
|
||||
block_size=self.cache_config.block_size,
|
||||
block_sizes=[self.cache_config.block_size],
|
||||
)
|
||||
|
||||
self.use_cuda_graph = (self.vllm_config.compilation_config.level
|
||||
@ -2040,6 +2048,35 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
self.attn_backends.append(attn_backend_i)
|
||||
self.attn_metadata_builders.append(attn_metadata_builder_i)
|
||||
|
||||
def may_reinitialize_input_batch(self,
|
||||
kv_cache_config: KVCacheConfig) -> None:
|
||||
"""
|
||||
Re-initialize the input batch if the block sizes are different from
|
||||
`[self.cache_config.block_size]`. This usually happens when there
|
||||
are multiple KV cache groups.
|
||||
|
||||
Args:
|
||||
kv_cache_config: The KV cache configuration.
|
||||
"""
|
||||
block_sizes = [
|
||||
kv_cache_group.kv_cache_spec.block_size
|
||||
for kv_cache_group in kv_cache_config.kv_cache_groups
|
||||
]
|
||||
if block_sizes != [self.cache_config.block_size]:
|
||||
assert self.cache_config.cpu_offload_gb == 0, (
|
||||
"Cannot re-initialize the input batch when CPU weight "
|
||||
"offloading is enabled. See https://github.com/vllm-project/vllm/pull/18298 " # noqa: E501
|
||||
"for more details.")
|
||||
self.input_batch = InputBatch(
|
||||
max_num_reqs=self.max_num_reqs,
|
||||
max_model_len=self.max_model_len,
|
||||
max_num_batched_tokens=self.max_num_tokens,
|
||||
device=self.device,
|
||||
pin_memory=self.pin_memory,
|
||||
vocab_size=self.model_config.get_vocab_size(),
|
||||
block_sizes=block_sizes,
|
||||
)
|
||||
|
||||
def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
|
||||
"""
|
||||
Initialize KV cache based on `kv_cache_config`.
|
||||
@ -2047,11 +2084,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
kv_cache_config: Configuration for the KV cache, including the KV
|
||||
cache size of each layer
|
||||
"""
|
||||
if len(kv_cache_config.kv_cache_groups) > 1:
|
||||
raise NotImplementedError(
|
||||
"Hybrid models with more than one KV cache type are not "
|
||||
"supported yet.")
|
||||
self.kv_cache_config = kv_cache_config
|
||||
self.may_reinitialize_input_batch(kv_cache_config)
|
||||
self.initialize_attn_backend(kv_cache_config)
|
||||
|
||||
kv_caches: dict[str, torch.Tensor] = {}
|
||||
|
||||
@ -200,7 +200,7 @@ class TPUModelRunner(LoRAModelRunnerMixin):
|
||||
device=self.device,
|
||||
pin_memory=self.pin_memory,
|
||||
vocab_size=self.model_config.get_vocab_size(),
|
||||
block_size=self.block_size,
|
||||
block_sizes=[self.block_size],
|
||||
)
|
||||
|
||||
# Cached torch/numpy tensor
|
||||
@ -1358,8 +1358,9 @@ class TPUModelRunner(LoRAModelRunnerMixin):
|
||||
device=self.device,
|
||||
pin_memory=self.pin_memory,
|
||||
vocab_size=self.model_config.get_vocab_size(),
|
||||
block_size=kv_cache_config.kv_cache_groups[0].kv_cache_spec.
|
||||
block_size,
|
||||
block_sizes=[
|
||||
kv_cache_config.kv_cache_groups[0].kv_cache_spec.block_size
|
||||
],
|
||||
)
|
||||
# Verify dtype compatibility between block_table_cpu and input_batch
|
||||
assert self.block_table_cpu.dtype == self.input_batch.block_table[
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user