mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-26 10:23:04 +08:00
[Bug] Batch invariant: Fix flash attn MLA RuntimeError: scheduler_metadata must have shape (metadata_size) (#27884)
This commit is contained in:
parent
380ba6816d
commit
7e4be74104
@ -1,6 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import contextlib
|
||||
import functools
|
||||
import os
|
||||
from collections import namedtuple
|
||||
from collections.abc import Callable
|
||||
@ -846,6 +847,7 @@ def get_batch_invariant_attention_block_size() -> AttentionBlockSize:
|
||||
return AttentionBlockSize(block_m=16, block_n=16)
|
||||
|
||||
|
||||
@functools.cache
|
||||
def vllm_is_batch_invariant():
|
||||
env_key = "VLLM_BATCH_INVARIANT"
|
||||
is_overridden = False
|
||||
|
||||
@ -163,6 +163,9 @@ class FlashAttnMLAMetadataBuilder(MLACommonMetadataBuilder[FlashAttnMLAMetadata]
|
||||
# we only set num_splits when using cuda graphs.
|
||||
max_num_splits = self.max_num_splits
|
||||
|
||||
if vllm_is_batch_invariant():
|
||||
max_num_splits = 1
|
||||
|
||||
scheduler_metadata = self._schedule_decode(
|
||||
num_reqs=seq_lens_cpu.numel(),
|
||||
cu_query_lens=query_start_loc_device,
|
||||
@ -188,9 +191,6 @@ class FlashAttnMLAMetadataBuilder(MLACommonMetadataBuilder[FlashAttnMLAMetadata]
|
||||
self.scheduler_metadata[n:] = 0
|
||||
scheduler_metadata = self.scheduler_metadata[:n]
|
||||
|
||||
if vllm_is_batch_invariant():
|
||||
max_num_splits = 1
|
||||
|
||||
metadata = FlashAttnMLADecodeMetadata(
|
||||
block_table=block_table_tensor,
|
||||
seq_lens=seq_lens_device,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user