mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-15 17:16:24 +08:00
[Hybrid] Pass kernel block size to builders (#27753)
Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com>
This commit is contained in:
parent
470ad118b6
commit
18961c5ea6
@ -62,7 +62,11 @@ class FlashAttentionBackend(AttentionBackend):
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_supported_kernel_block_size() -> list[int | MultipleOf]:
|
def get_supported_kernel_block_size() -> list[int | MultipleOf]:
|
||||||
return [MultipleOf(16)]
|
# NOTE(tdoublep): while in principle, FA supports
|
||||||
|
# MultipleOf(16), these are the block sizes that do not
|
||||||
|
# suffer from the NaN propagation problem described here:
|
||||||
|
# https://github.com/Dao-AILab/flash-attention/issues/1974
|
||||||
|
return [16, 32, 64]
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def validate_head_size(cls, head_size: int) -> None:
|
def validate_head_size(cls, head_size: int) -> None:
|
||||||
|
|||||||
@ -2,7 +2,7 @@
|
|||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
import copy
|
import copy
|
||||||
from dataclasses import dataclass, fields
|
from dataclasses import dataclass, fields, replace
|
||||||
from math import prod
|
from math import prod
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@ -44,6 +44,12 @@ class KVCacheSpec:
|
|||||||
"""
|
"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def copy_with_new_block_size(self, block_size: int) -> Self:
|
||||||
|
"""
|
||||||
|
Create a new KVCacheSpec from self but replacing the block size.
|
||||||
|
"""
|
||||||
|
return replace(self, block_size=block_size)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def merge(cls, specs: list[Self]) -> Self:
|
def merge(cls, specs: list[Self]) -> Self:
|
||||||
"""
|
"""
|
||||||
|
|||||||
@ -4039,16 +4039,11 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
) -> list[AttentionGroup]:
|
) -> list[AttentionGroup]:
|
||||||
attn_groups: list[AttentionGroup] = []
|
attn_groups: list[AttentionGroup] = []
|
||||||
for (attn_backend, kv_cache_spec), layer_names in attn_backends_map.items():
|
for (attn_backend, kv_cache_spec), layer_names in attn_backends_map.items():
|
||||||
attn_group = AttentionGroup.create_with_metadata_builders(
|
attn_group = AttentionGroup(
|
||||||
attn_backend,
|
attn_backend,
|
||||||
layer_names,
|
layer_names,
|
||||||
kv_cache_spec,
|
kv_cache_spec,
|
||||||
self.vllm_config,
|
|
||||||
self.device,
|
|
||||||
kv_cache_group_id,
|
kv_cache_group_id,
|
||||||
num_metadata_builders=1
|
|
||||||
if not self.parallel_config.enable_dbo
|
|
||||||
else 2,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
attn_groups.append(attn_group)
|
attn_groups.append(attn_group)
|
||||||
@ -4067,7 +4062,27 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
for i, attn_backend_map in enumerate(attention_backend_maps):
|
for i, attn_backend_map in enumerate(attention_backend_maps):
|
||||||
self.attn_groups.append(create_attn_groups(attn_backend_map, i))
|
self.attn_groups.append(create_attn_groups(attn_backend_map, i))
|
||||||
|
|
||||||
|
def initialize_metadata_builders(
|
||||||
|
self, kv_cache_config: KVCacheConfig, kernel_block_sizes: list[int]
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Create the metadata builders for all KV cache groups and attn groups.
|
||||||
|
"""
|
||||||
|
for kv_cache_group_id in range(len(kv_cache_config.kv_cache_groups)):
|
||||||
|
for attn_group in self.attn_groups[kv_cache_group_id]:
|
||||||
|
attn_group.create_metadata_builders(
|
||||||
|
self.vllm_config,
|
||||||
|
self.device,
|
||||||
|
kernel_block_sizes[kv_cache_group_id]
|
||||||
|
if kv_cache_group_id < len(kernel_block_sizes)
|
||||||
|
else None,
|
||||||
|
num_metadata_builders=1
|
||||||
|
if not self.parallel_config.enable_dbo
|
||||||
|
else 2,
|
||||||
|
)
|
||||||
# Calculate reorder batch threshold (if needed)
|
# Calculate reorder batch threshold (if needed)
|
||||||
|
# Note (tdoublep): do this *after* constructing builders,
|
||||||
|
# because some of them change the threshold at init time.
|
||||||
self.calculate_reorder_batch_threshold()
|
self.calculate_reorder_batch_threshold()
|
||||||
|
|
||||||
def _check_and_update_cudagraph_mode(
|
def _check_and_update_cudagraph_mode(
|
||||||
@ -4633,6 +4648,10 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
# kernel_block_size 64 and split the 256-token-block to 4 blocks with 64
|
# kernel_block_size 64 and split the 256-token-block to 4 blocks with 64
|
||||||
# tokens each.
|
# tokens each.
|
||||||
kernel_block_sizes = self._prepare_kernel_block_sizes(kv_cache_config)
|
kernel_block_sizes = self._prepare_kernel_block_sizes(kv_cache_config)
|
||||||
|
|
||||||
|
# create metadata builders
|
||||||
|
self.initialize_metadata_builders(kv_cache_config, kernel_block_sizes)
|
||||||
|
|
||||||
# Reinitialize need to after initialize_attn_backend
|
# Reinitialize need to after initialize_attn_backend
|
||||||
self.may_reinitialize_input_batch(kv_cache_config, kernel_block_sizes)
|
self.may_reinitialize_input_batch(kv_cache_config, kernel_block_sizes)
|
||||||
kv_caches = self.initialize_kv_cache_tensors(
|
kv_caches = self.initialize_kv_cache_tensors(
|
||||||
|
|||||||
@ -1,7 +1,7 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass, field
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@ -134,31 +134,37 @@ class MultiModalBudget:
|
|||||||
@dataclass
|
@dataclass
|
||||||
class AttentionGroup:
|
class AttentionGroup:
|
||||||
backend: type[AttentionBackend]
|
backend: type[AttentionBackend]
|
||||||
# When ubatching is enabled we will have a metadata builder for each ubatch
|
|
||||||
# so that if they use internal persistant buffers for cudagraphs, and they
|
|
||||||
# won't have to worry about conflicting with the other ubatches.
|
|
||||||
metadata_builders: list[AttentionMetadataBuilder]
|
|
||||||
layer_names: list[str]
|
layer_names: list[str]
|
||||||
kv_cache_spec: KVCacheSpec
|
kv_cache_spec: KVCacheSpec
|
||||||
kv_cache_group_id: int
|
kv_cache_group_id: int
|
||||||
|
# When ubatching is enabled we will have a metadata builder for each ubatch
|
||||||
|
# so that if they use internal persistant buffers for cudagraphs, and they
|
||||||
|
# won't have to worry about conflicting with the other ubatches.
|
||||||
|
metadata_builders: list[AttentionMetadataBuilder] = field(
|
||||||
|
default_factory=lambda: []
|
||||||
|
)
|
||||||
|
|
||||||
@staticmethod
|
def create_metadata_builders(
|
||||||
def create_with_metadata_builders(
|
self,
|
||||||
backend: type[AttentionBackend],
|
vllm_config,
|
||||||
layer_names: list[str],
|
device,
|
||||||
kv_cache_spec: KVCacheSpec,
|
kernel_block_size: int | None,
|
||||||
vllm_config: VllmConfig,
|
|
||||||
device: torch.device,
|
|
||||||
kv_cache_group_id: int,
|
|
||||||
num_metadata_builders: int = 1,
|
num_metadata_builders: int = 1,
|
||||||
) -> "AttentionGroup":
|
):
|
||||||
metadata_builders = [
|
kv_cache_spec_builder = (
|
||||||
backend.get_builder_cls()(kv_cache_spec, layer_names, vllm_config, device)
|
self.kv_cache_spec.copy_with_new_block_size(kernel_block_size)
|
||||||
|
if kernel_block_size is not None
|
||||||
|
else self.kv_cache_spec
|
||||||
|
)
|
||||||
|
self.metadata_builders = [
|
||||||
|
self.backend.get_builder_cls()(
|
||||||
|
kv_cache_spec_builder,
|
||||||
|
self.layer_names,
|
||||||
|
vllm_config,
|
||||||
|
device,
|
||||||
|
)
|
||||||
for _ in range(num_metadata_builders)
|
for _ in range(num_metadata_builders)
|
||||||
]
|
]
|
||||||
return AttentionGroup(
|
|
||||||
backend, metadata_builders, layer_names, kv_cache_spec, kv_cache_group_id
|
|
||||||
)
|
|
||||||
|
|
||||||
def get_metadata_builder(self, ubatch_id: int = 0) -> AttentionMetadataBuilder:
|
def get_metadata_builder(self, ubatch_id: int = 0) -> AttentionMetadataBuilder:
|
||||||
assert len(self.metadata_builders) > ubatch_id
|
assert len(self.metadata_builders) > ubatch_id
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user