mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-08 16:43:37 +08:00
[Bug] Batch invariant: Fix flash attn MLA RuntimeError: scheduler_metadata must have shape (metadata_size) (#27884)
This commit is contained in:
parent
380ba6816d
commit
7e4be74104
@ -1,6 +1,7 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
import contextlib
|
import contextlib
|
||||||
|
import functools
|
||||||
import os
|
import os
|
||||||
from collections import namedtuple
|
from collections import namedtuple
|
||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
@ -846,6 +847,7 @@ def get_batch_invariant_attention_block_size() -> AttentionBlockSize:
|
|||||||
return AttentionBlockSize(block_m=16, block_n=16)
|
return AttentionBlockSize(block_m=16, block_n=16)
|
||||||
|
|
||||||
|
|
||||||
|
@functools.cache
|
||||||
def vllm_is_batch_invariant():
|
def vllm_is_batch_invariant():
|
||||||
env_key = "VLLM_BATCH_INVARIANT"
|
env_key = "VLLM_BATCH_INVARIANT"
|
||||||
is_overridden = False
|
is_overridden = False
|
||||||
|
|||||||
@ -163,6 +163,9 @@ class FlashAttnMLAMetadataBuilder(MLACommonMetadataBuilder[FlashAttnMLAMetadata]
|
|||||||
# we only set num_splits when using cuda graphs.
|
# we only set num_splits when using cuda graphs.
|
||||||
max_num_splits = self.max_num_splits
|
max_num_splits = self.max_num_splits
|
||||||
|
|
||||||
|
if vllm_is_batch_invariant():
|
||||||
|
max_num_splits = 1
|
||||||
|
|
||||||
scheduler_metadata = self._schedule_decode(
|
scheduler_metadata = self._schedule_decode(
|
||||||
num_reqs=seq_lens_cpu.numel(),
|
num_reqs=seq_lens_cpu.numel(),
|
||||||
cu_query_lens=query_start_loc_device,
|
cu_query_lens=query_start_loc_device,
|
||||||
@ -188,9 +191,6 @@ class FlashAttnMLAMetadataBuilder(MLACommonMetadataBuilder[FlashAttnMLAMetadata]
|
|||||||
self.scheduler_metadata[n:] = 0
|
self.scheduler_metadata[n:] = 0
|
||||||
scheduler_metadata = self.scheduler_metadata[:n]
|
scheduler_metadata = self.scheduler_metadata[:n]
|
||||||
|
|
||||||
if vllm_is_batch_invariant():
|
|
||||||
max_num_splits = 1
|
|
||||||
|
|
||||||
metadata = FlashAttnMLADecodeMetadata(
|
metadata = FlashAttnMLADecodeMetadata(
|
||||||
block_table=block_table_tensor,
|
block_table=block_table_tensor,
|
||||||
seq_lens=seq_lens_device,
|
seq_lens=seq_lens_device,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user