From 4abfd8796f37adc8fccc9481f37f20de1bce62e4 Mon Sep 17 00:00:00 2001 From: Thomas Parnell Date: Sat, 2 Aug 2025 14:29:40 +0200 Subject: [PATCH] [V1] [Hybrid] Validate compatibility of attention backend batch reordering at init time (#21557) Signed-off-by: Thomas Parnell --- vllm/v1/attention/backends/flashinfer.py | 28 +++++------- vllm/v1/attention/backends/mamba_attn.py | 20 +++------ vllm/v1/attention/backends/mla/common.py | 22 +++------ vllm/v1/attention/backends/rocm_aiter_fa.py | 3 -- vllm/v1/attention/backends/utils.py | 12 ++--- vllm/v1/worker/cpu_model_runner.py | 34 +++++++++++++- vllm/v1/worker/gpu_model_runner.py | 49 ++++++++++++++------- 7 files changed, 96 insertions(+), 72 deletions(-) diff --git a/vllm/v1/attention/backends/flashinfer.py b/vllm/v1/attention/backends/flashinfer.py index 0aaad02b5b840..3697cb9387a92 100755 --- a/vllm/v1/attention/backends/flashinfer.py +++ b/vllm/v1/attention/backends/flashinfer.py @@ -4,7 +4,7 @@ from __future__ import annotations from dataclasses import dataclass -from typing import TYPE_CHECKING, ClassVar, Optional, Union +from typing import ClassVar, Optional, Union import torch from flashinfer import (BatchDecodeWithPagedKVCacheWrapper, @@ -21,17 +21,17 @@ from vllm.logger import init_logger from vllm.utils import cdiv, is_pin_memory_available from vllm.utils.flashinfer import use_trtllm_decode_attention from vllm.v1.attention.backends.flash_attn import use_cascade_attention -from vllm.v1.attention.backends.utils import ( - AttentionCGSupport, AttentionMetadataBuilder, CommonAttentionMetadata, - get_kv_cache_layout, get_per_layer_parameters, - infer_global_hyperparameters, reorder_batch_to_split_decodes_and_prefills, - split_decodes_and_prefills) +# yapf conflicts with isort for this block +# yapf: disable +from vllm.v1.attention.backends.utils import (AttentionCGSupport, + AttentionMetadataBuilder, + CommonAttentionMetadata, + get_kv_cache_layout, + get_per_layer_parameters, + infer_global_hyperparameters, + split_decodes_and_prefills) from vllm.v1.kv_cache_interface import AttentionSpec -if TYPE_CHECKING: - from vllm.v1.core.sched.output import SchedulerOutput - from vllm.v1.worker.gpu_input_batch import InputBatch - FLASHINFER_WORKSPACE_BUFFER_SIZE = 256 * 1024 * 1024 logger = init_logger(__name__) @@ -179,6 +179,8 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): attn_cudagraph_support: ClassVar[AttentionCGSupport] = \ AttentionCGSupport.PURE_DECODE_ONLY + reorder_batch_threshold: ClassVar[int] = 1 + def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], vllm_config: VllmConfig, device: torch.device): self.device = device @@ -239,12 +241,6 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): dtype=torch.int32, device=self.device) - def reorder_batch(self, input_batch: InputBatch, - scheduler_output: SchedulerOutput) -> bool: - return reorder_batch_to_split_decodes_and_prefills(input_batch, - scheduler_output, - decode_threshold=1) - def _get_workspace_buffer(self): if self._workspace_buffer is None: self._workspace_buffer = torch.empty( diff --git a/vllm/v1/attention/backends/mamba_attn.py b/vllm/v1/attention/backends/mamba_attn.py index 8b702e28d67c0..66a8d91db89c2 100644 --- a/vllm/v1/attention/backends/mamba_attn.py +++ b/vllm/v1/attention/backends/mamba_attn.py @@ -2,21 +2,17 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import math from dataclasses import dataclass -from typing import TYPE_CHECKING, Optional +from typing import ClassVar, Optional import torch from vllm.attention.backends.abstract import AttentionBackend from vllm.config import VllmConfig -from vllm.v1.attention.backends.utils import ( - AttentionMetadataBuilder, CommonAttentionMetadata, - reorder_batch_to_split_decodes_and_prefills, split_decodes_and_prefills) +from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder, + CommonAttentionMetadata, + split_decodes_and_prefills) from vllm.v1.kv_cache_interface import AttentionSpec, MambaSpec -if TYPE_CHECKING: - from vllm.v1.core.sched.output import SchedulerOutput - from vllm.v1.worker.gpu_input_batch import InputBatch - def _query_start_loc_to_chunk_indices_offsets(query_start_loc: torch.Tensor, chunk_size: int, @@ -87,6 +83,8 @@ class Mamba2AttentionMetadata: class Mamba2AttentionMetadataBuilder( AttentionMetadataBuilder[Mamba2AttentionMetadata]): + reorder_batch_threshold: ClassVar[int] = 1 + def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], vllm_config: VllmConfig, device: torch.device): assert isinstance(kv_cache_spec, MambaSpec) @@ -95,12 +93,6 @@ class Mamba2AttentionMetadataBuilder( assert self.chunk_size is not None, ( "chunk_size needs to be set in the model config for Mamba2 models") - def reorder_batch(self, input_batch: "InputBatch", - scheduler_output: "SchedulerOutput") -> bool: - return reorder_batch_to_split_decodes_and_prefills(input_batch, - scheduler_output, - decode_threshold=1) - def build(self, common_prefix_len: int, common_attn_metadata: CommonAttentionMetadata, diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py index d112468f1c91d..badff67656c24 100755 --- a/vllm/v1/attention/backends/mla/common.py +++ b/vllm/v1/attention/backends/mla/common.py @@ -190,7 +190,7 @@ return curr_o @ W_O import functools from abc import abstractmethod from dataclasses import dataclass, field -from typing import TYPE_CHECKING, Generic, Optional, TypeVar, Union +from typing import ClassVar, Generic, Optional, TypeVar, Union import torch @@ -210,10 +210,11 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.platforms import current_platform from vllm.utils import cdiv, round_down from vllm.utils.flashinfer import has_nvidia_artifactory -from vllm.v1.attention.backends.utils import ( - AttentionMetadataBuilder, CommonAttentionMetadata, - get_per_layer_parameters, infer_global_hyperparameters, - reorder_batch_to_split_decodes_and_prefills, split_decodes_and_prefills) +from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder, + CommonAttentionMetadata, + get_per_layer_parameters, + infer_global_hyperparameters, + split_decodes_and_prefills) from vllm.v1.kv_cache_interface import AttentionSpec try: @@ -233,10 +234,6 @@ try: except ImportError: flashinfer_available = False -if TYPE_CHECKING: - from vllm.v1.core.sched.output import SchedulerOutput - from vllm.v1.worker.gpu_input_batch import InputBatch - logger = init_logger(__name__) CUDNN_WORKSPACE_SIZE = 12800 @@ -403,6 +400,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]): NOTE: Please read the comment at the top of the file before trying to understand this class """ + reorder_batch_threshold: ClassVar[int] = 1 def __init__(self, kv_cache_spec: AttentionSpec, @@ -559,12 +557,6 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]): prefill.prefill_main = self._fi_prefill_main prefill.prefill_chunks = self._fi_prefill_chunks - def reorder_batch(self, input_batch: "InputBatch", - scheduler_output: "SchedulerOutput") -> bool: - return reorder_batch_to_split_decodes_and_prefills(input_batch, - scheduler_output, - decode_threshold=1) - def _build_decode(self, block_table_tensor: torch.Tensor, seq_lens: torch.Tensor): return MLACommonDecodeMetadata( diff --git a/vllm/v1/attention/backends/rocm_aiter_fa.py b/vllm/v1/attention/backends/rocm_aiter_fa.py index dd10b7f02730a..abe05174507ff 100644 --- a/vllm/v1/attention/backends/rocm_aiter_fa.py +++ b/vllm/v1/attention/backends/rocm_aiter_fa.py @@ -251,9 +251,6 @@ class AiterFlashAttentionMetadataBuilder( self.aot_sliding_window: Optional[tuple[int, int]] = None self.total_tokens: int = 0 - def reorder_batch(self, input_batch, scheduler_output) -> bool: - return False - def build_for_cudagraph_capture( self, common_attn_metadata: CommonAttentionMetadata): self.total_tokens = self.model_config.max_model_len \ diff --git a/vllm/v1/attention/backends/utils.py b/vllm/v1/attention/backends/utils.py index 0f041573e9d20..6defd211f4cfa 100644 --- a/vllm/v1/attention/backends/utils.py +++ b/vllm/v1/attention/backends/utils.py @@ -167,6 +167,10 @@ class AttentionMetadataBuilder(abc.ABC, Generic[M]): # Does this backend/builder support CUDA Graphs for attention. attn_cudagraph_support: ClassVar[AttentionCGSupport] = \ AttentionCGSupport.NEVER + # Does this backend/builder reorder the batch? + # If not, set this to None. Otherwise set it to the query + # length that will be pulled into the front of the batch. + reorder_batch_threshold: ClassVar[Optional[int]] = None @abstractmethod def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], @@ -221,14 +225,6 @@ class AttentionMetadataBuilder(abc.ABC, Generic[M]): ) -> bool: return False - def reorder_batch(self, input_batch: "InputBatch", - scheduler_output: "SchedulerOutput") -> bool: - """ - This method can reorder the batch if desired by the backend. - :return: Has the batch been reordered (default False). - """ - return False - @functools.lru_cache def get_kv_cache_layout(): diff --git a/vllm/v1/worker/cpu_model_runner.py b/vllm/v1/worker/cpu_model_runner.py index 6b2b50a57e1f8..d8f3e0d89a960 100644 --- a/vllm/v1/worker/cpu_model_runner.py +++ b/vllm/v1/worker/cpu_model_runner.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from contextlib import contextmanager -from typing import Any +from typing import TYPE_CHECKING, Any import torch import torch.nn as nn @@ -9,8 +9,12 @@ import torch.nn as nn from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.model_executor.model_loader import get_model +from vllm.v1.attention.backends.cpu_attn import TorchSDPAMetadataBuilderV1 from vllm.v1.worker.gpu_model_runner import GPUModelRunner +if TYPE_CHECKING: + from vllm.v1.core.sched.output import SchedulerOutput + logger = init_logger(__name__) @@ -27,6 +31,34 @@ class CPUModelRunner(GPUModelRunner): self._postprocess_tenosrs() + def _may_reorder_batch(self, scheduler_output: "SchedulerOutput") -> None: + """ + Update the order of requests in the batch based on the attention + backend's needs. For example, some attention backends (namely MLA) may + want to separate requests based on if the attention computation will be + compute-bound or memory-bound. + + Args: + scheduler_output: The scheduler output. + """ + # Attention free models have zero kv_cache_goups, however models + # like Mamba are also attention free but use the kv_cache for + # keeping its internal state. This is why we check the number + # of kv_cache groups instead of solely checking + # for self.model_config.is_attention_free. + if len(self.kv_cache_config.kv_cache_groups) == 0: + return + + if len(self.kv_cache_config.kv_cache_groups) > 1: + raise ValueError("Multiple KVCacheGroups is not" + "currently supported with CPU model runner.") + + assert type( + self.attn_metadata_builders[0]) is TorchSDPAMetadataBuilderV1 + + self.attn_metadata_builders[0].reorder_batch(self.input_batch, + scheduler_output) + def _postprocess_tenosrs(self) -> None: # Note: replace device tensors with cpu tensors def replace_tensor(obj: Any, cpu_attr_name: str, diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index d5a5799efb47c..42cef6c5733d2 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -49,7 +49,8 @@ from vllm.v1.attention.backends.mamba_selectors import get_mamba_attn_backend from vllm.v1.attention.backends.utils import ( AttentionCGSupport, AttentionMetadataBuilder, CommonAttentionMetadata, make_kv_sharing_fast_prefill_attention_metadata, - make_local_attention_virtual_batches) + make_local_attention_virtual_batches, + reorder_batch_to_split_decodes_and_prefills) from vllm.v1.core.encoder_cache_manager import compute_encoder_budget from vllm.v1.kv_cache_interface import (AttentionSpec, ChunkedLocalAttentionSpec, @@ -329,6 +330,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): self.kv_sharing_fast_prefill_logits_indices = torch.zeros( self.max_num_tokens, dtype=torch.int32, device=self.device) + self.reorder_batch_threshold: Optional[int] = None + def _may_reorder_batch(self, scheduler_output: "SchedulerOutput") -> None: """ Update the order of requests in the batch based on the attention @@ -347,20 +350,11 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): if len(self.kv_cache_config.kv_cache_groups) == 0: return - self.attn_metadata_builders[0].reorder_batch(self.input_batch, - scheduler_output) - - # For models with multiple KV cache groups, the groups should agree on - # the same order of requests. We ensure this by only allowing the first - # group to reorder the batch and asserting that all other groups do not - # reorder the batch. - # TODO(tdoublep): make this more flexible so that any group can - # re-order the batch (not only the first). - # TODO(tdoublep): verify this during engine init instead of at runtime - for i in range(1, len(self.kv_cache_config.kv_cache_groups)): - batch_reordered = self.attn_metadata_builders[i].reorder_batch( - self.input_batch, scheduler_output) - assert not batch_reordered + if self.reorder_batch_threshold is not None: + reorder_batch_to_split_decodes_and_prefills( + self.input_batch, + scheduler_output, + decode_threshold=self.reorder_batch_threshold) # Note: used for model runner override. def _init_device_properties(self) -> None: @@ -2654,6 +2648,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): self.attn_backends.append(attn_backend_i) self.attn_metadata_builders.append(attn_metadata_builder_i) + # Calculate reorder batch threshold (if neeeded) + self.calculate_reorder_batch_threshold() + if len(self.attn_backends) > 0: return @@ -2688,6 +2685,28 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): self.attn_metadata_builders.append(attn_metadata_builder) self.is_encoder_only_model = True + def calculate_reorder_batch_threshold(self) -> None: + """ + Check that if any backends reorder batches; that the reordering + is compatible (e.g., decode threshold is the same) + """ + for attn_metadata_builder_i in self.attn_metadata_builders: + # check that if any backends reorder batches; that the reordering + # is compatible (e.g., decode threshold is the same) + reorder_batch_threshold_i = ( + attn_metadata_builder_i.reorder_batch_threshold) + if reorder_batch_threshold_i is not None: + if self.reorder_batch_threshold is not None: + if reorder_batch_threshold_i != \ + self.reorder_batch_threshold: + raise ValueError( + f"Attention backend reorders decodes with " + f"threshold {reorder_batch_threshold_i} but other " + f"backend uses threshold " + f"{self.reorder_batch_threshold}") + else: + self.reorder_batch_threshold = reorder_batch_threshold_i + def may_reinitialize_input_batch(self, kv_cache_config: KVCacheConfig) -> None: """