diff --git a/tests/v1/worker/test_gpu_model_runner.py b/tests/v1/worker/test_gpu_model_runner.py index 9007436350be..23ab70480fbb 100644 --- a/tests/v1/worker/test_gpu_model_runner.py +++ b/tests/v1/worker/test_gpu_model_runner.py @@ -6,6 +6,7 @@ import pytest import torch from vllm.attention import Attention +from vllm.attention.backends.abstract import MultipleOf from vllm.config import ( CacheConfig, ModelConfig, @@ -34,6 +35,7 @@ from vllm.v1.kv_cache_interface import ( from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.worker.gpu_input_batch import InputBatch from vllm.v1.worker.gpu_model_runner import GPUModelRunner +from vllm.v1.worker.utils import AttentionGroup BLOCK_SIZE = 16 NUM_BLOCKS = 10 @@ -181,6 +183,57 @@ def _is_req_state_block_table_match(model_runner, req_id: str) -> bool: ).all() +def _make_mock_backend_for_kernel_block_size( + supported_sizes: list[int | MultipleOf], +): + class _MockBackend: + @staticmethod + def get_supported_kernel_block_size(): + return supported_sizes + + return _MockBackend() + + +def _make_kv_cache_spec() -> FullAttentionSpec: + return FullAttentionSpec(block_size=1, num_kv_heads=1, head_size=1, dtype="float16") + + +def test_select_common_block_size_prefers_manager_block_size(): + backend_a = _make_mock_backend_for_kernel_block_size([MultipleOf(32)]) + backend_b = _make_mock_backend_for_kernel_block_size([64, MultipleOf(16)]) + attn_groups = [ + AttentionGroup(backend_a, [], [], _make_kv_cache_spec(), 0), + AttentionGroup(backend_b, [], [], _make_kv_cache_spec(), 0), + ] + + selected_size = GPUModelRunner.select_common_block_size(128, attn_groups) + assert selected_size == 128 + + +def test_select_common_block_size_uses_largest_shared_int(): + backend_a = _make_mock_backend_for_kernel_block_size([128, 64]) + backend_b = _make_mock_backend_for_kernel_block_size([64, 32]) + attn_groups = [ + AttentionGroup(backend_a, [], [], _make_kv_cache_spec(), 0), + AttentionGroup(backend_b, [], [], _make_kv_cache_spec(), 0), + ] + + selected_size = GPUModelRunner.select_common_block_size(256, attn_groups) + assert selected_size == 64 + + +def test_select_common_block_size_no_valid_option(): + backend_a = _make_mock_backend_for_kernel_block_size([64]) + backend_b = _make_mock_backend_for_kernel_block_size([MultipleOf(16)]) + attn_groups = [ + AttentionGroup(backend_a, [], [], _make_kv_cache_spec(), 0), + AttentionGroup(backend_b, [], [], _make_kv_cache_spec(), 0), + ] + + with pytest.raises(ValueError): + GPUModelRunner.select_common_block_size(48, attn_groups) + + def test_update_states_new_request(model_runner, dist_init): req_id = "req_0" diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 747a7b377e40..ba852bb89f33 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -3978,6 +3978,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): def create_attn_groups( attn_backends_map: dict[AttentionGroupKey, list[str]], + kv_cache_group_id: int, ) -> list[AttentionGroup]: attn_groups: list[AttentionGroup] = [] for (attn_backend, kv_cache_spec), layer_names in attn_backends_map.items(): @@ -3987,6 +3988,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): kv_cache_spec, self.vllm_config, self.device, + kv_cache_group_id, num_metadata_builders=1 if not self.parallel_config.enable_dbo else 2, @@ -4005,8 +4007,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # Resolve cudagraph_mode before actually initialize metadata_builders self._check_and_update_cudagraph_mode(attention_backend_set) - for attn_backends_map in attention_backend_maps: - self.attn_groups.append(create_attn_groups(attn_backends_map)) + for i, attn_backend_map in enumerate(attention_backend_maps): + self.attn_groups.append(create_attn_groups(attn_backend_map, i)) # Calculate reorder batch threshold (if needed) self.calculate_reorder_batch_threshold() @@ -4156,87 +4158,81 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): return self.reorder_batch_threshold = reduce(min_none_high, reorder_batch_thresholds) - def _find_compatible_block_sizes( - self, - kv_manager_block_size: int, - backend_cls: type[AttentionBackend], - return_all: bool = False, - ) -> list[int]: - """ - Find compatible block sizes for a backend. - - Args: - kv_manager_block_size: Physical block size of KV cache - backend_cls: Attention backend class - return_all: Return all compatible sizes if True, max size if False - - Returns: - Compatible block size(s) based on return_all parameter - - Raises: - ValueError: If no compatible block size found - """ - supported_block_size = backend_cls.get_supported_kernel_block_size() - compatible_sizes = [] - - for block_size in supported_block_size: - if isinstance(block_size, int): - if kv_manager_block_size % block_size == 0: - compatible_sizes.append(block_size) - elif ( - isinstance(block_size, MultipleOf) - and kv_manager_block_size % block_size.base == 0 - ): - compatible_sizes.append(kv_manager_block_size) - - if not compatible_sizes: - raise ValueError(f"No compatible block size for {kv_manager_block_size}") - - return compatible_sizes if return_all else [max(compatible_sizes)] - - def _select_common_block_size( - self, kv_manager_block_size: int, attn_groups: list[AttentionGroup] + @staticmethod + def select_common_block_size( + kv_manager_block_size: int, attn_groups: list[AttentionGroup] ) -> int: """ - Select common block size for all backends. + Select a block size that is supported by all backends and is a factor of + kv_manager_block_size. + + If kv_manager_block_size is supported by all backends, return it directly. + Otherwise, return the max supported size. Args: kv_manager_block_size: Block size of KV cache attn_groups: List of attention groups Returns: - Block size supported by all backends, - prioritizing cache_config.block_size + The selected block size Raises: - ValueError: If no common block size found + ValueError: If no valid block size found """ - all_backend_supports = [] - for attn_group in attn_groups: - compatible_sizes = self._find_compatible_block_sizes( - kv_manager_block_size, attn_group.backend, return_all=True - ) - supported_sizes = sorted(list(set(compatible_sizes)), reverse=True) - all_backend_supports.append(set(supported_sizes)) + def block_size_is_supported( + backends: list[type[AttentionBackend]], block_size: int + ) -> bool: + """ + Check if the block size is supported by all backends. + """ + for backend in backends: + is_supported = False + for supported_size in backend.get_supported_kernel_block_size(): + if isinstance(supported_size, int): + if block_size == supported_size: + is_supported = True + elif isinstance(supported_size, MultipleOf): + if block_size % supported_size.base == 0: + is_supported = True + else: + raise ValueError(f"Unknown supported size: {supported_size}") + if not is_supported: + return False + return True - common_supported_sizes = set.intersection(*all_backend_supports) + backends = [group.backend for group in attn_groups] - if not common_supported_sizes: - error_msg = f"No common block size for {kv_manager_block_size}. " - for i, attn_group in enumerate(attn_groups): - supported = all_backend_supports[i] - error_msg += ( - f"Backend {attn_group.backend} supports: {sorted(supported)}. " - ) - raise ValueError(error_msg) + # Case 1: if the block_size of kv cache manager is supported by all backends, + # return it directly + if block_size_is_supported(backends, kv_manager_block_size): + return kv_manager_block_size - if self.cache_config.block_size in common_supported_sizes: - return self.cache_config.block_size + # Case 2: otherwise, the block_size must be an `int`-format supported size of + # at least one backend. Iterate over all `int`-format supported sizes in + # descending order and return the first one that is supported by all backends. + # Simple proof: + # If the supported size b is in MultipleOf(x_i) format for all attention + # backends i, and b a factor of kv_manager_block_size, then + # kv_manager_block_size also satisfies MultipleOf(x_i) for all i. We will + # return kv_manager_block_size in case 1. + all_int_supported_sizes = set( + supported_size + for backend in backends + for supported_size in backend.get_supported_kernel_block_size() + if isinstance(supported_size, int) + ) - return max(common_supported_sizes) + for supported_size in sorted(all_int_supported_sizes, reverse=True): + if kv_manager_block_size % supported_size != 0: + continue + if block_size_is_supported(backends, supported_size): + return supported_size + raise ValueError(f"No common block size for {kv_manager_block_size}. ") - def may_reinitialize_input_batch(self, kv_cache_config: KVCacheConfig) -> None: + def may_reinitialize_input_batch( + self, kv_cache_config: KVCacheConfig, kernel_block_sizes: list[int] + ) -> None: """ Re-initialize the input batch if the block sizes are different from `[self.cache_config.block_size]`. This usually happens when there @@ -4244,6 +4240,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): Args: kv_cache_config: The KV cache configuration. + kernel_block_sizes: The kernel block sizes for each KV cache group. """ block_sizes = [ kv_cache_group.kv_cache_spec.block_size @@ -4251,9 +4248,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): if not isinstance(kv_cache_group.kv_cache_spec, EncoderOnlyAttentionSpec) ] - # Generate kernel_block_sizes that matches each block_size - kernel_block_sizes = self._prepare_kernel_block_sizes(kv_cache_config) - if block_sizes != [self.cache_config.block_size] or kernel_block_sizes != [ self.cache_config.block_size ]: @@ -4354,7 +4348,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # all backends in the group. attn_groups = self.attn_groups[kv_cache_group_id] kv_manager_block_size = kv_cache_group.kv_cache_spec.block_size - selected_kernel_size = self._select_common_block_size( + selected_kernel_size = self.select_common_block_size( kv_manager_block_size, attn_groups ) kernel_block_sizes.append(selected_kernel_size) @@ -4372,6 +4366,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): self, kv_cache_config: KVCacheConfig, kv_cache_raw_tensors: dict[str, torch.Tensor], + kernel_block_sizes: list[int], ) -> dict[str, torch.Tensor]: """ Reshape the KV cache tensors to the desired shape and dtype. @@ -4380,6 +4375,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): kv_cache_config: The KV cache config kv_cache_raw_tensors: The KV cache buffer of each layer, with correct size but uninitialized shape. + kernel_block_sizes: The kernel block sizes for each KV cache group. Returns: Dict[str, torch.Tensor]: A map between layer names to their corresponding memory buffer for KV cache. @@ -4389,6 +4385,10 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): for group in self._kv_cache_spec_attn_group_iterator(): kv_cache_spec = group.kv_cache_spec attn_backend = group.backend + if group.kv_cache_group_id == len(kernel_block_sizes): + # There may be a last group for layers without kv cache. + continue + kernel_block_size = kernel_block_sizes[group.kv_cache_group_id] for layer_name in group.layer_names: if layer_name in self.runner_only_attn_layers: continue @@ -4397,24 +4397,21 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): num_blocks = raw_tensor.numel() // kv_cache_spec.page_size_bytes if isinstance(kv_cache_spec, AttentionSpec): has_attn = True - kv_manager_block_size = kv_cache_spec.block_size - kernel_size_list = self._find_compatible_block_sizes( - kv_manager_block_size, attn_backend, return_all=False + num_blocks_per_kv_block = ( + kv_cache_spec.block_size // kernel_block_size ) - kernel_size = kernel_size_list[0] - num_blocks_per_kv_block = kv_manager_block_size // kernel_size kernel_num_blocks = num_blocks * num_blocks_per_kv_block kv_cache_shape = attn_backend.get_kv_cache_shape( kernel_num_blocks, - kernel_size, + kernel_block_size, kv_cache_spec.num_kv_heads, kv_cache_spec.head_size, cache_dtype_str=self.cache_config.cache_dtype, ) dtype = kv_cache_spec.dtype try: - kv_cache_stride_order = attn_backend.get_kv_cache_stride_order() # noqa: E501 + kv_cache_stride_order = attn_backend.get_kv_cache_stride_order() assert len(kv_cache_stride_order) == len(kv_cache_shape) except (AttributeError, NotImplementedError): kv_cache_stride_order = tuple(range(len(kv_cache_shape))) @@ -4497,13 +4494,15 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ) def initialize_kv_cache_tensors( - self, kv_cache_config: KVCacheConfig + self, kv_cache_config: KVCacheConfig, kernel_block_sizes: list[int] ) -> dict[str, torch.Tensor]: """ Initialize the memory buffer for KV cache. Args: kv_cache_config: The KV cache config + kernel_block_sizes: The kernel block sizes for each KV cache group. + Returns: Dict[str, torch.Tensor]: A map between layer names to their corresponding memory buffer for KV cache. @@ -4512,7 +4511,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): kv_cache_raw_tensors = self._allocate_kv_cache_tensors(kv_cache_config) # Change the memory buffer to the desired shape kv_caches = self._reshape_kv_cache_tensors( - kv_cache_config, kv_cache_raw_tensors + kv_cache_config, kv_cache_raw_tensors, kernel_block_sizes ) # Set up cross-layer KV cache sharing @@ -4571,9 +4570,17 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): self.may_add_encoder_only_layers_to_kv_cache_config() self.maybe_add_kv_sharing_layers_to_kv_cache_groups(kv_cache_config) self.initialize_attn_backend(kv_cache_config) + # The kernel block size for all KV cache groups. For example, if + # kv_cache_manager uses block_size 256 for a given group, but the attention + # backends for that group only supports block_size 64, we will return + # kernel_block_size 64 and split the 256-token-block to 4 blocks with 64 + # tokens each. + kernel_block_sizes = self._prepare_kernel_block_sizes(kv_cache_config) # Reinitialize need to after initialize_attn_backend - self.may_reinitialize_input_batch(kv_cache_config) - kv_caches = self.initialize_kv_cache_tensors(kv_cache_config) + self.may_reinitialize_input_batch(kv_cache_config, kernel_block_sizes) + kv_caches = self.initialize_kv_cache_tensors( + kv_cache_config, kernel_block_sizes + ) if self.speculative_config and self.speculative_config.use_eagle(): assert isinstance(self.drafter, EagleProposer) diff --git a/vllm/v1/worker/utils.py b/vllm/v1/worker/utils.py index 92baf0cb7136..396adbcfb289 100644 --- a/vllm/v1/worker/utils.py +++ b/vllm/v1/worker/utils.py @@ -140,6 +140,7 @@ class AttentionGroup: metadata_builders: list[AttentionMetadataBuilder] layer_names: list[str] kv_cache_spec: KVCacheSpec + kv_cache_group_id: int @staticmethod def create_with_metadata_builders( @@ -148,13 +149,16 @@ class AttentionGroup: kv_cache_spec: KVCacheSpec, vllm_config: VllmConfig, device: torch.device, + kv_cache_group_id: int, num_metadata_builders: int = 1, ) -> "AttentionGroup": metadata_builders = [ backend.get_builder_cls()(kv_cache_spec, layer_names, vllm_config, device) for _ in range(num_metadata_builders) ] - return AttentionGroup(backend, metadata_builders, layer_names, kv_cache_spec) + return AttentionGroup( + backend, metadata_builders, layer_names, kv_cache_spec, kv_cache_group_id + ) def get_metadata_builder(self, ubatch_id: int = 0) -> AttentionMetadataBuilder: assert len(self.metadata_builders) > ubatch_id