[Spec Decode] Enable FlashInfer Spec Decoding (#25196)

Signed-off-by: Benjamin Chislett <benjamin.chislett@centml.ai>
Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
Co-authored-by: lhsjohn <huashuoli@tencent.com>
Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
Benjamin Chislett 2025-09-23 22:29:58 -04:00 committed by yewentao256
parent 0e54bbe108
commit 177c37e960
12 changed files with 250 additions and 49 deletions

View File

@ -9,7 +9,8 @@ from tests.v1.attention.utils import BatchSpec, create_common_attn_metadata
from vllm.v1.attention.backends.utils import (UBatchSlice,
_make_metadata_with_slice,
slice_query_start_locs,
split_attn_metadata)
split_attn_metadata,
split_decodes_and_prefills)
from vllm.v1.worker.ubatch_utils import create_ubatch_slices
@ -158,6 +159,112 @@ def test_split_attn_metadata_decode_batch(large_decode_metadata):
assert torch.equal(results[1].seq_lens, torch.tensor([2048] * mid_point))
def apply_split_decodes_and_prefills(query_lens: list[int],
decode_threshold: int,
require_uniform: bool):
"""Helper function to apply split_decodes_and_prefills and return
the results."""
device = torch.device("cpu")
seq_lens = [10 * (i + 1) for i in range(len(query_lens))]
common_metadata = create_common_attn_metadata(BatchSpec(
seq_lens=seq_lens, query_lens=query_lens),
block_size=16,
device=device)
return split_decodes_and_prefills(common_metadata,
decode_threshold=decode_threshold,
require_uniform=require_uniform)
def test_split_decodes_and_prefills_nonuniform_all_ones():
query_lens = [1, 1, 1]
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
apply_split_decodes_and_prefills(query_lens, 1, False))
assert num_decodes == 3
assert num_prefills == 0
assert num_decode_tokens == 3
assert num_prefill_tokens == 0
def test_split_decodes_and_prefills_nonuniform_all_short_decodes():
query_lens = [1, 2, 1, 3, 2, 1, 2]
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
apply_split_decodes_and_prefills(query_lens, 3, False))
assert num_decodes == 7
assert num_prefills == 0
assert num_decode_tokens == sum(query_lens)
assert num_prefill_tokens == 0
def test_split_decodes_and_prefills_nonuniform_all_prefills():
query_lens = [4, 5, 6, 7]
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
apply_split_decodes_and_prefills(query_lens, 3, False))
assert num_decodes == 0
assert num_prefills == 4
assert num_decode_tokens == 0
assert num_prefill_tokens == sum(query_lens)
def test_split_decodes_and_prefills_nonuniform_mixed_batch():
query_lens = [2, 1, 3, 4, 5, 6, 7, 8]
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
apply_split_decodes_and_prefills(query_lens, 4, False))
assert num_decodes == 4 # 2, 1, 3, 4 are all <= 4
assert num_prefills == 4 # 5, 6, 7, 8 are all > 4
assert num_decode_tokens == 10 # 2 + 1 + 3 + 4
assert num_prefill_tokens == 26 # 5 + 6 + 7 + 8
def test_split_decodes_and_prefills_uniform_all_ones():
query_lens = [1, 1, 1]
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
apply_split_decodes_and_prefills(query_lens, 1, True))
assert num_decodes == 3
assert num_prefills == 0
assert num_decode_tokens == 3
assert num_prefill_tokens == 0
def test_split_decodes_and_prefills_uniform_all_short_decodes():
query_lens = [2, 2, 1, 3, 2, 1, 2]
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
apply_split_decodes_and_prefills(query_lens, 3, True))
assert num_decodes == 2
assert num_prefills == 5
assert num_decode_tokens == 4
assert num_prefill_tokens == (1 + 3 + 2 + 1 + 2)
def test_split_decodes_and_prefills_uniform_all_prefills():
query_lens = [4, 5, 6, 7]
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
apply_split_decodes_and_prefills(query_lens, 3, True))
assert num_decodes == 0
assert num_prefills == 4
assert num_decode_tokens == 0
assert num_prefill_tokens == sum(query_lens)
def test_split_decodes_and_prefills_uniform_mixed_batch_all_uniform_decodes():
query_lens = [2, 2, 2, 4, 5, 6, 7, 8]
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
apply_split_decodes_and_prefills(query_lens, 4, True))
assert num_decodes == 3 # 2, 2, 2 are all <= 4 and uniform
assert num_prefills == 5 # 4, 5, 6, 7, 8 are all > 4
assert num_decode_tokens == 6 # 2 + 2 + 2
assert num_prefill_tokens == 30 # 4 + 5 + 6 + 7 + 8
def test_split_decodes_and_prefills_uniform_mixed_batch_non_uniform_decodes():
query_lens = [2, 1, 2, 4, 5, 6, 7, 8]
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
apply_split_decodes_and_prefills(query_lens, 4, True))
assert num_decodes == 1 # only the first 2 is taken as decode
assert num_prefills == 7 # 1, 2, 4, 5, 6, 7, 8 are all > 4 or non-uniform
assert num_decode_tokens == 2 # only the first 2
assert num_prefill_tokens == (sum(query_lens) - 2) # rest of the tokens
@pytest.mark.parametrize(
"seq_lens,query_lens,split_point,expected_first_reqs,expected_second_reqs",
[

View File

@ -181,6 +181,12 @@ def force_use_trtllm_attention() -> Optional[bool]:
return _force_use_trtllm_attention(envs.VLLM_USE_TRTLLM_ATTENTION)
def can_use_trtllm_attention(num_qo_heads: int, num_kv_heads: int) -> bool:
"""Check if the current configuration supports TRTLLM attention."""
has_trtllm = supports_trtllm_attention()
return has_trtllm and (num_qo_heads % num_kv_heads == 0)
def use_trtllm_attention(
num_qo_heads: int,
num_kv_heads: int,
@ -188,7 +194,9 @@ def use_trtllm_attention(
max_seq_len: int,
kv_cache_dtype: str,
q_dtype: torch.dtype,
is_prefill: bool,
has_sinks: bool = False,
has_spec: bool = False,
) -> bool:
"""Return ``True`` if TRTLLM attention is used."""
force_use_trtllm = force_use_trtllm_attention()
@ -214,6 +222,12 @@ def use_trtllm_attention(
)
return False
if has_spec and not is_prefill:
# Speculative decoding requires TRTLLM attention for decodes
logger.info_once(
"Using TRTLLM attention (enabled for speculative decoding).")
return True
# Must use TRTLLM attention if query is FP8 quantized
if q_dtype == current_platform.fp8_dtype():
if has_sinks:
@ -391,6 +405,7 @@ __all__ = [
"has_flashinfer_cutlass_fused_moe",
"has_nvidia_artifactory",
"supports_trtllm_attention",
"can_use_trtllm_attention",
"use_trtllm_attention",
"flashinfer_disable_q_quantization",
"flashinfer_scaled_fp4_mm",

View File

@ -25,7 +25,8 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
from vllm.platforms import current_platform
from vllm.triton_utils import tl, triton
from vllm.utils import cdiv, is_pin_memory_available
from vllm.utils.flashinfer import (flashinfer_disable_q_quantization,
from vllm.utils.flashinfer import (can_use_trtllm_attention,
flashinfer_disable_q_quantization,
supports_trtllm_attention,
use_trtllm_attention)
from vllm.v1.attention.backends.flash_attn import use_cascade_attention
@ -223,6 +224,7 @@ class FlashInferMetadata:
# For flashinfer trtllm batch decode
max_q_len: int
max_q_len_prefill: int
max_seq_len: int
seq_lens: torch.Tensor
block_table_tensor: torch.Tensor
@ -250,7 +252,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
cudagraph_support: ClassVar[AttentionCGSupport] = \
AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE
reorder_batch_threshold: ClassVar[int] = 1
reorder_batch_threshold: int = 1
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
vllm_config: VllmConfig, device: torch.device):
@ -302,6 +304,10 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
else:
self.q_data_type = self.model_config.dtype
supports_spec_as_decode = \
can_use_trtllm_attention(self.num_qo_heads, self.num_kv_heads)
self._init_reorder_batch_threshold(1, supports_spec_as_decode)
self._cascade_wrapper = None # Wrapper for cascade attention
# Global hyperparameters shared by all attention layers
@ -416,7 +422,8 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
num_actual_tokens = common_attn_metadata.num_actual_tokens
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens =\
split_decodes_and_prefills(common_attn_metadata,
decode_threshold=self.reorder_batch_threshold)
decode_threshold=self.reorder_batch_threshold,
require_uniform=True)
page_size = self.page_size
max_q_len = common_attn_metadata.max_query_len
@ -491,20 +498,25 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
paged_kv_last_page_len_np,
)
uses_spec_reorder = self.reorder_batch_threshold > 1
prefill_use_trtllm = use_trtllm_attention(self.num_qo_heads,
self.num_kv_heads,
num_prefill_tokens,
max_seq_len,
self.cache_dtype,
self.q_data_type,
has_sinks=self.has_sinks)
is_prefill=True,
has_sinks=self.has_sinks,
has_spec=uses_spec_reorder)
decode_use_trtllm = use_trtllm_attention(self.num_qo_heads,
self.num_kv_heads,
num_decode_tokens,
max_seq_len,
self.cache_dtype,
self.q_data_type,
has_sinks=self.has_sinks)
is_prefill=False,
has_sinks=self.has_sinks,
has_spec=uses_spec_reorder)
if self.has_sinks and not (prefill_use_trtllm and decode_use_trtllm):
raise NotImplementedError(
"FlashInfer backend currently does not support attention "
@ -521,6 +533,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
q_data_type=self.q_data_type,
slot_mapping=common_attn_metadata.slot_mapping,
max_q_len=max_q_len,
max_q_len_prefill=max_q_len,
max_seq_len=max_seq_len,
seq_lens=seq_lens,
block_table_tensor=block_table_tensor,
@ -577,6 +590,15 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
qo_indptr_cpu = qo_indptr_cpu[prefill_start:] - qo_indptr_cpu[
prefill_start]
paged_kv_indptr_cpu = paged_kv_indptr_cpu[prefill_start:]
# Recompute max_q_len for the slice of requests we are using
# for prefills. This can be different from max_q_len when
# we have a non-uniform batch with some short decodes offloaded
# to the prefill pathway
query_lens_prefill = qo_indptr_cpu[1:] - qo_indptr_cpu[:-1]
attn_metadata.max_q_len_prefill = \
int(query_lens_prefill.max().item())
if not attn_metadata.prefill_use_trtllm:
attn_metadata.prefill_wrapper.plan(
qo_indptr_cpu,
@ -607,7 +629,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
num_decodes <= self._decode_cudagraph_max_bs)
if use_cudagraph:
num_input_tokens = (
self.vllm_config.pad_for_cudagraph(num_decodes))
self.vllm_config.pad_for_cudagraph(num_decode_tokens))
# Carefully fulfill the padding region with reasonable value
# on cpu.
# Make sure paged_kv_indptr_cpu is not decreasing
@ -621,7 +643,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
num_decodes:num_input_tokens].fill_(1)
else:
num_input_tokens = num_decodes
num_input_tokens = num_decode_tokens
attn_metadata.decode_wrapper = self._get_decode_wrapper(
num_input_tokens, use_cudagraph)
@ -842,6 +864,9 @@ class FlashInferImpl(AttentionImpl):
output.copy_(attn_metadata.cascade_wrapper.run(query, kv_cache))
return output
# When using spec decoding, num_decodes can be < num_decode_tokens
# because some decode requests may have more than one query token.
num_decodes = attn_metadata.num_decodes
num_decode_tokens = attn_metadata.num_decode_tokens
num_prefill_tokens = attn_metadata.num_prefill_tokens
@ -874,8 +899,8 @@ class FlashInferImpl(AttentionImpl):
prefill_query = prefill_query.contiguous()
workspace_buffer = _get_trtllm_gen_workspace_buffer()
block_tables_prefill = attn_metadata.block_table_tensor[
num_decode_tokens:]
seq_lens_prefill = attn_metadata.seq_lens[num_decode_tokens:]
num_decodes:]
seq_lens_prefill = attn_metadata.seq_lens[num_decodes:]
# This path needs to be enabled with VLLM_KV_CACHE_LAYOUT = HND
assert get_kv_cache_layout() == "HND"
@ -919,7 +944,7 @@ class FlashInferImpl(AttentionImpl):
workspace_buffer=workspace_buffer,
block_tables=mock_block_table,
seq_lens=seq_lens_prefill,
max_q_len=attn_metadata.max_q_len,
max_q_len=attn_metadata.max_q_len_prefill,
max_kv_len=attn_metadata.max_seq_len,
bmm1_scale=self.bmm1_scale,
bmm2_scale=self.bmm2_scale,
@ -976,6 +1001,14 @@ class FlashInferImpl(AttentionImpl):
assert self.o_sf_scale is None
out = output[:num_decode_tokens]
if num_decode_tokens % attn_metadata.num_decodes != 0:
# This gets triggered when the dummy_run forces
# attention to be initialized with q_len = 0
q_len_per_req = 1
else:
q_len_per_req = \
num_decode_tokens // attn_metadata.num_decodes
trtllm_batch_decode_with_kv_cache(
query=decode_query,
kv_cache=kv_cache_permute,
@ -989,7 +1022,7 @@ class FlashInferImpl(AttentionImpl):
sinks=self.sinks,
o_sf_scale=self.o_sf_scale,
out=out,
)
q_len_per_req=q_len_per_req)
return output_padded

