[Bug] Batch invariant: Fix flash attn MLA RuntimeError: scheduler_metadata must have shape (metadata_size) (#27884)

This commit is contained in:
Wentao Ye 2025-11-04 01:05:55 -05:00 committed by GitHub
parent 380ba6816d
commit 7e4be74104
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 5 additions and 3 deletions

View File

@ -1,6 +1,7 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import contextlib import contextlib
import functools
import os import os
from collections import namedtuple from collections import namedtuple
from collections.abc import Callable from collections.abc import Callable
@ -846,6 +847,7 @@ def get_batch_invariant_attention_block_size() -> AttentionBlockSize:
return AttentionBlockSize(block_m=16, block_n=16) return AttentionBlockSize(block_m=16, block_n=16)
@functools.cache
def vllm_is_batch_invariant(): def vllm_is_batch_invariant():
env_key = "VLLM_BATCH_INVARIANT" env_key = "VLLM_BATCH_INVARIANT"
is_overridden = False is_overridden = False

View File

@ -163,6 +163,9 @@ class FlashAttnMLAMetadataBuilder(MLACommonMetadataBuilder[FlashAttnMLAMetadata]
# we only set num_splits when using cuda graphs. # we only set num_splits when using cuda graphs.
max_num_splits = self.max_num_splits max_num_splits = self.max_num_splits
if vllm_is_batch_invariant():
max_num_splits = 1
scheduler_metadata = self._schedule_decode( scheduler_metadata = self._schedule_decode(
num_reqs=seq_lens_cpu.numel(), num_reqs=seq_lens_cpu.numel(),
cu_query_lens=query_start_loc_device, cu_query_lens=query_start_loc_device,
@ -188,9 +191,6 @@ class FlashAttnMLAMetadataBuilder(MLACommonMetadataBuilder[FlashAttnMLAMetadata]
self.scheduler_metadata[n:] = 0 self.scheduler_metadata[n:] = 0
scheduler_metadata = self.scheduler_metadata[:n] scheduler_metadata = self.scheduler_metadata[:n]
if vllm_is_batch_invariant():
max_num_splits = 1
metadata = FlashAttnMLADecodeMetadata( metadata = FlashAttnMLADecodeMetadata(
block_table=block_table_tensor, block_table=block_table_tensor,
seq_lens=seq_lens_device, seq_lens=seq_lens_device,