[Hybrid] Pass kernel block size to builders (#27753)

Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com>
This commit is contained in:
Thomas Parnell 2025-11-03 06:48:03 +01:00 committed by GitHub
parent 470ad118b6
commit 18961c5ea6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 62 additions and 27 deletions

View File

@ -62,7 +62,11 @@ class FlashAttentionBackend(AttentionBackend):
@staticmethod
def get_supported_kernel_block_size() -> list[int | MultipleOf]:
return [MultipleOf(16)]
# NOTE(tdoublep): while in principle, FA supports
# MultipleOf(16), these are the block sizes that do not
# suffer from the NaN propagation problem described here:
# https://github.com/Dao-AILab/flash-attention/issues/1974
return [16, 32, 64]
@classmethod
def validate_head_size(cls, head_size: int) -> None:

View File

@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import copy
from dataclasses import dataclass, fields
from dataclasses import dataclass, fields, replace
from math import prod
import torch
@ -44,6 +44,12 @@ class KVCacheSpec:
"""
raise NotImplementedError
def copy_with_new_block_size(self, block_size: int) -> Self:
"""
Create a new KVCacheSpec from self but replacing the block size.
"""
return replace(self, block_size=block_size)
@classmethod
def merge(cls, specs: list[Self]) -> Self:
"""

View File

@ -4039,16 +4039,11 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
) -> list[AttentionGroup]:
attn_groups: list[AttentionGroup] = []
for (attn_backend, kv_cache_spec), layer_names in attn_backends_map.items():
attn_group = AttentionGroup.create_with_metadata_builders(
attn_group = AttentionGroup(
attn_backend,
layer_names,
kv_cache_spec,
self.vllm_config,
self.device,
kv_cache_group_id,
num_metadata_builders=1
if not self.parallel_config.enable_dbo
else 2,
)
attn_groups.append(attn_group)
@ -4067,7 +4062,27 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
for i, attn_backend_map in enumerate(attention_backend_maps):
self.attn_groups.append(create_attn_groups(attn_backend_map, i))
def initialize_metadata_builders(
self, kv_cache_config: KVCacheConfig, kernel_block_sizes: list[int]
) -> None:
"""
Create the metadata builders for all KV cache groups and attn groups.
"""
for kv_cache_group_id in range(len(kv_cache_config.kv_cache_groups)):
for attn_group in self.attn_groups[kv_cache_group_id]:
attn_group.create_metadata_builders(
self.vllm_config,
self.device,
kernel_block_sizes[kv_cache_group_id]
if kv_cache_group_id < len(kernel_block_sizes)
else None,
num_metadata_builders=1
if not self.parallel_config.enable_dbo
else 2,
)
# Calculate reorder batch threshold (if needed)
# Note (tdoublep): do this *after* constructing builders,
# because some of them change the threshold at init time.
self.calculate_reorder_batch_threshold()
def _check_and_update_cudagraph_mode(
@ -4633,6 +4648,10 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# kernel_block_size 64 and split the 256-token-block to 4 blocks with 64
# tokens each.
kernel_block_sizes = self._prepare_kernel_block_sizes(kv_cache_config)
# create metadata builders
self.initialize_metadata_builders(kv_cache_config, kernel_block_sizes)
# Reinitialize need to after initialize_attn_backend
self.may_reinitialize_input_batch(kv_cache_config, kernel_block_sizes)
kv_caches = self.initialize_kv_cache_tensors(

View File

@ -1,7 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections import defaultdict
from dataclasses import dataclass
from dataclasses import dataclass, field
from typing import TYPE_CHECKING
import torch
@ -134,31 +134,37 @@ class MultiModalBudget:
@dataclass
class AttentionGroup:
backend: type[AttentionBackend]
# When ubatching is enabled we will have a metadata builder for each ubatch
# so that if they use internal persistant buffers for cudagraphs, and they
# won't have to worry about conflicting with the other ubatches.
metadata_builders: list[AttentionMetadataBuilder]
layer_names: list[str]
kv_cache_spec: KVCacheSpec
kv_cache_group_id: int
# When ubatching is enabled we will have a metadata builder for each ubatch
# so that if they use internal persistant buffers for cudagraphs, and they
# won't have to worry about conflicting with the other ubatches.
metadata_builders: list[AttentionMetadataBuilder] = field(
default_factory=lambda: []
)
@staticmethod
def create_with_metadata_builders(
backend: type[AttentionBackend],
layer_names: list[str],
kv_cache_spec: KVCacheSpec,
vllm_config: VllmConfig,
device: torch.device,
kv_cache_group_id: int,
def create_metadata_builders(
self,
vllm_config,
device,
kernel_block_size: int | None,
num_metadata_builders: int = 1,
) -> "AttentionGroup":
metadata_builders = [
backend.get_builder_cls()(kv_cache_spec, layer_names, vllm_config, device)
):
kv_cache_spec_builder = (
self.kv_cache_spec.copy_with_new_block_size(kernel_block_size)
if kernel_block_size is not None
else self.kv_cache_spec
)
self.metadata_builders = [
self.backend.get_builder_cls()(
kv_cache_spec_builder,
self.layer_names,
vllm_config,
device,
)
for _ in range(num_metadata_builders)
]
return AttentionGroup(
backend, metadata_builders, layer_names, kv_cache_spec, kv_cache_group_id
)
def get_metadata_builder(self, ubatch_id: int = 0) -> AttentionMetadataBuilder:
assert len(self.metadata_builders) > ubatch_id