mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 04:15:01 +08:00
[Spec Decode] Enable efficient speculative decoding with FlashInfer-MLA (#25984)
Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
This commit is contained in:
parent
6ebaf43ee4
commit
3d1f67616d
@ -190,7 +190,7 @@ return curr_o @ W_O
|
||||
import functools
|
||||
from abc import abstractmethod
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Generic, Optional, TypeVar, Union
|
||||
from typing import ClassVar, Generic, Optional, TypeVar, Union
|
||||
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
@ -454,6 +454,20 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
|
||||
understand this class
|
||||
"""
|
||||
|
||||
# Whether the backend supports reordering the batch such that
|
||||
# short sequences (i.e. verification for speculative decoding) are
|
||||
# classified as decode requests.
|
||||
# If True, this will increase `reorder_batch_threshold` (below) when
|
||||
# speculative decoding is enabled, and set `require_uniform=True` when
|
||||
# when reordering the batch. Non-uniform decode requests will
|
||||
# fall back to prefill in this case.
|
||||
supports_uniform_spec_as_decode: ClassVar[bool] = False
|
||||
|
||||
# The threshold for reordering the batch into decode and prefill requests.
|
||||
# If > 1, the batch will be reordered such that requests with
|
||||
# query length <= threshold are classified as decode requests.
|
||||
# Use `supports_uniform_spec_as_decode` (above) to set this automatically
|
||||
# when speculative decoding is enabled.
|
||||
reorder_batch_threshold: int = 1
|
||||
|
||||
@staticmethod
|
||||
@ -503,6 +517,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
|
||||
self.model_config = vllm_config.model_config
|
||||
parallel_config = vllm_config.parallel_config
|
||||
self.compilation_config = vllm_config.compilation_config
|
||||
self.vllm_config = vllm_config
|
||||
self.device = device
|
||||
|
||||
self.num_heads = self.model_config.get_num_attention_heads(parallel_config)
|
||||
@ -578,6 +593,11 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
|
||||
device=device,
|
||||
)
|
||||
|
||||
supports_spec_as_decode = self.supports_uniform_spec_as_decode
|
||||
self._init_reorder_batch_threshold(
|
||||
self.reorder_batch_threshold, supports_spec_as_decode
|
||||
)
|
||||
|
||||
def _build_fi_prefill_wrappers(self, prefill: FlashInferPrefillMetadata):
|
||||
qo_indptr = prefill.query_start_loc
|
||||
|
||||
@ -714,7 +734,9 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
|
||||
|
||||
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
|
||||
split_decodes_and_prefills(
|
||||
common_attn_metadata, decode_threshold=self.reorder_batch_threshold
|
||||
common_attn_metadata,
|
||||
decode_threshold=self.reorder_batch_threshold,
|
||||
require_uniform=self.supports_uniform_spec_as_decode,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@ -22,6 +22,9 @@ FLASHINFER_MLA_WORKSPACE_BUFFER_SIZE = 128 * 1024 * 1024
|
||||
|
||||
|
||||
class FlashInferMLAMetadataBuilder(MLACommonMetadataBuilder[MLACommonMetadata]):
|
||||
# enable spec-as-decode optimization
|
||||
supports_uniform_spec_as_decode: ClassVar[bool] = True
|
||||
|
||||
# enable full CUDA Graph support for decode-only capture
|
||||
cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.UNIFORM_BATCH
|
||||
|
||||
@ -111,7 +114,15 @@ class FlashInferMLAImpl(MLACommonImpl[MLACommonMetadata]):
|
||||
q = torch.cat([q_nope, q_pe], dim=-1)
|
||||
|
||||
# trtllm API requires extra dimension q_len_per_request for MTP
|
||||
q = q.unsqueeze(1)
|
||||
if attn_metadata.num_decode_tokens % attn_metadata.num_decodes != 0:
|
||||
logger.warning_once(
|
||||
"""FlashInferMLAImpl got a query of uneven length.
|
||||
This usually indicates an issue in batch reordering
|
||||
or incorrect setup in dummy_run."""
|
||||
)
|
||||
q = q.unsqueeze(1)
|
||||
else:
|
||||
q = q.view(attn_metadata.num_decodes, -1, q.shape[-2], q.shape[-1])
|
||||
|
||||
if self.bmm1_scale is None:
|
||||
self.bmm1_scale = layer._q_scale_float * layer._k_scale_float * self.scale
|
||||
@ -132,6 +143,9 @@ class FlashInferMLAImpl(MLACommonImpl[MLACommonMetadata]):
|
||||
bmm2_scale=self.bmm2_scale,
|
||||
)
|
||||
|
||||
# Flatten the output for consistent shape
|
||||
o = o.view(-1, o.shape[-2], o.shape[-1])
|
||||
|
||||
# TODO: Return LSE pending support from Flashinfer API:
|
||||
# https://github.com/flashinfer-ai/flashinfer/pull/1566
|
||||
return o, None
|
||||
|
||||
@ -275,8 +275,9 @@ class AttentionMetadataBuilder(abc.ABC, Generic[M]):
|
||||
speculative_config is not None
|
||||
and speculative_config.num_speculative_tokens is not None
|
||||
):
|
||||
self.reorder_batch_threshold = (
|
||||
1 + speculative_config.num_speculative_tokens
|
||||
self.reorder_batch_threshold = max(
|
||||
self.reorder_batch_threshold,
|
||||
1 + speculative_config.num_speculative_tokens,
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user