View File

@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Backend for GatedDeltaNet attention."""
from dataclasses import dataclass
from typing import ClassVar, Optional
from typing import Optional
import torch
@ -62,7 +62,7 @@ class GDNAttentionMetadataBuilder(
cudagraph_support = AttentionCGSupport.UNIFORM_BATCH
reorder_batch_threshold: ClassVar[int] = 1
reorder_batch_threshold: int = 1
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
vllm_config: VllmConfig, device: torch.device):
@ -76,7 +76,7 @@ class GDNAttentionMetadataBuilder(
else:
self.num_spec = 0
self.use_spec_decode = self.num_spec > 0
self.reorder_batch_threshold = self.num_spec + 1 # type: ignore[misc]
self._init_reorder_batch_threshold(1, self.use_spec_decode)
self.use_full_cuda_graph = \
self.compilation_config.cudagraph_mode.has_full_cudagraphs()

View File

@ -1,7 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass
from typing import ClassVar
import torch
@ -35,7 +34,7 @@ class LinearAttentionMetadata:
class LinearAttentionMetadataBuilder(
AttentionMetadataBuilder[LinearAttentionMetadata]):
reorder_batch_threshold: ClassVar[int] = 1
reorder_batch_threshold: int = 1
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
vllm_config: VllmConfig, device: torch.device):

View File

@ -16,7 +16,7 @@ M = TypeVar("M")
class BaseMambaAttentionMetadataBuilder(AttentionMetadataBuilder[M], abc.ABC):
reorder_batch_threshold: ClassVar[int] = 1
reorder_batch_threshold: int = 1
cudagraph_support: ClassVar[AttentionCGSupport] = \
AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE

View File

@ -190,7 +190,7 @@ return curr_o @ W_O
import functools
from abc import abstractmethod
from dataclasses import dataclass, field
from typing import ClassVar, Generic, Optional, TypeVar, Union
from typing import Generic, Optional, TypeVar, Union
import torch
from tqdm import tqdm
@ -434,7 +434,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
NOTE: Please read the comment at the top of the file before trying to
understand this class
"""
reorder_batch_threshold: ClassVar[int] = 1
reorder_batch_threshold: int = 1
@staticmethod
def determine_chunked_prefill_workspace_size(

View File

@ -64,7 +64,7 @@ class FlashAttnMLAMetadataBuilder(
cudagraph_support: ClassVar[AttentionCGSupport] = \
AttentionCGSupport.UNIFORM_BATCH
reorder_batch_threshold: ClassVar[int] = 512
reorder_batch_threshold: int = 512
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
vllm_config: VllmConfig, device: torch.device):
@ -99,7 +99,7 @@ class FlashAttnMLAMetadataBuilder(
# TODO(lucas): Until we add support for the DCP custom masking we need
# to restrict decodes to q_len == 1 when DCP is enabled.
self.__class__.reorder_batch_threshold = 1 \
self.reorder_batch_threshold = 1 \
if get_dcp_group().world_size > 1 else self.reorder_batch_threshold
def _schedule_decode(self, num_reqs, cu_query_lens, max_query_len, seqlens,

View File

@ -1,7 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass
from typing import ClassVar, Optional
from typing import Optional
import torch
@ -41,7 +41,7 @@ class ShortConvAttentionMetadata:
class ShortConvAttentionMetadataBuilder(
AttentionMetadataBuilder[ShortConvAttentionMetadata]):
reorder_batch_threshold: ClassVar[int] = 1
reorder_batch_threshold: int = 1
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
vllm_config: VllmConfig, device: torch.device):

View File

@ -236,7 +236,7 @@ class AttentionMetadataBuilder(abc.ABC, Generic[M]):
# Does this backend/builder reorder the batch?
# If not, set this to None. Otherwise set it to the query
# length that will be pulled into the front of the batch.
reorder_batch_threshold: ClassVar[Optional[int]] = None
reorder_batch_threshold: Optional[int] = None
@abstractmethod
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
@ -246,6 +246,22 @@ class AttentionMetadataBuilder(abc.ABC, Generic[M]):
self.vllm_config = vllm_config
self.device = device
def _init_reorder_batch_threshold(
self,
reorder_batch_threshold: int = 1,
supports_spec_as_decode: bool = False) -> None:
self.reorder_batch_threshold = reorder_batch_threshold
if self.reorder_batch_threshold is not None \
and supports_spec_as_decode:
# If the backend supports spec-as-decode kernels, then we can set
# the reorder_batch_threshold based on the number of speculative
# tokens from the config.
speculative_config = self.vllm_config.speculative_config
if (speculative_config is not None
and speculative_config.num_speculative_tokens is not None):
self.reorder_batch_threshold = \
1 + speculative_config.num_speculative_tokens
@abstractmethod
def build(self,
common_prefix_len: int,
@ -703,9 +719,9 @@ def subclass_attention_backend(
def split_decodes_and_prefills(
common_attn_metadata: CommonAttentionMetadata,
decode_threshold: int = 1,
) -> tuple[int, int, int, int]:
common_attn_metadata: CommonAttentionMetadata,
decode_threshold: int = 1,
require_uniform: bool = False) -> tuple[int, int, int, int]:
"""
Assuming a reordered batch, finds the boundary between prefill and decode
requests.
@ -714,6 +730,9 @@ def split_decodes_and_prefills(
common_attn_metadata: CommonAttentionMetadata object containing the
batch metadata.
decode_threshold: The maximum query length to be considered a decode.
require_uniform: If True, requires that all decode requests have the
same query length. When set, some queries may be considered prefills
even if they are <= decode_threshold, in order to ensure uniformity.
Returns:
num_decodes: The number of decode requests.
@ -726,11 +745,20 @@ def split_decodes_and_prefills(
num_tokens = common_attn_metadata.num_actual_tokens
query_start_loc = common_attn_metadata.query_start_loc_cpu
if max_query_len <= decode_threshold:
if max_query_len <= decode_threshold and \
(not require_uniform or decode_threshold <= 1):
return num_reqs, 0, num_tokens, 0
query_lens = query_start_loc[1:] - query_start_loc[:-1]
is_prefill = query_lens > decode_threshold
if query_lens[0].item() > decode_threshold:
# first request is not decode, so no decode requests
return 0, num_reqs, 0, num_tokens
if require_uniform:
is_prefill = query_lens != query_lens[0]
else:
is_prefill = query_lens > decode_threshold
if not torch.any(is_prefill):
return num_reqs, 0, num_tokens, 0
@ -806,6 +834,38 @@ def reorder_batch_to_split_decodes_and_prefills(
return modified_batch
def reshape_query_for_spec_decode(query: torch.Tensor,
batch_size: int) -> torch.Tensor:
"""
Reshapes the query tensor for the specified batch size, so that
it has shape (batch_size, seq_len, num_heads, head_dim).
"""
assert query.dim() == 3, f"query must be 3D, got {query.dim()}D"
total_tokens = query.shape[0]
num_heads = query.shape[1]
head_dim = query.shape[2]
assert total_tokens % batch_size == 0, (
f"{total_tokens=} is not divisible by {batch_size=}")
seq_len = total_tokens // batch_size
return query.view(batch_size, seq_len, num_heads, head_dim)
def reshape_attn_output_for_spec_decode(
attn_output: torch.Tensor) -> torch.Tensor:
"""
Reshapes the attention output tensor, so that
the batch_size and seq_len dimensions are combined.
"""
if attn_output.dim() == 3:
# Already in the correct shape
return attn_output
assert attn_output.dim() == 4, \
f"attn_output must be 4D, got {attn_output.dim()}D"
total_tokens = attn_output.shape[0] * attn_output.shape[1]
return attn_output.view(total_tokens, attn_output.shape[2],
attn_output.shape[3])
KV_SHARING_FAST_PREFILL_METADATA_FIELDS = [
('logits_indices_padded', Optional[torch.Tensor], None),
('num_logits_indices', int, 0),

View File

@ -3,7 +3,7 @@
"""Attention layer with XFormersAttention."""
from dataclasses import dataclass
from typing import TYPE_CHECKING, ClassVar, Optional
from typing import TYPE_CHECKING, Optional
import torch
@ -197,7 +197,7 @@ class XFormersAttentionMetadata:
class XFormersAttentionMetadataBuilder(
AttentionMetadataBuilder[XFormersAttentionMetadata]):
reorder_batch_threshold: ClassVar[int] = 1
reorder_batch_threshold: int = 1
def __init__(
self,

View File

@ -3,7 +3,7 @@
import ast
from dataclasses import replace
from importlib.util import find_spec
from typing import Optional, Protocol
from typing import Optional
import numpy as np
import torch
@ -37,17 +37,6 @@ logger = init_logger(__name__)
PADDING_SLOT_ID = -1
class EagleAttentionMetadata(Protocol):
# Required attributes
num_actual_tokens: int
max_query_len: int
query_start_loc: torch.Tensor
max_seq_len: int
seq_lens: torch.Tensor
block_table: torch.Tensor
slot_mapping: torch.Tensor
class EagleProposer:
def __init__(
@ -120,7 +109,7 @@ class EagleProposer:
with_numpy=True)
# Determine allowed attention backends once during initialization.
self.allowed_attn_types: tuple[type, ...]
self.allowed_attn_types: Optional[tuple] = None
if current_platform.is_rocm():
rocm_types = [TritonAttentionMetadata, FlashAttentionMetadata]
# vllm.v1.attention.backends.rocm_aiter_fa is an optional backend
@ -129,9 +118,6 @@ class EagleProposer:
AiterFlashAttentionMetadata)
rocm_types.append(AiterFlashAttentionMetadata)
self.allowed_attn_types = tuple(rocm_types)
else:
self.allowed_attn_types = (FlashAttentionMetadata,
TreeAttentionMetadata)
# Parse the speculative token tree.
spec_token_tree = self.speculative_config.speculative_token_tree
@ -266,7 +252,8 @@ class EagleProposer:
draft_token_ids = logits.argmax(dim=-1)
if not isinstance(attn_metadata, self.allowed_attn_types):
if self.allowed_attn_types is not None and \
not isinstance(attn_metadata, self.allowed_attn_types):
raise ValueError(
f"Unsupported attention metadata type for speculative "
"decoding with num_speculative_tokens > 1: "