From 7e4be741044bfead91afc418100ff9a4d804bf7f Mon Sep 17 00:00:00 2001 From: Wentao Ye <44945378+yewentao256@users.noreply.github.com> Date: Tue, 4 Nov 2025 01:05:55 -0500 Subject: [PATCH] [Bug] Batch invariant: Fix flash attn MLA `RuntimeError: scheduler_metadata must have shape (metadata_size)` (#27884) --- vllm/model_executor/layers/batch_invariant.py | 2 ++ vllm/v1/attention/backends/mla/flashattn_mla.py | 6 +++--- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/vllm/model_executor/layers/batch_invariant.py b/vllm/model_executor/layers/batch_invariant.py index 39e77b935d3d5..0234f228d700a 100644 --- a/vllm/model_executor/layers/batch_invariant.py +++ b/vllm/model_executor/layers/batch_invariant.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import contextlib +import functools import os from collections import namedtuple from collections.abc import Callable @@ -846,6 +847,7 @@ def get_batch_invariant_attention_block_size() -> AttentionBlockSize: return AttentionBlockSize(block_m=16, block_n=16) +@functools.cache def vllm_is_batch_invariant(): env_key = "VLLM_BATCH_INVARIANT" is_overridden = False diff --git a/vllm/v1/attention/backends/mla/flashattn_mla.py b/vllm/v1/attention/backends/mla/flashattn_mla.py index a6aac701b784b..6baf45efccb54 100644 --- a/vllm/v1/attention/backends/mla/flashattn_mla.py +++ b/vllm/v1/attention/backends/mla/flashattn_mla.py @@ -163,6 +163,9 @@ class FlashAttnMLAMetadataBuilder(MLACommonMetadataBuilder[FlashAttnMLAMetadata] # we only set num_splits when using cuda graphs. max_num_splits = self.max_num_splits + if vllm_is_batch_invariant(): + max_num_splits = 1 + scheduler_metadata = self._schedule_decode( num_reqs=seq_lens_cpu.numel(), cu_query_lens=query_start_loc_device, @@ -188,9 +191,6 @@ class FlashAttnMLAMetadataBuilder(MLACommonMetadataBuilder[FlashAttnMLAMetadata] self.scheduler_metadata[n:] = 0 scheduler_metadata = self.scheduler_metadata[:n] - if vllm_is_batch_invariant(): - max_num_splits = 1 - metadata = FlashAttnMLADecodeMetadata( block_table=block_table_tensor, seq_lens=seq_lens_device,