mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 08:34:56 +08:00
[v1] Re-init input batch for multiple kv cache groups (#18654)
Signed-off-by: Chen Zhang <zhangch99@outlook.com>
This commit is contained in:
parent
6865fe0074
commit
6cac54f4d1
@ -10,8 +10,6 @@ import torch
|
|||||||
|
|
||||||
from vllm.sampling_params import SamplingParams
|
from vllm.sampling_params import SamplingParams
|
||||||
from vllm.utils import is_pin_memory_available, make_tensor_with_pad
|
from vllm.utils import is_pin_memory_available, make_tensor_with_pad
|
||||||
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
|
|
||||||
KVCacheGroupSpec, KVCacheTensor)
|
|
||||||
from vllm.v1.sample.metadata import SamplingMetadata
|
from vllm.v1.sample.metadata import SamplingMetadata
|
||||||
from vllm.v1.worker.block_table import BlockTable, MultiGroupBlockTable
|
from vllm.v1.worker.block_table import BlockTable, MultiGroupBlockTable
|
||||||
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
|
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
|
||||||
@ -25,27 +23,6 @@ CUDA_DEVICES = [
|
|||||||
MAX_NUM_PROMPT_TOKENS = 64
|
MAX_NUM_PROMPT_TOKENS = 64
|
||||||
|
|
||||||
|
|
||||||
def get_kv_cache_config() -> KVCacheConfig:
|
|
||||||
return KVCacheConfig(
|
|
||||||
num_blocks=10,
|
|
||||||
tensors={
|
|
||||||
"layer.0": KVCacheTensor(size=1024),
|
|
||||||
},
|
|
||||||
kv_cache_groups=[
|
|
||||||
KVCacheGroupSpec(
|
|
||||||
layer_names=["layer.0"],
|
|
||||||
kv_cache_spec=FullAttentionSpec(
|
|
||||||
block_size=1,
|
|
||||||
num_kv_heads=1,
|
|
||||||
head_size=16,
|
|
||||||
dtype=torch.float16,
|
|
||||||
use_mla=False,
|
|
||||||
),
|
|
||||||
),
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _compare_objs(obj1, obj2):
|
def _compare_objs(obj1, obj2):
|
||||||
attrs = inspect.getmembers(obj1, lambda a: not (inspect.isroutine(a)))
|
attrs = inspect.getmembers(obj1, lambda a: not (inspect.isroutine(a)))
|
||||||
attr_names = set([
|
attr_names = set([
|
||||||
@ -252,7 +229,7 @@ def test_sampling_metadata_in_input_batch(device: str, batch_size: int):
|
|||||||
device=torch.device(device),
|
device=torch.device(device),
|
||||||
pin_memory=is_pin_memory_available(),
|
pin_memory=is_pin_memory_available(),
|
||||||
vocab_size=1024,
|
vocab_size=1024,
|
||||||
block_size=1,
|
block_sizes=[1],
|
||||||
)
|
)
|
||||||
reqs: list[CachedRequestState] = []
|
reqs: list[CachedRequestState] = []
|
||||||
req_id_reqs = {}
|
req_id_reqs = {}
|
||||||
@ -342,7 +319,7 @@ def test_swap_states_in_input_batch(device: str, batch_size: int,
|
|||||||
device=torch.device(device),
|
device=torch.device(device),
|
||||||
pin_memory=is_pin_memory_available(),
|
pin_memory=is_pin_memory_available(),
|
||||||
vocab_size=1024,
|
vocab_size=1024,
|
||||||
block_size=1,
|
block_sizes=[1],
|
||||||
)
|
)
|
||||||
ref_input_batch: InputBatch = InputBatch(
|
ref_input_batch: InputBatch = InputBatch(
|
||||||
max_num_reqs=batch_size,
|
max_num_reqs=batch_size,
|
||||||
@ -351,7 +328,7 @@ def test_swap_states_in_input_batch(device: str, batch_size: int,
|
|||||||
device=torch.device(device),
|
device=torch.device(device),
|
||||||
pin_memory=is_pin_memory_available(),
|
pin_memory=is_pin_memory_available(),
|
||||||
vocab_size=1024,
|
vocab_size=1024,
|
||||||
block_size=1,
|
block_sizes=[1],
|
||||||
)
|
)
|
||||||
|
|
||||||
reqs: list[CachedRequestState] = []
|
reqs: list[CachedRequestState] = []
|
||||||
|
|||||||
@ -54,7 +54,9 @@ def initialize_kv_cache(runner: GPUModelRunner):
|
|||||||
device=runner.device,
|
device=runner.device,
|
||||||
pin_memory=runner.pin_memory,
|
pin_memory=runner.pin_memory,
|
||||||
vocab_size=runner.model_config.get_vocab_size(),
|
vocab_size=runner.model_config.get_vocab_size(),
|
||||||
block_size=kv_cache_config.kv_cache_groups[0].kv_cache_spec.block_size,
|
block_sizes=[
|
||||||
|
kv_cache_config.kv_cache_groups[0].kv_cache_spec.block_size
|
||||||
|
],
|
||||||
)
|
)
|
||||||
runner.initialize_attn_backend(kv_cache_config)
|
runner.initialize_attn_backend(kv_cache_config)
|
||||||
|
|
||||||
|
|||||||
@ -105,10 +105,11 @@ class MultiGroupBlockTable:
|
|||||||
|
|
||||||
def __init__(self, max_num_reqs: int, max_model_len: int,
|
def __init__(self, max_num_reqs: int, max_model_len: int,
|
||||||
max_num_batched_tokens: int, pin_memory: bool,
|
max_num_batched_tokens: int, pin_memory: bool,
|
||||||
device: torch.device, block_size: int) -> None:
|
device: torch.device, block_sizes: list[int]) -> None:
|
||||||
self.block_tables = [
|
self.block_tables = [
|
||||||
BlockTable(max_num_reqs, cdiv(max_model_len, block_size),
|
BlockTable(max_num_reqs, cdiv(max_model_len, block_size),
|
||||||
max_num_batched_tokens, pin_memory, device)
|
max_num_batched_tokens, pin_memory, device)
|
||||||
|
for block_size in block_sizes
|
||||||
]
|
]
|
||||||
|
|
||||||
def append_row(self, block_ids: list[list[int]], row_idx: int) -> None:
|
def append_row(self, block_ids: list[list[int]], row_idx: int) -> None:
|
||||||
|
|||||||
@ -56,14 +56,14 @@ class CachedRequestState:
|
|||||||
class InputBatch:
|
class InputBatch:
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
max_num_reqs: int,
|
max_num_reqs: int,
|
||||||
max_model_len: int,
|
max_model_len: int,
|
||||||
max_num_batched_tokens: int,
|
max_num_batched_tokens: int,
|
||||||
device: torch.device,
|
device: torch.device,
|
||||||
pin_memory: bool,
|
pin_memory: bool,
|
||||||
vocab_size: int,
|
vocab_size: int,
|
||||||
block_size: int,
|
block_sizes: list[int], # The block_size of each kv cache group
|
||||||
):
|
):
|
||||||
self.max_num_reqs = max_num_reqs
|
self.max_num_reqs = max_num_reqs
|
||||||
self.max_model_len = max_model_len
|
self.max_model_len = max_model_len
|
||||||
@ -105,7 +105,7 @@ class InputBatch:
|
|||||||
max_num_batched_tokens=max_num_batched_tokens,
|
max_num_batched_tokens=max_num_batched_tokens,
|
||||||
pin_memory=pin_memory,
|
pin_memory=pin_memory,
|
||||||
device=device,
|
device=device,
|
||||||
block_size=block_size,
|
block_sizes=block_sizes,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Sampling-related.
|
# Sampling-related.
|
||||||
|
|||||||
@ -143,7 +143,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
self.attn_metadata_builders: list[AttentionMetadataBuilder] = []
|
self.attn_metadata_builders: list[AttentionMetadataBuilder] = []
|
||||||
self.attn_backends: list[type[AttentionBackend]] = []
|
self.attn_backends: list[type[AttentionBackend]] = []
|
||||||
# self.kv_cache_config: KVCacheConfig
|
# self.kv_cache_config: KVCacheConfig
|
||||||
# self.input_batch: InputBatch # Persistent batch.
|
|
||||||
|
|
||||||
# req_id -> (input_id -> encoder_output)
|
# req_id -> (input_id -> encoder_output)
|
||||||
self.encoder_cache: dict[str, dict[int, torch.Tensor]] = {}
|
self.encoder_cache: dict[str, dict[int, torch.Tensor]] = {}
|
||||||
@ -173,6 +172,15 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
# Request states.
|
# Request states.
|
||||||
self.requests: dict[str, CachedRequestState] = {}
|
self.requests: dict[str, CachedRequestState] = {}
|
||||||
|
|
||||||
|
# Input Batch
|
||||||
|
# NOTE(Chen): Ideally, we should initialize the input batch inside
|
||||||
|
# `initialize_kv_cache` based on the kv cache config. However, as in
|
||||||
|
# https://github.com/vllm-project/vllm/pull/18298, due to some unknown
|
||||||
|
# reasons, we have to initialize the input batch before `load_model`,
|
||||||
|
# quantization + weight offloading will fail otherwise. As a temporary
|
||||||
|
# solution, we initialize the input batch here, and re-initialize it
|
||||||
|
# in `initialize_kv_cache` if the block_sizes here is different from
|
||||||
|
# the block_sizes in the kv cache config.
|
||||||
self.input_batch = InputBatch(
|
self.input_batch = InputBatch(
|
||||||
max_num_reqs=self.max_num_reqs,
|
max_num_reqs=self.max_num_reqs,
|
||||||
max_model_len=self.max_model_len,
|
max_model_len=self.max_model_len,
|
||||||
@ -180,7 +188,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
device=self.device,
|
device=self.device,
|
||||||
pin_memory=self.pin_memory,
|
pin_memory=self.pin_memory,
|
||||||
vocab_size=self.model_config.get_vocab_size(),
|
vocab_size=self.model_config.get_vocab_size(),
|
||||||
block_size=self.cache_config.block_size,
|
block_sizes=[self.cache_config.block_size],
|
||||||
)
|
)
|
||||||
|
|
||||||
self.use_cuda_graph = (self.vllm_config.compilation_config.level
|
self.use_cuda_graph = (self.vllm_config.compilation_config.level
|
||||||
@ -2040,6 +2048,35 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
self.attn_backends.append(attn_backend_i)
|
self.attn_backends.append(attn_backend_i)
|
||||||
self.attn_metadata_builders.append(attn_metadata_builder_i)
|
self.attn_metadata_builders.append(attn_metadata_builder_i)
|
||||||
|
|
||||||
|
def may_reinitialize_input_batch(self,
|
||||||
|
kv_cache_config: KVCacheConfig) -> None:
|
||||||
|
"""
|
||||||
|
Re-initialize the input batch if the block sizes are different from
|
||||||
|
`[self.cache_config.block_size]`. This usually happens when there
|
||||||
|
are multiple KV cache groups.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
kv_cache_config: The KV cache configuration.
|
||||||
|
"""
|
||||||
|
block_sizes = [
|
||||||
|
kv_cache_group.kv_cache_spec.block_size
|
||||||
|
for kv_cache_group in kv_cache_config.kv_cache_groups
|
||||||
|
]
|
||||||
|
if block_sizes != [self.cache_config.block_size]:
|
||||||
|
assert self.cache_config.cpu_offload_gb == 0, (
|
||||||
|
"Cannot re-initialize the input batch when CPU weight "
|
||||||
|
"offloading is enabled. See https://github.com/vllm-project/vllm/pull/18298 " # noqa: E501
|
||||||
|
"for more details.")
|
||||||
|
self.input_batch = InputBatch(
|
||||||
|
max_num_reqs=self.max_num_reqs,
|
||||||
|
max_model_len=self.max_model_len,
|
||||||
|
max_num_batched_tokens=self.max_num_tokens,
|
||||||
|
device=self.device,
|
||||||
|
pin_memory=self.pin_memory,
|
||||||
|
vocab_size=self.model_config.get_vocab_size(),
|
||||||
|
block_sizes=block_sizes,
|
||||||
|
)
|
||||||
|
|
||||||
def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
|
def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
|
||||||
"""
|
"""
|
||||||
Initialize KV cache based on `kv_cache_config`.
|
Initialize KV cache based on `kv_cache_config`.
|
||||||
@ -2047,11 +2084,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
kv_cache_config: Configuration for the KV cache, including the KV
|
kv_cache_config: Configuration for the KV cache, including the KV
|
||||||
cache size of each layer
|
cache size of each layer
|
||||||
"""
|
"""
|
||||||
if len(kv_cache_config.kv_cache_groups) > 1:
|
|
||||||
raise NotImplementedError(
|
|
||||||
"Hybrid models with more than one KV cache type are not "
|
|
||||||
"supported yet.")
|
|
||||||
self.kv_cache_config = kv_cache_config
|
self.kv_cache_config = kv_cache_config
|
||||||
|
self.may_reinitialize_input_batch(kv_cache_config)
|
||||||
self.initialize_attn_backend(kv_cache_config)
|
self.initialize_attn_backend(kv_cache_config)
|
||||||
|
|
||||||
kv_caches: dict[str, torch.Tensor] = {}
|
kv_caches: dict[str, torch.Tensor] = {}
|
||||||
|
|||||||
@ -200,7 +200,7 @@ class TPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
device=self.device,
|
device=self.device,
|
||||||
pin_memory=self.pin_memory,
|
pin_memory=self.pin_memory,
|
||||||
vocab_size=self.model_config.get_vocab_size(),
|
vocab_size=self.model_config.get_vocab_size(),
|
||||||
block_size=self.block_size,
|
block_sizes=[self.block_size],
|
||||||
)
|
)
|
||||||
|
|
||||||
# Cached torch/numpy tensor
|
# Cached torch/numpy tensor
|
||||||
@ -1358,8 +1358,9 @@ class TPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
device=self.device,
|
device=self.device,
|
||||||
pin_memory=self.pin_memory,
|
pin_memory=self.pin_memory,
|
||||||
vocab_size=self.model_config.get_vocab_size(),
|
vocab_size=self.model_config.get_vocab_size(),
|
||||||
block_size=kv_cache_config.kv_cache_groups[0].kv_cache_spec.
|
block_sizes=[
|
||||||
block_size,
|
kv_cache_config.kv_cache_groups[0].kv_cache_spec.block_size
|
||||||
|
],
|
||||||
)
|
)
|
||||||
# Verify dtype compatibility between block_table_cpu and input_batch
|
# Verify dtype compatibility between block_table_cpu and input_batch
|
||||||
assert self.block_table_cpu.dtype == self.input_batch.block_table[
|
assert self.block_table_cpu.dtype == self.input_batch.block_table[
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user