[Spec Decode] Enable efficient speculative decoding with FlashInfer-MLA (#25984)

Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
This commit is contained in:
Benjamin Chislett 2025-10-07 16:05:59 -04:00 committed by GitHub
parent 6ebaf43ee4
commit 3d1f67616d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 42 additions and 5 deletions

View File

@ -190,7 +190,7 @@ return curr_o @ W_O
import functools
from abc import abstractmethod
from dataclasses import dataclass, field
from typing import Generic, Optional, TypeVar, Union
from typing import ClassVar, Generic, Optional, TypeVar, Union
import torch
from tqdm import tqdm
@ -454,6 +454,20 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
understand this class
"""
# Whether the backend supports reordering the batch such that
# short sequences (i.e. verification for speculative decoding) are
# classified as decode requests.
# If True, this will increase `reorder_batch_threshold` (below) when
# speculative decoding is enabled, and set `require_uniform=True` when
# when reordering the batch. Non-uniform decode requests will
# fall back to prefill in this case.
supports_uniform_spec_as_decode: ClassVar[bool] = False
# The threshold for reordering the batch into decode and prefill requests.
# If > 1, the batch will be reordered such that requests with
# query length <= threshold are classified as decode requests.
# Use `supports_uniform_spec_as_decode` (above) to set this automatically
# when speculative decoding is enabled.
reorder_batch_threshold: int = 1
@staticmethod
@ -503,6 +517,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
self.model_config = vllm_config.model_config
parallel_config = vllm_config.parallel_config
self.compilation_config = vllm_config.compilation_config
self.vllm_config = vllm_config
self.device = device
self.num_heads = self.model_config.get_num_attention_heads(parallel_config)
@ -578,6 +593,11 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
device=device,
)
supports_spec_as_decode = self.supports_uniform_spec_as_decode
self._init_reorder_batch_threshold(
self.reorder_batch_threshold, supports_spec_as_decode
)
def _build_fi_prefill_wrappers(self, prefill: FlashInferPrefillMetadata):
qo_indptr = prefill.query_start_loc
@ -714,7 +734,9 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
split_decodes_and_prefills(
common_attn_metadata, decode_threshold=self.reorder_batch_threshold
common_attn_metadata,
decode_threshold=self.reorder_batch_threshold,
require_uniform=self.supports_uniform_spec_as_decode,
)
)

View File

@ -22,6 +22,9 @@ FLASHINFER_MLA_WORKSPACE_BUFFER_SIZE = 128 * 1024 * 1024
class FlashInferMLAMetadataBuilder(MLACommonMetadataBuilder[MLACommonMetadata]):
# enable spec-as-decode optimization
supports_uniform_spec_as_decode: ClassVar[bool] = True
# enable full CUDA Graph support for decode-only capture
cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.UNIFORM_BATCH
@ -111,7 +114,15 @@ class FlashInferMLAImpl(MLACommonImpl[MLACommonMetadata]):
q = torch.cat([q_nope, q_pe], dim=-1)
# trtllm API requires extra dimension q_len_per_request for MTP
q = q.unsqueeze(1)
if attn_metadata.num_decode_tokens % attn_metadata.num_decodes != 0:
logger.warning_once(
"""FlashInferMLAImpl got a query of uneven length.
This usually indicates an issue in batch reordering
or incorrect setup in dummy_run."""
)
q = q.unsqueeze(1)
else:
q = q.view(attn_metadata.num_decodes, -1, q.shape[-2], q.shape[-1])
if self.bmm1_scale is None:
self.bmm1_scale = layer._q_scale_float * layer._k_scale_float * self.scale
@ -132,6 +143,9 @@ class FlashInferMLAImpl(MLACommonImpl[MLACommonMetadata]):
bmm2_scale=self.bmm2_scale,
)
# Flatten the output for consistent shape
o = o.view(-1, o.shape[-2], o.shape[-1])
# TODO: Return LSE pending support from Flashinfer API:
# https://github.com/flashinfer-ai/flashinfer/pull/1566
return o, None

View File

@ -275,8 +275,9 @@ class AttentionMetadataBuilder(abc.ABC, Generic[M]):
speculative_config is not None
and speculative_config.num_speculative_tokens is not None
):
self.reorder_batch_threshold = (
1 + speculative_config.num_speculative_tokens
self.reorder_batch_threshold = max(
self.reorder_batch_threshold,
1 + speculative_config.num_speculative_tokens,
)
@abstractmethod