mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 06:35:00 +08:00
[Hybrid] A simpler algorithm to find kernel_block_size (#26476)
Signed-off-by: Chen Zhang <zhangch99@outlook.com>
This commit is contained in:
parent
0e0a638c3b
commit
df334868ca
@ -6,6 +6,7 @@ import pytest
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from vllm.attention import Attention
|
from vllm.attention import Attention
|
||||||
|
from vllm.attention.backends.abstract import MultipleOf
|
||||||
from vllm.config import (
|
from vllm.config import (
|
||||||
CacheConfig,
|
CacheConfig,
|
||||||
ModelConfig,
|
ModelConfig,
|
||||||
@ -34,6 +35,7 @@ from vllm.v1.kv_cache_interface import (
|
|||||||
from vllm.v1.sample.metadata import SamplingMetadata
|
from vllm.v1.sample.metadata import SamplingMetadata
|
||||||
from vllm.v1.worker.gpu_input_batch import InputBatch
|
from vllm.v1.worker.gpu_input_batch import InputBatch
|
||||||
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
|
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
|
||||||
|
from vllm.v1.worker.utils import AttentionGroup
|
||||||
|
|
||||||
BLOCK_SIZE = 16
|
BLOCK_SIZE = 16
|
||||||
NUM_BLOCKS = 10
|
NUM_BLOCKS = 10
|
||||||
@ -181,6 +183,57 @@ def _is_req_state_block_table_match(model_runner, req_id: str) -> bool:
|
|||||||
).all()
|
).all()
|
||||||
|
|
||||||
|
|
||||||
|
def _make_mock_backend_for_kernel_block_size(
|
||||||
|
supported_sizes: list[int | MultipleOf],
|
||||||
|
):
|
||||||
|
class _MockBackend:
|
||||||
|
@staticmethod
|
||||||
|
def get_supported_kernel_block_size():
|
||||||
|
return supported_sizes
|
||||||
|
|
||||||
|
return _MockBackend()
|
||||||
|
|
||||||
|
|
||||||
|
def _make_kv_cache_spec() -> FullAttentionSpec:
|
||||||
|
return FullAttentionSpec(block_size=1, num_kv_heads=1, head_size=1, dtype="float16")
|
||||||
|
|
||||||
|
|
||||||
|
def test_select_common_block_size_prefers_manager_block_size():
|
||||||
|
backend_a = _make_mock_backend_for_kernel_block_size([MultipleOf(32)])
|
||||||
|
backend_b = _make_mock_backend_for_kernel_block_size([64, MultipleOf(16)])
|
||||||
|
attn_groups = [
|
||||||
|
AttentionGroup(backend_a, [], [], _make_kv_cache_spec(), 0),
|
||||||
|
AttentionGroup(backend_b, [], [], _make_kv_cache_spec(), 0),
|
||||||
|
]
|
||||||
|
|
||||||
|
selected_size = GPUModelRunner.select_common_block_size(128, attn_groups)
|
||||||
|
assert selected_size == 128
|
||||||
|
|
||||||
|
|
||||||
|
def test_select_common_block_size_uses_largest_shared_int():
|
||||||
|
backend_a = _make_mock_backend_for_kernel_block_size([128, 64])
|
||||||
|
backend_b = _make_mock_backend_for_kernel_block_size([64, 32])
|
||||||
|
attn_groups = [
|
||||||
|
AttentionGroup(backend_a, [], [], _make_kv_cache_spec(), 0),
|
||||||
|
AttentionGroup(backend_b, [], [], _make_kv_cache_spec(), 0),
|
||||||
|
]
|
||||||
|
|
||||||
|
selected_size = GPUModelRunner.select_common_block_size(256, attn_groups)
|
||||||
|
assert selected_size == 64
|
||||||
|
|
||||||
|
|
||||||
|
def test_select_common_block_size_no_valid_option():
|
||||||
|
backend_a = _make_mock_backend_for_kernel_block_size([64])
|
||||||
|
backend_b = _make_mock_backend_for_kernel_block_size([MultipleOf(16)])
|
||||||
|
attn_groups = [
|
||||||
|
AttentionGroup(backend_a, [], [], _make_kv_cache_spec(), 0),
|
||||||
|
AttentionGroup(backend_b, [], [], _make_kv_cache_spec(), 0),
|
||||||
|
]
|
||||||
|
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
GPUModelRunner.select_common_block_size(48, attn_groups)
|
||||||
|
|
||||||
|
|
||||||
def test_update_states_new_request(model_runner, dist_init):
|
def test_update_states_new_request(model_runner, dist_init):
|
||||||
req_id = "req_0"
|
req_id = "req_0"
|
||||||
|
|
||||||
|
|||||||
@ -3978,6 +3978,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
|
|
||||||
def create_attn_groups(
|
def create_attn_groups(
|
||||||
attn_backends_map: dict[AttentionGroupKey, list[str]],
|
attn_backends_map: dict[AttentionGroupKey, list[str]],
|
||||||
|
kv_cache_group_id: int,
|
||||||
) -> list[AttentionGroup]:
|
) -> list[AttentionGroup]:
|
||||||
attn_groups: list[AttentionGroup] = []
|
attn_groups: list[AttentionGroup] = []
|
||||||
for (attn_backend, kv_cache_spec), layer_names in attn_backends_map.items():
|
for (attn_backend, kv_cache_spec), layer_names in attn_backends_map.items():
|
||||||
@ -3987,6 +3988,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
kv_cache_spec,
|
kv_cache_spec,
|
||||||
self.vllm_config,
|
self.vllm_config,
|
||||||
self.device,
|
self.device,
|
||||||
|
kv_cache_group_id,
|
||||||
num_metadata_builders=1
|
num_metadata_builders=1
|
||||||
if not self.parallel_config.enable_dbo
|
if not self.parallel_config.enable_dbo
|
||||||
else 2,
|
else 2,
|
||||||
@ -4005,8 +4007,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
# Resolve cudagraph_mode before actually initialize metadata_builders
|
# Resolve cudagraph_mode before actually initialize metadata_builders
|
||||||
self._check_and_update_cudagraph_mode(attention_backend_set)
|
self._check_and_update_cudagraph_mode(attention_backend_set)
|
||||||
|
|
||||||
for attn_backends_map in attention_backend_maps:
|
for i, attn_backend_map in enumerate(attention_backend_maps):
|
||||||
self.attn_groups.append(create_attn_groups(attn_backends_map))
|
self.attn_groups.append(create_attn_groups(attn_backend_map, i))
|
||||||
|
|
||||||
# Calculate reorder batch threshold (if needed)
|
# Calculate reorder batch threshold (if needed)
|
||||||
self.calculate_reorder_batch_threshold()
|
self.calculate_reorder_batch_threshold()
|
||||||
@ -4156,87 +4158,81 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
return
|
return
|
||||||
self.reorder_batch_threshold = reduce(min_none_high, reorder_batch_thresholds)
|
self.reorder_batch_threshold = reduce(min_none_high, reorder_batch_thresholds)
|
||||||
|
|
||||||
def _find_compatible_block_sizes(
|
@staticmethod
|
||||||
self,
|
def select_common_block_size(
|
||||||
kv_manager_block_size: int,
|
kv_manager_block_size: int, attn_groups: list[AttentionGroup]
|
||||||
backend_cls: type[AttentionBackend],
|
|
||||||
return_all: bool = False,
|
|
||||||
) -> list[int]:
|
|
||||||
"""
|
|
||||||
Find compatible block sizes for a backend.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
kv_manager_block_size: Physical block size of KV cache
|
|
||||||
backend_cls: Attention backend class
|
|
||||||
return_all: Return all compatible sizes if True, max size if False
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Compatible block size(s) based on return_all parameter
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ValueError: If no compatible block size found
|
|
||||||
"""
|
|
||||||
supported_block_size = backend_cls.get_supported_kernel_block_size()
|
|
||||||
compatible_sizes = []
|
|
||||||
|
|
||||||
for block_size in supported_block_size:
|
|
||||||
if isinstance(block_size, int):
|
|
||||||
if kv_manager_block_size % block_size == 0:
|
|
||||||
compatible_sizes.append(block_size)
|
|
||||||
elif (
|
|
||||||
isinstance(block_size, MultipleOf)
|
|
||||||
and kv_manager_block_size % block_size.base == 0
|
|
||||||
):
|
|
||||||
compatible_sizes.append(kv_manager_block_size)
|
|
||||||
|
|
||||||
if not compatible_sizes:
|
|
||||||
raise ValueError(f"No compatible block size for {kv_manager_block_size}")
|
|
||||||
|
|
||||||
return compatible_sizes if return_all else [max(compatible_sizes)]
|
|
||||||
|
|
||||||
def _select_common_block_size(
|
|
||||||
self, kv_manager_block_size: int, attn_groups: list[AttentionGroup]
|
|
||||||
) -> int:
|
) -> int:
|
||||||
"""
|
"""
|
||||||
Select common block size for all backends.
|
Select a block size that is supported by all backends and is a factor of
|
||||||
|
kv_manager_block_size.
|
||||||
|
|
||||||
|
If kv_manager_block_size is supported by all backends, return it directly.
|
||||||
|
Otherwise, return the max supported size.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
kv_manager_block_size: Block size of KV cache
|
kv_manager_block_size: Block size of KV cache
|
||||||
attn_groups: List of attention groups
|
attn_groups: List of attention groups
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Block size supported by all backends,
|
The selected block size
|
||||||
prioritizing cache_config.block_size
|
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
ValueError: If no common block size found
|
ValueError: If no valid block size found
|
||||||
"""
|
"""
|
||||||
all_backend_supports = []
|
|
||||||
|
|
||||||
for attn_group in attn_groups:
|
def block_size_is_supported(
|
||||||
compatible_sizes = self._find_compatible_block_sizes(
|
backends: list[type[AttentionBackend]], block_size: int
|
||||||
kv_manager_block_size, attn_group.backend, return_all=True
|
) -> bool:
|
||||||
)
|
"""
|
||||||
supported_sizes = sorted(list(set(compatible_sizes)), reverse=True)
|
Check if the block size is supported by all backends.
|
||||||
all_backend_supports.append(set(supported_sizes))
|
"""
|
||||||
|
for backend in backends:
|
||||||
|
is_supported = False
|
||||||
|
for supported_size in backend.get_supported_kernel_block_size():
|
||||||
|
if isinstance(supported_size, int):
|
||||||
|
if block_size == supported_size:
|
||||||
|
is_supported = True
|
||||||
|
elif isinstance(supported_size, MultipleOf):
|
||||||
|
if block_size % supported_size.base == 0:
|
||||||
|
is_supported = True
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown supported size: {supported_size}")
|
||||||
|
if not is_supported:
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
common_supported_sizes = set.intersection(*all_backend_supports)
|
backends = [group.backend for group in attn_groups]
|
||||||
|
|
||||||
if not common_supported_sizes:
|
# Case 1: if the block_size of kv cache manager is supported by all backends,
|
||||||
error_msg = f"No common block size for {kv_manager_block_size}. "
|
# return it directly
|
||||||
for i, attn_group in enumerate(attn_groups):
|
if block_size_is_supported(backends, kv_manager_block_size):
|
||||||
supported = all_backend_supports[i]
|
return kv_manager_block_size
|
||||||
error_msg += (
|
|
||||||
f"Backend {attn_group.backend} supports: {sorted(supported)}. "
|
|
||||||
)
|
|
||||||
raise ValueError(error_msg)
|
|
||||||
|
|
||||||
if self.cache_config.block_size in common_supported_sizes:
|
# Case 2: otherwise, the block_size must be an `int`-format supported size of
|
||||||
return self.cache_config.block_size
|
# at least one backend. Iterate over all `int`-format supported sizes in
|
||||||
|
# descending order and return the first one that is supported by all backends.
|
||||||
|
# Simple proof:
|
||||||
|
# If the supported size b is in MultipleOf(x_i) format for all attention
|
||||||
|
# backends i, and b a factor of kv_manager_block_size, then
|
||||||
|
# kv_manager_block_size also satisfies MultipleOf(x_i) for all i. We will
|
||||||
|
# return kv_manager_block_size in case 1.
|
||||||
|
all_int_supported_sizes = set(
|
||||||
|
supported_size
|
||||||
|
for backend in backends
|
||||||
|
for supported_size in backend.get_supported_kernel_block_size()
|
||||||
|
if isinstance(supported_size, int)
|
||||||
|
)
|
||||||
|
|
||||||
return max(common_supported_sizes)
|
for supported_size in sorted(all_int_supported_sizes, reverse=True):
|
||||||
|
if kv_manager_block_size % supported_size != 0:
|
||||||
|
continue
|
||||||
|
if block_size_is_supported(backends, supported_size):
|
||||||
|
return supported_size
|
||||||
|
raise ValueError(f"No common block size for {kv_manager_block_size}. ")
|
||||||
|
|
||||||
def may_reinitialize_input_batch(self, kv_cache_config: KVCacheConfig) -> None:
|
def may_reinitialize_input_batch(
|
||||||
|
self, kv_cache_config: KVCacheConfig, kernel_block_sizes: list[int]
|
||||||
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Re-initialize the input batch if the block sizes are different from
|
Re-initialize the input batch if the block sizes are different from
|
||||||
`[self.cache_config.block_size]`. This usually happens when there
|
`[self.cache_config.block_size]`. This usually happens when there
|
||||||
@ -4244,6 +4240,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
kv_cache_config: The KV cache configuration.
|
kv_cache_config: The KV cache configuration.
|
||||||
|
kernel_block_sizes: The kernel block sizes for each KV cache group.
|
||||||
"""
|
"""
|
||||||
block_sizes = [
|
block_sizes = [
|
||||||
kv_cache_group.kv_cache_spec.block_size
|
kv_cache_group.kv_cache_spec.block_size
|
||||||
@ -4251,9 +4248,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
if not isinstance(kv_cache_group.kv_cache_spec, EncoderOnlyAttentionSpec)
|
if not isinstance(kv_cache_group.kv_cache_spec, EncoderOnlyAttentionSpec)
|
||||||
]
|
]
|
||||||
|
|
||||||
# Generate kernel_block_sizes that matches each block_size
|
|
||||||
kernel_block_sizes = self._prepare_kernel_block_sizes(kv_cache_config)
|
|
||||||
|
|
||||||
if block_sizes != [self.cache_config.block_size] or kernel_block_sizes != [
|
if block_sizes != [self.cache_config.block_size] or kernel_block_sizes != [
|
||||||
self.cache_config.block_size
|
self.cache_config.block_size
|
||||||
]:
|
]:
|
||||||
@ -4354,7 +4348,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
# all backends in the group.
|
# all backends in the group.
|
||||||
attn_groups = self.attn_groups[kv_cache_group_id]
|
attn_groups = self.attn_groups[kv_cache_group_id]
|
||||||
kv_manager_block_size = kv_cache_group.kv_cache_spec.block_size
|
kv_manager_block_size = kv_cache_group.kv_cache_spec.block_size
|
||||||
selected_kernel_size = self._select_common_block_size(
|
selected_kernel_size = self.select_common_block_size(
|
||||||
kv_manager_block_size, attn_groups
|
kv_manager_block_size, attn_groups
|
||||||
)
|
)
|
||||||
kernel_block_sizes.append(selected_kernel_size)
|
kernel_block_sizes.append(selected_kernel_size)
|
||||||
@ -4372,6 +4366,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
self,
|
self,
|
||||||
kv_cache_config: KVCacheConfig,
|
kv_cache_config: KVCacheConfig,
|
||||||
kv_cache_raw_tensors: dict[str, torch.Tensor],
|
kv_cache_raw_tensors: dict[str, torch.Tensor],
|
||||||
|
kernel_block_sizes: list[int],
|
||||||
) -> dict[str, torch.Tensor]:
|
) -> dict[str, torch.Tensor]:
|
||||||
"""
|
"""
|
||||||
Reshape the KV cache tensors to the desired shape and dtype.
|
Reshape the KV cache tensors to the desired shape and dtype.
|
||||||
@ -4380,6 +4375,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
kv_cache_config: The KV cache config
|
kv_cache_config: The KV cache config
|
||||||
kv_cache_raw_tensors: The KV cache buffer of each layer, with
|
kv_cache_raw_tensors: The KV cache buffer of each layer, with
|
||||||
correct size but uninitialized shape.
|
correct size but uninitialized shape.
|
||||||
|
kernel_block_sizes: The kernel block sizes for each KV cache group.
|
||||||
Returns:
|
Returns:
|
||||||
Dict[str, torch.Tensor]: A map between layer names to their
|
Dict[str, torch.Tensor]: A map between layer names to their
|
||||||
corresponding memory buffer for KV cache.
|
corresponding memory buffer for KV cache.
|
||||||
@ -4389,6 +4385,10 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
for group in self._kv_cache_spec_attn_group_iterator():
|
for group in self._kv_cache_spec_attn_group_iterator():
|
||||||
kv_cache_spec = group.kv_cache_spec
|
kv_cache_spec = group.kv_cache_spec
|
||||||
attn_backend = group.backend
|
attn_backend = group.backend
|
||||||
|
if group.kv_cache_group_id == len(kernel_block_sizes):
|
||||||
|
# There may be a last group for layers without kv cache.
|
||||||
|
continue
|
||||||
|
kernel_block_size = kernel_block_sizes[group.kv_cache_group_id]
|
||||||
for layer_name in group.layer_names:
|
for layer_name in group.layer_names:
|
||||||
if layer_name in self.runner_only_attn_layers:
|
if layer_name in self.runner_only_attn_layers:
|
||||||
continue
|
continue
|
||||||
@ -4397,24 +4397,21 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
num_blocks = raw_tensor.numel() // kv_cache_spec.page_size_bytes
|
num_blocks = raw_tensor.numel() // kv_cache_spec.page_size_bytes
|
||||||
if isinstance(kv_cache_spec, AttentionSpec):
|
if isinstance(kv_cache_spec, AttentionSpec):
|
||||||
has_attn = True
|
has_attn = True
|
||||||
kv_manager_block_size = kv_cache_spec.block_size
|
num_blocks_per_kv_block = (
|
||||||
kernel_size_list = self._find_compatible_block_sizes(
|
kv_cache_spec.block_size // kernel_block_size
|
||||||
kv_manager_block_size, attn_backend, return_all=False
|
|
||||||
)
|
)
|
||||||
kernel_size = kernel_size_list[0]
|
|
||||||
num_blocks_per_kv_block = kv_manager_block_size // kernel_size
|
|
||||||
kernel_num_blocks = num_blocks * num_blocks_per_kv_block
|
kernel_num_blocks = num_blocks * num_blocks_per_kv_block
|
||||||
|
|
||||||
kv_cache_shape = attn_backend.get_kv_cache_shape(
|
kv_cache_shape = attn_backend.get_kv_cache_shape(
|
||||||
kernel_num_blocks,
|
kernel_num_blocks,
|
||||||
kernel_size,
|
kernel_block_size,
|
||||||
kv_cache_spec.num_kv_heads,
|
kv_cache_spec.num_kv_heads,
|
||||||
kv_cache_spec.head_size,
|
kv_cache_spec.head_size,
|
||||||
cache_dtype_str=self.cache_config.cache_dtype,
|
cache_dtype_str=self.cache_config.cache_dtype,
|
||||||
)
|
)
|
||||||
dtype = kv_cache_spec.dtype
|
dtype = kv_cache_spec.dtype
|
||||||
try:
|
try:
|
||||||
kv_cache_stride_order = attn_backend.get_kv_cache_stride_order() # noqa: E501
|
kv_cache_stride_order = attn_backend.get_kv_cache_stride_order()
|
||||||
assert len(kv_cache_stride_order) == len(kv_cache_shape)
|
assert len(kv_cache_stride_order) == len(kv_cache_shape)
|
||||||
except (AttributeError, NotImplementedError):
|
except (AttributeError, NotImplementedError):
|
||||||
kv_cache_stride_order = tuple(range(len(kv_cache_shape)))
|
kv_cache_stride_order = tuple(range(len(kv_cache_shape)))
|
||||||
@ -4497,13 +4494,15 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def initialize_kv_cache_tensors(
|
def initialize_kv_cache_tensors(
|
||||||
self, kv_cache_config: KVCacheConfig
|
self, kv_cache_config: KVCacheConfig, kernel_block_sizes: list[int]
|
||||||
) -> dict[str, torch.Tensor]:
|
) -> dict[str, torch.Tensor]:
|
||||||
"""
|
"""
|
||||||
Initialize the memory buffer for KV cache.
|
Initialize the memory buffer for KV cache.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
kv_cache_config: The KV cache config
|
kv_cache_config: The KV cache config
|
||||||
|
kernel_block_sizes: The kernel block sizes for each KV cache group.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Dict[str, torch.Tensor]: A map between layer names to their
|
Dict[str, torch.Tensor]: A map between layer names to their
|
||||||
corresponding memory buffer for KV cache.
|
corresponding memory buffer for KV cache.
|
||||||
@ -4512,7 +4511,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
kv_cache_raw_tensors = self._allocate_kv_cache_tensors(kv_cache_config)
|
kv_cache_raw_tensors = self._allocate_kv_cache_tensors(kv_cache_config)
|
||||||
# Change the memory buffer to the desired shape
|
# Change the memory buffer to the desired shape
|
||||||
kv_caches = self._reshape_kv_cache_tensors(
|
kv_caches = self._reshape_kv_cache_tensors(
|
||||||
kv_cache_config, kv_cache_raw_tensors
|
kv_cache_config, kv_cache_raw_tensors, kernel_block_sizes
|
||||||
)
|
)
|
||||||
|
|
||||||
# Set up cross-layer KV cache sharing
|
# Set up cross-layer KV cache sharing
|
||||||
@ -4571,9 +4570,17 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
self.may_add_encoder_only_layers_to_kv_cache_config()
|
self.may_add_encoder_only_layers_to_kv_cache_config()
|
||||||
self.maybe_add_kv_sharing_layers_to_kv_cache_groups(kv_cache_config)
|
self.maybe_add_kv_sharing_layers_to_kv_cache_groups(kv_cache_config)
|
||||||
self.initialize_attn_backend(kv_cache_config)
|
self.initialize_attn_backend(kv_cache_config)
|
||||||
|
# The kernel block size for all KV cache groups. For example, if
|
||||||
|
# kv_cache_manager uses block_size 256 for a given group, but the attention
|
||||||
|
# backends for that group only supports block_size 64, we will return
|
||||||
|
# kernel_block_size 64 and split the 256-token-block to 4 blocks with 64
|
||||||
|
# tokens each.
|
||||||
|
kernel_block_sizes = self._prepare_kernel_block_sizes(kv_cache_config)
|
||||||
# Reinitialize need to after initialize_attn_backend
|
# Reinitialize need to after initialize_attn_backend
|
||||||
self.may_reinitialize_input_batch(kv_cache_config)
|
self.may_reinitialize_input_batch(kv_cache_config, kernel_block_sizes)
|
||||||
kv_caches = self.initialize_kv_cache_tensors(kv_cache_config)
|
kv_caches = self.initialize_kv_cache_tensors(
|
||||||
|
kv_cache_config, kernel_block_sizes
|
||||||
|
)
|
||||||
|
|
||||||
if self.speculative_config and self.speculative_config.use_eagle():
|
if self.speculative_config and self.speculative_config.use_eagle():
|
||||||
assert isinstance(self.drafter, EagleProposer)
|
assert isinstance(self.drafter, EagleProposer)
|
||||||
|
|||||||
@ -140,6 +140,7 @@ class AttentionGroup:
|
|||||||
metadata_builders: list[AttentionMetadataBuilder]
|
metadata_builders: list[AttentionMetadataBuilder]
|
||||||
layer_names: list[str]
|
layer_names: list[str]
|
||||||
kv_cache_spec: KVCacheSpec
|
kv_cache_spec: KVCacheSpec
|
||||||
|
kv_cache_group_id: int
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def create_with_metadata_builders(
|
def create_with_metadata_builders(
|
||||||
@ -148,13 +149,16 @@ class AttentionGroup:
|
|||||||
kv_cache_spec: KVCacheSpec,
|
kv_cache_spec: KVCacheSpec,
|
||||||
vllm_config: VllmConfig,
|
vllm_config: VllmConfig,
|
||||||
device: torch.device,
|
device: torch.device,
|
||||||
|
kv_cache_group_id: int,
|
||||||
num_metadata_builders: int = 1,
|
num_metadata_builders: int = 1,
|
||||||
) -> "AttentionGroup":
|
) -> "AttentionGroup":
|
||||||
metadata_builders = [
|
metadata_builders = [
|
||||||
backend.get_builder_cls()(kv_cache_spec, layer_names, vllm_config, device)
|
backend.get_builder_cls()(kv_cache_spec, layer_names, vllm_config, device)
|
||||||
for _ in range(num_metadata_builders)
|
for _ in range(num_metadata_builders)
|
||||||
]
|
]
|
||||||
return AttentionGroup(backend, metadata_builders, layer_names, kv_cache_spec)
|
return AttentionGroup(
|
||||||
|
backend, metadata_builders, layer_names, kv_cache_spec, kv_cache_group_id
|
||||||
|
)
|
||||||
|
|
||||||
def get_metadata_builder(self, ubatch_id: int = 0) -> AttentionMetadataBuilder:
|
def get_metadata_builder(self, ubatch_id: int = 0) -> AttentionMetadataBuilder:
|
||||||
assert len(self.metadata_builders) > ubatch_id
|
assert len(self.metadata_builders) > ubatch_id
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user