[Bug] Batch invariant: Fix flash attn MLA RuntimeError: scheduler_metadata must have shape (metadata_size) (#27884)

This commit is contained in:
Wentao Ye 2025-11-04 01:05:55 -05:00 committed by GitHub
parent 380ba6816d
commit 7e4be74104
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 5 additions and 3 deletions

View File

@ -1,6 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import contextlib
import functools
import os
from collections import namedtuple
from collections.abc import Callable
@ -846,6 +847,7 @@ def get_batch_invariant_attention_block_size() -> AttentionBlockSize:
return AttentionBlockSize(block_m=16, block_n=16)
@functools.cache
def vllm_is_batch_invariant():
env_key = "VLLM_BATCH_INVARIANT"
is_overridden = False

View File

@ -163,6 +163,9 @@ class FlashAttnMLAMetadataBuilder(MLACommonMetadataBuilder[FlashAttnMLAMetadata]
# we only set num_splits when using cuda graphs.
max_num_splits = self.max_num_splits
if vllm_is_batch_invariant():
max_num_splits = 1
scheduler_metadata = self._schedule_decode(
num_reqs=seq_lens_cpu.numel(),
cu_query_lens=query_start_loc_device,
@ -188,9 +191,6 @@ class FlashAttnMLAMetadataBuilder(MLACommonMetadataBuilder[FlashAttnMLAMetadata]
self.scheduler_metadata[n:] = 0
scheduler_metadata = self.scheduler_metadata[:n]
if vllm_is_batch_invariant():
max_num_splits = 1
metadata = FlashAttnMLADecodeMetadata(
block_table=block_table_tensor,
seq_lens=seq_lens_device,