[V1] [Hybrid] Validate compatibility of attention backend batch reordering at init time (#21557)

Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com>
This commit is contained in:
Thomas Parnell 2025-08-02 14:29:40 +02:00 committed by GitHub
parent f5d0f4784f
commit 4abfd8796f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 96 additions and 72 deletions

View File

@ -4,7 +4,7 @@
from __future__ import annotations
from dataclasses import dataclass
from typing import TYPE_CHECKING, ClassVar, Optional, Union
from typing import ClassVar, Optional, Union
import torch
from flashinfer import (BatchDecodeWithPagedKVCacheWrapper,
@ -21,17 +21,17 @@ from vllm.logger import init_logger
from vllm.utils import cdiv, is_pin_memory_available
from vllm.utils.flashinfer import use_trtllm_decode_attention
from vllm.v1.attention.backends.flash_attn import use_cascade_attention
from vllm.v1.attention.backends.utils import (
AttentionCGSupport, AttentionMetadataBuilder, CommonAttentionMetadata,
get_kv_cache_layout, get_per_layer_parameters,
infer_global_hyperparameters, reorder_batch_to_split_decodes_and_prefills,
split_decodes_and_prefills)
# yapf conflicts with isort for this block
# yapf: disable
from vllm.v1.attention.backends.utils import (AttentionCGSupport,
AttentionMetadataBuilder,
CommonAttentionMetadata,
get_kv_cache_layout,
get_per_layer_parameters,
infer_global_hyperparameters,
split_decodes_and_prefills)
from vllm.v1.kv_cache_interface import AttentionSpec
if TYPE_CHECKING:
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.worker.gpu_input_batch import InputBatch
FLASHINFER_WORKSPACE_BUFFER_SIZE = 256 * 1024 * 1024
logger = init_logger(__name__)
@ -179,6 +179,8 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
attn_cudagraph_support: ClassVar[AttentionCGSupport] = \
AttentionCGSupport.PURE_DECODE_ONLY
reorder_batch_threshold: ClassVar[int] = 1
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
vllm_config: VllmConfig, device: torch.device):
self.device = device
@ -239,12 +241,6 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
dtype=torch.int32,
device=self.device)
def reorder_batch(self, input_batch: InputBatch,
scheduler_output: SchedulerOutput) -> bool:
return reorder_batch_to_split_decodes_and_prefills(input_batch,
scheduler_output,
decode_threshold=1)
def _get_workspace_buffer(self):
if self._workspace_buffer is None:
self._workspace_buffer = torch.empty(

View File

@ -2,21 +2,17 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import math
from dataclasses import dataclass
from typing import TYPE_CHECKING, Optional
from typing import ClassVar, Optional
import torch
from vllm.attention.backends.abstract import AttentionBackend
from vllm.config import VllmConfig
from vllm.v1.attention.backends.utils import (
AttentionMetadataBuilder, CommonAttentionMetadata,
reorder_batch_to_split_decodes_and_prefills, split_decodes_and_prefills)
from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
CommonAttentionMetadata,
split_decodes_and_prefills)
from vllm.v1.kv_cache_interface import AttentionSpec, MambaSpec
if TYPE_CHECKING:
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.worker.gpu_input_batch import InputBatch
def _query_start_loc_to_chunk_indices_offsets(query_start_loc: torch.Tensor,
chunk_size: int,
@ -87,6 +83,8 @@ class Mamba2AttentionMetadata:
class Mamba2AttentionMetadataBuilder(
AttentionMetadataBuilder[Mamba2AttentionMetadata]):
reorder_batch_threshold: ClassVar[int] = 1
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
vllm_config: VllmConfig, device: torch.device):
assert isinstance(kv_cache_spec, MambaSpec)
@ -95,12 +93,6 @@ class Mamba2AttentionMetadataBuilder(
assert self.chunk_size is not None, (
"chunk_size needs to be set in the model config for Mamba2 models")
def reorder_batch(self, input_batch: "InputBatch",
scheduler_output: "SchedulerOutput") -> bool:
return reorder_batch_to_split_decodes_and_prefills(input_batch,
scheduler_output,
decode_threshold=1)
def build(self,
common_prefix_len: int,
common_attn_metadata: CommonAttentionMetadata,

View File

@ -190,7 +190,7 @@ return curr_o @ W_O
import functools
from abc import abstractmethod
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Generic, Optional, TypeVar, Union
from typing import ClassVar, Generic, Optional, TypeVar, Union
import torch
@ -210,10 +210,11 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
from vllm.platforms import current_platform
from vllm.utils import cdiv, round_down
from vllm.utils.flashinfer import has_nvidia_artifactory
from vllm.v1.attention.backends.utils import (
AttentionMetadataBuilder, CommonAttentionMetadata,
get_per_layer_parameters, infer_global_hyperparameters,
reorder_batch_to_split_decodes_and_prefills, split_decodes_and_prefills)
from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
CommonAttentionMetadata,
get_per_layer_parameters,
infer_global_hyperparameters,
split_decodes_and_prefills)
from vllm.v1.kv_cache_interface import AttentionSpec
try:
@ -233,10 +234,6 @@ try:
except ImportError:
flashinfer_available = False
if TYPE_CHECKING:
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.worker.gpu_input_batch import InputBatch
logger = init_logger(__name__)
CUDNN_WORKSPACE_SIZE = 12800
@ -403,6 +400,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
NOTE: Please read the comment at the top of the file before trying to
understand this class
"""
reorder_batch_threshold: ClassVar[int] = 1
def __init__(self,
kv_cache_spec: AttentionSpec,
@ -559,12 +557,6 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
prefill.prefill_main = self._fi_prefill_main
prefill.prefill_chunks = self._fi_prefill_chunks
def reorder_batch(self, input_batch: "InputBatch",
scheduler_output: "SchedulerOutput") -> bool:
return reorder_batch_to_split_decodes_and_prefills(input_batch,
scheduler_output,
decode_threshold=1)
def _build_decode(self, block_table_tensor: torch.Tensor,
seq_lens: torch.Tensor):
return MLACommonDecodeMetadata(

View File

@ -251,9 +251,6 @@ class AiterFlashAttentionMetadataBuilder(
self.aot_sliding_window: Optional[tuple[int, int]] = None
self.total_tokens: int = 0
def reorder_batch(self, input_batch, scheduler_output) -> bool:
return False
def build_for_cudagraph_capture(
self, common_attn_metadata: CommonAttentionMetadata):
self.total_tokens = self.model_config.max_model_len \

View File

@ -167,6 +167,10 @@ class AttentionMetadataBuilder(abc.ABC, Generic[M]):
# Does this backend/builder support CUDA Graphs for attention.
attn_cudagraph_support: ClassVar[AttentionCGSupport] = \
AttentionCGSupport.NEVER
# Does this backend/builder reorder the batch?
# If not, set this to None. Otherwise set it to the query
# length that will be pulled into the front of the batch.
reorder_batch_threshold: ClassVar[Optional[int]] = None
@abstractmethod
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
@ -221,14 +225,6 @@ class AttentionMetadataBuilder(abc.ABC, Generic[M]):
) -> bool:
return False
def reorder_batch(self, input_batch: "InputBatch",
scheduler_output: "SchedulerOutput") -> bool:
"""
This method can reorder the batch if desired by the backend.
:return: Has the batch been reordered (default False).
"""
return False
@functools.lru_cache
def get_kv_cache_layout():

View File

@ -1,7 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from contextlib import contextmanager
from typing import Any
from typing import TYPE_CHECKING, Any
import torch
import torch.nn as nn
@ -9,8 +9,12 @@ import torch.nn as nn
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.model_executor.model_loader import get_model
from vllm.v1.attention.backends.cpu_attn import TorchSDPAMetadataBuilderV1
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
if TYPE_CHECKING:
from vllm.v1.core.sched.output import SchedulerOutput
logger = init_logger(__name__)
@ -27,6 +31,34 @@ class CPUModelRunner(GPUModelRunner):
self._postprocess_tenosrs()
def _may_reorder_batch(self, scheduler_output: "SchedulerOutput") -> None:
"""
Update the order of requests in the batch based on the attention
backend's needs. For example, some attention backends (namely MLA) may
want to separate requests based on if the attention computation will be
compute-bound or memory-bound.
Args:
scheduler_output: The scheduler output.
"""
# Attention free models have zero kv_cache_goups, however models
# like Mamba are also attention free but use the kv_cache for
# keeping its internal state. This is why we check the number
# of kv_cache groups instead of solely checking
# for self.model_config.is_attention_free.
if len(self.kv_cache_config.kv_cache_groups) == 0:
return
if len(self.kv_cache_config.kv_cache_groups) > 1:
raise ValueError("Multiple KVCacheGroups is not"
"currently supported with CPU model runner.")
assert type(
self.attn_metadata_builders[0]) is TorchSDPAMetadataBuilderV1
self.attn_metadata_builders[0].reorder_batch(self.input_batch,
scheduler_output)
def _postprocess_tenosrs(self) -> None:
# Note: replace device tensors with cpu tensors
def replace_tensor(obj: Any, cpu_attr_name: str,

View File

@ -49,7 +49,8 @@ from vllm.v1.attention.backends.mamba_selectors import get_mamba_attn_backend
from vllm.v1.attention.backends.utils import (
AttentionCGSupport, AttentionMetadataBuilder, CommonAttentionMetadata,
make_kv_sharing_fast_prefill_attention_metadata,
make_local_attention_virtual_batches)
make_local_attention_virtual_batches,
reorder_batch_to_split_decodes_and_prefills)
from vllm.v1.core.encoder_cache_manager import compute_encoder_budget
from vllm.v1.kv_cache_interface import (AttentionSpec,
ChunkedLocalAttentionSpec,
@ -329,6 +330,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self.kv_sharing_fast_prefill_logits_indices = torch.zeros(
self.max_num_tokens, dtype=torch.int32, device=self.device)
self.reorder_batch_threshold: Optional[int] = None
def _may_reorder_batch(self, scheduler_output: "SchedulerOutput") -> None:
"""
Update the order of requests in the batch based on the attention
@ -347,20 +350,11 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
if len(self.kv_cache_config.kv_cache_groups) == 0:
return
self.attn_metadata_builders[0].reorder_batch(self.input_batch,
scheduler_output)
# For models with multiple KV cache groups, the groups should agree on
# the same order of requests. We ensure this by only allowing the first
# group to reorder the batch and asserting that all other groups do not
# reorder the batch.
# TODO(tdoublep): make this more flexible so that any group can
# re-order the batch (not only the first).
# TODO(tdoublep): verify this during engine init instead of at runtime
for i in range(1, len(self.kv_cache_config.kv_cache_groups)):
batch_reordered = self.attn_metadata_builders[i].reorder_batch(
self.input_batch, scheduler_output)
assert not batch_reordered
if self.reorder_batch_threshold is not None:
reorder_batch_to_split_decodes_and_prefills(
self.input_batch,
scheduler_output,
decode_threshold=self.reorder_batch_threshold)
# Note: used for model runner override.
def _init_device_properties(self) -> None:
@ -2654,6 +2648,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self.attn_backends.append(attn_backend_i)
self.attn_metadata_builders.append(attn_metadata_builder_i)
# Calculate reorder batch threshold (if neeeded)
self.calculate_reorder_batch_threshold()
if len(self.attn_backends) > 0:
return
@ -2688,6 +2685,28 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self.attn_metadata_builders.append(attn_metadata_builder)
self.is_encoder_only_model = True
def calculate_reorder_batch_threshold(self) -> None:
"""
Check that if any backends reorder batches; that the reordering
is compatible (e.g., decode threshold is the same)
"""
for attn_metadata_builder_i in self.attn_metadata_builders:
# check that if any backends reorder batches; that the reordering
# is compatible (e.g., decode threshold is the same)
reorder_batch_threshold_i = (
attn_metadata_builder_i.reorder_batch_threshold)
if reorder_batch_threshold_i is not None:
if self.reorder_batch_threshold is not None:
if reorder_batch_threshold_i != \
self.reorder_batch_threshold:
raise ValueError(
f"Attention backend reorders decodes with "
f"threshold {reorder_batch_threshold_i} but other "
f"backend uses threshold "
f"{self.reorder_batch_threshold}")
else:
self.reorder_batch_threshold = reorder_batch_threshold_i
def may_reinitialize_input_batch(self,
kv_cache_config: KVCacheConfig) -> None:
"""