diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py index f7ec18f5e9f6..3fb00f5917ea 100755 --- a/vllm/v1/attention/backends/mla/common.py +++ b/vllm/v1/attention/backends/mla/common.py @@ -190,7 +190,7 @@ return curr_o @ W_O import functools from abc import abstractmethod from dataclasses import dataclass, field -from typing import Generic, Optional, TypeVar, Union +from typing import ClassVar, Generic, Optional, TypeVar, Union import torch from tqdm import tqdm @@ -454,6 +454,20 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]): understand this class """ + # Whether the backend supports reordering the batch such that + # short sequences (i.e. verification for speculative decoding) are + # classified as decode requests. + # If True, this will increase `reorder_batch_threshold` (below) when + # speculative decoding is enabled, and set `require_uniform=True` when + # when reordering the batch. Non-uniform decode requests will + # fall back to prefill in this case. + supports_uniform_spec_as_decode: ClassVar[bool] = False + + # The threshold for reordering the batch into decode and prefill requests. + # If > 1, the batch will be reordered such that requests with + # query length <= threshold are classified as decode requests. + # Use `supports_uniform_spec_as_decode` (above) to set this automatically + # when speculative decoding is enabled. reorder_batch_threshold: int = 1 @staticmethod @@ -503,6 +517,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]): self.model_config = vllm_config.model_config parallel_config = vllm_config.parallel_config self.compilation_config = vllm_config.compilation_config + self.vllm_config = vllm_config self.device = device self.num_heads = self.model_config.get_num_attention_heads(parallel_config) @@ -578,6 +593,11 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]): device=device, ) + supports_spec_as_decode = self.supports_uniform_spec_as_decode + self._init_reorder_batch_threshold( + self.reorder_batch_threshold, supports_spec_as_decode + ) + def _build_fi_prefill_wrappers(self, prefill: FlashInferPrefillMetadata): qo_indptr = prefill.query_start_loc @@ -714,7 +734,9 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]): num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = ( split_decodes_and_prefills( - common_attn_metadata, decode_threshold=self.reorder_batch_threshold + common_attn_metadata, + decode_threshold=self.reorder_batch_threshold, + require_uniform=self.supports_uniform_spec_as_decode, ) ) diff --git a/vllm/v1/attention/backends/mla/flashinfer_mla.py b/vllm/v1/attention/backends/mla/flashinfer_mla.py index 13552edab87b..206f96ea366a 100644 --- a/vllm/v1/attention/backends/mla/flashinfer_mla.py +++ b/vllm/v1/attention/backends/mla/flashinfer_mla.py @@ -22,6 +22,9 @@ FLASHINFER_MLA_WORKSPACE_BUFFER_SIZE = 128 * 1024 * 1024 class FlashInferMLAMetadataBuilder(MLACommonMetadataBuilder[MLACommonMetadata]): + # enable spec-as-decode optimization + supports_uniform_spec_as_decode: ClassVar[bool] = True + # enable full CUDA Graph support for decode-only capture cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.UNIFORM_BATCH @@ -111,7 +114,15 @@ class FlashInferMLAImpl(MLACommonImpl[MLACommonMetadata]): q = torch.cat([q_nope, q_pe], dim=-1) # trtllm API requires extra dimension q_len_per_request for MTP - q = q.unsqueeze(1) + if attn_metadata.num_decode_tokens % attn_metadata.num_decodes != 0: + logger.warning_once( + """FlashInferMLAImpl got a query of uneven length. + This usually indicates an issue in batch reordering + or incorrect setup in dummy_run.""" + ) + q = q.unsqueeze(1) + else: + q = q.view(attn_metadata.num_decodes, -1, q.shape[-2], q.shape[-1]) if self.bmm1_scale is None: self.bmm1_scale = layer._q_scale_float * layer._k_scale_float * self.scale @@ -132,6 +143,9 @@ class FlashInferMLAImpl(MLACommonImpl[MLACommonMetadata]): bmm2_scale=self.bmm2_scale, ) + # Flatten the output for consistent shape + o = o.view(-1, o.shape[-2], o.shape[-1]) + # TODO: Return LSE pending support from Flashinfer API: # https://github.com/flashinfer-ai/flashinfer/pull/1566 return o, None diff --git a/vllm/v1/attention/backends/utils.py b/vllm/v1/attention/backends/utils.py index 003c7253e553..add2c3cb8d59 100644 --- a/vllm/v1/attention/backends/utils.py +++ b/vllm/v1/attention/backends/utils.py @@ -275,8 +275,9 @@ class AttentionMetadataBuilder(abc.ABC, Generic[M]): speculative_config is not None and speculative_config.num_speculative_tokens is not None ): - self.reorder_batch_threshold = ( - 1 + speculative_config.num_speculative_tokens + self.reorder_batch_threshold = max( + self.reorder_batch_threshold, + 1 + speculative_config.num_speculative_tokens, ) @abstractmethod