mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 03:05:02 +08:00
[Hybrid] Pass kernel block size to builders (#27753)
Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com>
This commit is contained in:
parent
470ad118b6
commit
18961c5ea6
@ -62,7 +62,11 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
|
||||
@staticmethod
|
||||
def get_supported_kernel_block_size() -> list[int | MultipleOf]:
|
||||
return [MultipleOf(16)]
|
||||
# NOTE(tdoublep): while in principle, FA supports
|
||||
# MultipleOf(16), these are the block sizes that do not
|
||||
# suffer from the NaN propagation problem described here:
|
||||
# https://github.com/Dao-AILab/flash-attention/issues/1974
|
||||
return [16, 32, 64]
|
||||
|
||||
@classmethod
|
||||
def validate_head_size(cls, head_size: int) -> None:
|
||||
|
||||
@ -2,7 +2,7 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import copy
|
||||
from dataclasses import dataclass, fields
|
||||
from dataclasses import dataclass, fields, replace
|
||||
from math import prod
|
||||
|
||||
import torch
|
||||
@ -44,6 +44,12 @@ class KVCacheSpec:
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def copy_with_new_block_size(self, block_size: int) -> Self:
|
||||
"""
|
||||
Create a new KVCacheSpec from self but replacing the block size.
|
||||
"""
|
||||
return replace(self, block_size=block_size)
|
||||
|
||||
@classmethod
|
||||
def merge(cls, specs: list[Self]) -> Self:
|
||||
"""
|
||||
|
||||
@ -4039,16 +4039,11 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
) -> list[AttentionGroup]:
|
||||
attn_groups: list[AttentionGroup] = []
|
||||
for (attn_backend, kv_cache_spec), layer_names in attn_backends_map.items():
|
||||
attn_group = AttentionGroup.create_with_metadata_builders(
|
||||
attn_group = AttentionGroup(
|
||||
attn_backend,
|
||||
layer_names,
|
||||
kv_cache_spec,
|
||||
self.vllm_config,
|
||||
self.device,
|
||||
kv_cache_group_id,
|
||||
num_metadata_builders=1
|
||||
if not self.parallel_config.enable_dbo
|
||||
else 2,
|
||||
)
|
||||
|
||||
attn_groups.append(attn_group)
|
||||
@ -4067,7 +4062,27 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
for i, attn_backend_map in enumerate(attention_backend_maps):
|
||||
self.attn_groups.append(create_attn_groups(attn_backend_map, i))
|
||||
|
||||
def initialize_metadata_builders(
|
||||
self, kv_cache_config: KVCacheConfig, kernel_block_sizes: list[int]
|
||||
) -> None:
|
||||
"""
|
||||
Create the metadata builders for all KV cache groups and attn groups.
|
||||
"""
|
||||
for kv_cache_group_id in range(len(kv_cache_config.kv_cache_groups)):
|
||||
for attn_group in self.attn_groups[kv_cache_group_id]:
|
||||
attn_group.create_metadata_builders(
|
||||
self.vllm_config,
|
||||
self.device,
|
||||
kernel_block_sizes[kv_cache_group_id]
|
||||
if kv_cache_group_id < len(kernel_block_sizes)
|
||||
else None,
|
||||
num_metadata_builders=1
|
||||
if not self.parallel_config.enable_dbo
|
||||
else 2,
|
||||
)
|
||||
# Calculate reorder batch threshold (if needed)
|
||||
# Note (tdoublep): do this *after* constructing builders,
|
||||
# because some of them change the threshold at init time.
|
||||
self.calculate_reorder_batch_threshold()
|
||||
|
||||
def _check_and_update_cudagraph_mode(
|
||||
@ -4633,6 +4648,10 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
# kernel_block_size 64 and split the 256-token-block to 4 blocks with 64
|
||||
# tokens each.
|
||||
kernel_block_sizes = self._prepare_kernel_block_sizes(kv_cache_config)
|
||||
|
||||
# create metadata builders
|
||||
self.initialize_metadata_builders(kv_cache_config, kernel_block_sizes)
|
||||
|
||||
# Reinitialize need to after initialize_attn_backend
|
||||
self.may_reinitialize_input_batch(kv_cache_config, kernel_block_sizes)
|
||||
kv_caches = self.initialize_kv_cache_tensors(
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
@ -134,31 +134,37 @@ class MultiModalBudget:
|
||||
@dataclass
|
||||
class AttentionGroup:
|
||||
backend: type[AttentionBackend]
|
||||
# When ubatching is enabled we will have a metadata builder for each ubatch
|
||||
# so that if they use internal persistant buffers for cudagraphs, and they
|
||||
# won't have to worry about conflicting with the other ubatches.
|
||||
metadata_builders: list[AttentionMetadataBuilder]
|
||||
layer_names: list[str]
|
||||
kv_cache_spec: KVCacheSpec
|
||||
kv_cache_group_id: int
|
||||
# When ubatching is enabled we will have a metadata builder for each ubatch
|
||||
# so that if they use internal persistant buffers for cudagraphs, and they
|
||||
# won't have to worry about conflicting with the other ubatches.
|
||||
metadata_builders: list[AttentionMetadataBuilder] = field(
|
||||
default_factory=lambda: []
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def create_with_metadata_builders(
|
||||
backend: type[AttentionBackend],
|
||||
layer_names: list[str],
|
||||
kv_cache_spec: KVCacheSpec,
|
||||
vllm_config: VllmConfig,
|
||||
device: torch.device,
|
||||
kv_cache_group_id: int,
|
||||
def create_metadata_builders(
|
||||
self,
|
||||
vllm_config,
|
||||
device,
|
||||
kernel_block_size: int | None,
|
||||
num_metadata_builders: int = 1,
|
||||
) -> "AttentionGroup":
|
||||
metadata_builders = [
|
||||
backend.get_builder_cls()(kv_cache_spec, layer_names, vllm_config, device)
|
||||
):
|
||||
kv_cache_spec_builder = (
|
||||
self.kv_cache_spec.copy_with_new_block_size(kernel_block_size)
|
||||
if kernel_block_size is not None
|
||||
else self.kv_cache_spec
|
||||
)
|
||||
self.metadata_builders = [
|
||||
self.backend.get_builder_cls()(
|
||||
kv_cache_spec_builder,
|
||||
self.layer_names,
|
||||
vllm_config,
|
||||
device,
|
||||
)
|
||||
for _ in range(num_metadata_builders)
|
||||
]
|
||||
return AttentionGroup(
|
||||
backend, metadata_builders, layer_names, kv_cache_spec, kv_cache_group_id
|
||||
)
|
||||
|
||||
def get_metadata_builder(self, ubatch_id: int = 0) -> AttentionMetadataBuilder:
|
||||
assert len(self.metadata_builders) > ubatch_id
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user