[Hybrid] Pass kernel block size to builders (#27753)

Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com>
This commit is contained in:
Thomas Parnell 2025-11-03 06:48:03 +01:00 committed by GitHub
parent 470ad118b6
commit 18961c5ea6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 62 additions and 27 deletions

View File

@ -62,7 +62,11 @@ class FlashAttentionBackend(AttentionBackend):
@staticmethod @staticmethod
def get_supported_kernel_block_size() -> list[int | MultipleOf]: def get_supported_kernel_block_size() -> list[int | MultipleOf]:
return [MultipleOf(16)] # NOTE(tdoublep): while in principle, FA supports
# MultipleOf(16), these are the block sizes that do not
# suffer from the NaN propagation problem described here:
# https://github.com/Dao-AILab/flash-attention/issues/1974
return [16, 32, 64]
@classmethod @classmethod
def validate_head_size(cls, head_size: int) -> None: def validate_head_size(cls, head_size: int) -> None:

View File

@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import copy import copy
from dataclasses import dataclass, fields from dataclasses import dataclass, fields, replace
from math import prod from math import prod
import torch import torch
@ -44,6 +44,12 @@ class KVCacheSpec:
""" """
raise NotImplementedError raise NotImplementedError
def copy_with_new_block_size(self, block_size: int) -> Self:
"""
Create a new KVCacheSpec from self but replacing the block size.
"""
return replace(self, block_size=block_size)
@classmethod @classmethod
def merge(cls, specs: list[Self]) -> Self: def merge(cls, specs: list[Self]) -> Self:
""" """

View File

@ -4039,16 +4039,11 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
) -> list[AttentionGroup]: ) -> list[AttentionGroup]:
attn_groups: list[AttentionGroup] = [] attn_groups: list[AttentionGroup] = []
for (attn_backend, kv_cache_spec), layer_names in attn_backends_map.items(): for (attn_backend, kv_cache_spec), layer_names in attn_backends_map.items():
attn_group = AttentionGroup.create_with_metadata_builders( attn_group = AttentionGroup(
attn_backend, attn_backend,
layer_names, layer_names,
kv_cache_spec, kv_cache_spec,
self.vllm_config,
self.device,
kv_cache_group_id, kv_cache_group_id,
num_metadata_builders=1
if not self.parallel_config.enable_dbo
else 2,
) )
attn_groups.append(attn_group) attn_groups.append(attn_group)
@ -4067,7 +4062,27 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
for i, attn_backend_map in enumerate(attention_backend_maps): for i, attn_backend_map in enumerate(attention_backend_maps):
self.attn_groups.append(create_attn_groups(attn_backend_map, i)) self.attn_groups.append(create_attn_groups(attn_backend_map, i))
def initialize_metadata_builders(
self, kv_cache_config: KVCacheConfig, kernel_block_sizes: list[int]
) -> None:
"""
Create the metadata builders for all KV cache groups and attn groups.
"""
for kv_cache_group_id in range(len(kv_cache_config.kv_cache_groups)):
for attn_group in self.attn_groups[kv_cache_group_id]:
attn_group.create_metadata_builders(
self.vllm_config,
self.device,
kernel_block_sizes[kv_cache_group_id]
if kv_cache_group_id < len(kernel_block_sizes)
else None,
num_metadata_builders=1
if not self.parallel_config.enable_dbo
else 2,
)
# Calculate reorder batch threshold (if needed) # Calculate reorder batch threshold (if needed)
# Note (tdoublep): do this *after* constructing builders,
# because some of them change the threshold at init time.
self.calculate_reorder_batch_threshold() self.calculate_reorder_batch_threshold()
def _check_and_update_cudagraph_mode( def _check_and_update_cudagraph_mode(
@ -4633,6 +4648,10 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# kernel_block_size 64 and split the 256-token-block to 4 blocks with 64 # kernel_block_size 64 and split the 256-token-block to 4 blocks with 64
# tokens each. # tokens each.
kernel_block_sizes = self._prepare_kernel_block_sizes(kv_cache_config) kernel_block_sizes = self._prepare_kernel_block_sizes(kv_cache_config)
# create metadata builders
self.initialize_metadata_builders(kv_cache_config, kernel_block_sizes)
# Reinitialize need to after initialize_attn_backend # Reinitialize need to after initialize_attn_backend
self.may_reinitialize_input_batch(kv_cache_config, kernel_block_sizes) self.may_reinitialize_input_batch(kv_cache_config, kernel_block_sizes)
kv_caches = self.initialize_kv_cache_tensors( kv_caches = self.initialize_kv_cache_tensors(

View File

@ -1,7 +1,7 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections import defaultdict from collections import defaultdict
from dataclasses import dataclass from dataclasses import dataclass, field
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
import torch import torch
@ -134,31 +134,37 @@ class MultiModalBudget:
@dataclass @dataclass
class AttentionGroup: class AttentionGroup:
backend: type[AttentionBackend] backend: type[AttentionBackend]
# When ubatching is enabled we will have a metadata builder for each ubatch
# so that if they use internal persistant buffers for cudagraphs, and they
# won't have to worry about conflicting with the other ubatches.
metadata_builders: list[AttentionMetadataBuilder]
layer_names: list[str] layer_names: list[str]
kv_cache_spec: KVCacheSpec kv_cache_spec: KVCacheSpec
kv_cache_group_id: int kv_cache_group_id: int
# When ubatching is enabled we will have a metadata builder for each ubatch
# so that if they use internal persistant buffers for cudagraphs, and they
# won't have to worry about conflicting with the other ubatches.
metadata_builders: list[AttentionMetadataBuilder] = field(
default_factory=lambda: []
)
@staticmethod def create_metadata_builders(
def create_with_metadata_builders( self,
backend: type[AttentionBackend], vllm_config,
layer_names: list[str], device,
kv_cache_spec: KVCacheSpec, kernel_block_size: int | None,
vllm_config: VllmConfig,
device: torch.device,
kv_cache_group_id: int,
num_metadata_builders: int = 1, num_metadata_builders: int = 1,
) -> "AttentionGroup": ):
metadata_builders = [ kv_cache_spec_builder = (
backend.get_builder_cls()(kv_cache_spec, layer_names, vllm_config, device) self.kv_cache_spec.copy_with_new_block_size(kernel_block_size)
if kernel_block_size is not None
else self.kv_cache_spec
)
self.metadata_builders = [
self.backend.get_builder_cls()(
kv_cache_spec_builder,
self.layer_names,
vllm_config,
device,
)
for _ in range(num_metadata_builders) for _ in range(num_metadata_builders)
] ]
return AttentionGroup(
backend, metadata_builders, layer_names, kv_cache_spec, kv_cache_group_id
)
def get_metadata_builder(self, ubatch_id: int = 0) -> AttentionMetadataBuilder: def get_metadata_builder(self, ubatch_id: int = 0) -> AttentionMetadataBuilder:
assert len(self.metadata_builders) > ubatch_id assert len(self.metadata_builders) > ubatch_id