[V1] Add tree drafting tests for eagle spec decoding (#22705)

Signed-off-by: Giancarlo Delfin <gdelfin@meta.com>
This commit is contained in:
Giancarlo Delfin 2025-08-13 04:11:28 -07:00 committed by GitHub
parent 3f52738dce
commit d94e3026de
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 178 additions and 55 deletions

View File

@ -1,6 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Optional
from unittest import mock
import pytest
@ -23,7 +24,11 @@ eagle_dir = "yuhuili/EAGLE-LLaMA3.1-Instruct-8B"
eagle3_dir = "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B"
def _create_proposer(method: str, k: int) -> EagleProposer:
def _create_proposer(
method: str,
num_speculative_tokens: int,
speculative_token_tree: Optional[list[tuple[int]]] = None,
) -> EagleProposer:
model_config = ModelConfig(model=model_dir,
runner="generate",
max_model_len=100)
@ -31,12 +36,18 @@ def _create_proposer(method: str, k: int) -> EagleProposer:
# Choose model directory based on method
draft_model_dir = eagle_dir if method == "eagle" else eagle3_dir
spec_token_tree_str = None
if speculative_token_tree is not None:
assert num_speculative_tokens == len(speculative_token_tree)
spec_token_tree_str = str(speculative_token_tree)
speculative_config = SpeculativeConfig(
target_model_config=model_config,
target_parallel_config=ParallelConfig(),
model=draft_model_dir,
method=method,
num_speculative_tokens=k,
num_speculative_tokens=num_speculative_tokens,
speculative_token_tree=spec_token_tree_str,
)
vllm_config = VllmConfig(
@ -189,7 +200,7 @@ def test_load_model(mock_get_model, mock_get_layers, mock_get_pp_group, method,
target_model.lm_head = mock.MagicMock()
# Create proposer using the helper function
proposer = _create_proposer(method, k=8)
proposer = _create_proposer(method, num_speculative_tokens=8)
# Call the method under test
proposer.load_model(target_model)
@ -226,6 +237,10 @@ def test_propose(method, attn_backend, num_speculative_tokens, monkeypatch):
pytest.skip("TRITON_ATTN_VLLM_V1 does not support "
"multi-token eagle spec decode on current platform")
if (attn_backend == "TREE_ATTN"):
pytest.skip("TREE_ATTN is tested separately in test_propose_tree"
"because it requires special input mocking.")
if attn_backend == "FLASH_ATTN_VLLM_V1" and current_platform.is_rocm():
monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1")
@ -378,3 +393,142 @@ def test_propose(method, attn_backend, num_speculative_tokens, monkeypatch):
# Verify all tokens match our expectations
assert torch.equal(result, expected_tokens)
@pytest.mark.parametrize(
"spec_token_tree",
[
[(0, )], # A single token
[(0, ), (0, 0), (0, 0, 0)], # Chain
[(0, ), (1, ), (2, )], # Parallel
[(0, ), (1, ), (2, ), (0, 0), (0, 1), (1, 0), (1, 1), (2, 0),
(2, 1)], # Tree
])
def test_propose_tree(spec_token_tree):
# Get GPU device.
device = torch.device(current_platform.device_type)
# Setup test parameters.
batch_size = 2
seq_len_1 = 5
seq_len_2 = 3
total_tokens = seq_len_1 + seq_len_2
vocab_size = 100
seq_lens = [seq_len_1, seq_len_2]
num_speculative_tokens = len(spec_token_tree)
# Create proposer first so we can use its actual hidden_size.
proposer = _create_proposer("eagle",
num_speculative_tokens,
speculative_token_tree=spec_token_tree)
# Get the hidden_size from the proposer to ensure consistency.
hidden_size = proposer.hidden_size
# Helper to create deterministic logits that will produce specific tokens
def create_deterministic_logits(token_ids, k: int):
logits = torch.full((batch_size, vocab_size), -100.0, device=device)
for i, token_id in enumerate(token_ids):
# Assign decreasing values to the k, consecutive, tokens.
for j in range(k):
logits[i, token_id + j] = 100.0 - j
return logits
# Mock a model that returns deterministic logits.
base_token_ids = torch.tensor([42, 60], dtype=torch.int64, device=device)
# Skip loading the model and replace it with a mock that returns
# deterministic outputs.
model_mock = mock.MagicMock()
# Mock the model forward calls.
forward_returns = [(torch.zeros(total_tokens, hidden_size, device=device),
torch.zeros(total_tokens, hidden_size, device=device))]
for cu_num_drafts in proposer.cu_drafts_per_level:
h_logits = torch.zeros(batch_size * cu_num_drafts,
hidden_size,
device=device)
h_states = torch.zeros(batch_size * cu_num_drafts,
hidden_size,
device=device)
forward_returns.append((h_logits, h_states))
model_mock.side_effect = forward_returns
# Mock the compute_logits calls.
cu_num_drafts_tensor = torch.tensor([0] + proposer.cu_drafts_per_level,
dtype=torch.int32,
device=device)
logits_returns = []
for level, num_children in enumerate(proposer.child_drafts_per_level):
token_ids = base_token_ids + cu_num_drafts_tensor[level]
level_num_drafts = cu_num_drafts_tensor[
level + 1] - cu_num_drafts_tensor[level]
level_logits = []
for i in range(level_num_drafts // num_children):
level_logits.append(
create_deterministic_logits(token_ids + i * num_children,
num_children))
logits_returns.append(torch.stack(level_logits, dim=1))
model_mock.compute_logits.side_effect = logits_returns
# Assign the mock to the proposer
proposer.model = model_mock
# Assign draft attn_layer_names since load_model is not invoked
proposer.attn_layer_names = ["layer.0"]
# Get the tree attention metadata builder.
attn_metadata_builder_cls, _ = get_attention_backend(_Backend.TREE_ATTN)
attn_metadata_builder = attn_metadata_builder_cls(
kv_cache_spec=create_standard_kv_cache_spec(proposer.vllm_config),
layer_names=proposer.attn_layer_names,
vllm_config=proposer.vllm_config,
device=device,
)
# Mock runner for attention metadata building.
proposer.runner = mock.MagicMock()
proposer.runner.attn_groups.append([mock.MagicMock()])
proposer.runner.attn_groups[0][0].metadata_builder = attn_metadata_builder
# Setup inputs for the proposer.
target_token_ids = torch.randint(0,
vocab_size, (total_tokens, ),
device=device)
target_positions = torch.cat([
torch.arange(seq_len_1, device=device),
torch.arange(seq_len_2, device=device)
])
target_hidden_states = torch.randn(total_tokens,
hidden_size,
device=device)
next_token_ids = torch.randint(0,
vocab_size, (batch_size, ),
dtype=torch.int32,
device=device)
batch_spec = BatchSpec(
seq_lens=seq_lens,
query_lens=seq_lens,
)
common_attn_metadata = create_common_attn_metadata(
batch_spec,
block_size=16,
device=device,
)
sampling_metadata = mock.MagicMock()
# Propose draft tokens.
result = proposer.propose(target_token_ids=target_token_ids,
target_positions=target_positions,
target_hidden_states=target_hidden_states,
next_token_ids=next_token_ids,
common_attn_metadata=common_attn_metadata,
sampling_metadata=sampling_metadata)
assert result.shape == (batch_size, num_speculative_tokens)
# The tokens are expected to be consecutive integers starting
# from the base token IDs.
expected_tokens = base_token_ids[:, None] + torch.arange(
num_speculative_tokens, dtype=torch.int64, device=device)
# Verify that the draft tokens match our expectations.
assert torch.equal(result, expected_tokens)

View File

@ -39,12 +39,6 @@ def test_eagle_max_len(monkeypatch: pytest.MonkeyPatch,
num_speculative_tokens: int, attn_backend: str):
with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "1")
if attn_backend == "TREE_ATTN" and num_speculative_tokens > 1:
# TREE_ATTN fails the test with multi-token spec decode
# TODO: Investigate why
pytest.skip("TREE_ATTN fails the test")
m.setenv("VLLM_ATTENTION_BACKEND", attn_backend)
if (attn_backend == "TRITON_ATTN_VLLM_V1"

View File

@ -236,9 +236,9 @@ class TreeAttentionMetadataBuilder(
# Use prefill for drafting at the root level.
self.tree_attn_bias = torch.empty(0)
else:
# Slice the tree attention bias for drafting.
query_len = common_attn_metadata.max_query_len
start, end = draft_index, draft_index + query_len
# Slice the tree attention bias for drafting. Exclude
# the root level.
start, end = 1, 1 + common_attn_metadata.max_query_len
self.tree_attn_bias = self.tree_attn_bias[start:end,
start:end].contiguous()

View File

@ -113,13 +113,6 @@ class EagleProposer:
num_drafts_per_level[level])
self.child_drafts_per_level.append(num_drafts_per_level[level] //
num_drafts_per_level[level - 1])
# Find the first level where the tree branches off into one or more
# children.
self.first_branching_level = None
for level in range(tree_depth):
if self.cu_drafts_per_level[level] > level + 1:
self.first_branching_level = level
break
# Precompute draft position offsets in flattened tree.
self.tree_draft_pos_offsets = torch.arange(
1,
@ -209,11 +202,10 @@ class EagleProposer:
logits = self.model.compute_logits(sample_hidden_states, None)
positions = target_positions[last_token_indices]
hidden_states = hidden_states[last_token_indices]
if self.first_branching_level == 0:
# Branching has occurred at the root level. Draft using tree
# attention.
if isinstance(attn_metadata, TreeAttentionMetadata):
# Draft using tree attention.
draft_token_ids_list = self.propose_tree(
tree_root_level=0,
batch_size=batch_size,
logits=logits,
positions=positions,
@ -242,11 +234,10 @@ class EagleProposer:
(TritonAttentionMetadata, AiterFlashAttentionMetadata,
FlashAttentionMetadata))
else:
# Currently, only FlashAttention and TreeAttention support
# multi-token eagle spec decode. This is because the code below
# makes assumptions about attn_metadata attributes available.
assert isinstance(attn_metadata,
(FlashAttentionMetadata, TreeAttentionMetadata))
# Currently, only FlashAttention supports multi-token eagle spec
# decode. This is because the code below makes assumptions about
# attn_metadata attributes available.
assert isinstance(attn_metadata, FlashAttentionMetadata)
# Generate the remaining draft tokens.
draft_token_ids_list = [draft_token_ids]
@ -259,7 +250,7 @@ class EagleProposer:
attn_metadata.num_actual_tokens = batch_size
attn_metadata.max_query_len = 1
attn_metadata.query_start_loc = self.arange[:batch_size + 1]
for token_index in range(self.num_speculative_tokens - 1):
for _ in range(self.num_speculative_tokens - 1):
# Update the inputs.
# cast to int32 is crucial when eagle model is compiled.
# tensor.argmax() returns int64 by default.
@ -327,21 +318,6 @@ class EagleProposer:
hidden_states = hidden_states[:batch_size]
logits = self.model.compute_logits(last_hidden_states[:batch_size],
None)
if self.first_branching_level == token_index + 1:
# Branching has occurred. The remaining tokens are drafted
# using tree attention.
draft_token_ids_list += self.propose_tree(
tree_root_level=token_index + 1,
batch_size=batch_size,
logits=logits,
positions=positions,
hidden_states=hidden_states,
common_attn_metadata=common_attn_metadata,
)
# [batch_size, num_tree_tokens]
return torch.cat(draft_token_ids_list, dim=1)
draft_token_ids = logits.argmax(dim=-1)
draft_token_ids_list.append(draft_token_ids)
@ -351,7 +327,6 @@ class EagleProposer:
def propose_tree(
self,
tree_root_level: int,
batch_size: int,
# [num_tokens, vocab_size]
logits: torch.Tensor,
@ -366,10 +341,10 @@ class EagleProposer:
assert isinstance(tree_attn_metadata_builder,
TreeAttentionMetadataBuilder)
total_num_drafts = self.cu_drafts_per_level[tree_root_level]
total_num_drafts = self.cu_drafts_per_level[0]
level_num_drafts = total_num_drafts
# Sample a draft token for each child at the tree root level.
num_children = self.child_drafts_per_level[tree_root_level]
num_children = self.child_drafts_per_level[0]
if num_children == 1:
draft_token_ids = logits.argmax(dim=-1).view(batch_size, -1)
else:
@ -393,22 +368,23 @@ class EagleProposer:
positions.view(batch_size, -1) +
self.tree_draft_pos_offsets[:batch_size, :])
tree_depth = len(self.cu_drafts_per_level)
for level in range(tree_root_level, tree_depth - 1):
for level in range(tree_depth - 1):
# Get draft positions for RoPE.
draft_positions = positions + (level + 1)
exceeds_max_model_len = (positions +
total_num_drafts) >= self.max_model_len
# Mask out the position ids that exceed the max model length.
# Otherwise, we may get out-of-range error in RoPE.
clamped_draft_positions = torch.where(
draft_positions = torch.where(
exceeds_max_model_len,
0,
draft_positions,
)
).view(batch_size, -1)
if level_num_drafts > 1:
# Repeat the positions for each draft at this level.
draft_positions = clamped_draft_positions.repeat_interleave(
level_num_drafts).reshape(batch_size, -1)
draft_positions = draft_positions.repeat_interleave(
level_num_drafts, dim=1)
if num_children > 1:
# Repeat draft hidden states for each child.
@ -425,7 +401,7 @@ class EagleProposer:
# Build new attention metadata for the next level of drafts.
# This is necessary to support tree attention.
query_len = total_num_drafts - tree_root_level
query_len = total_num_drafts
common_attn_metadata = replace(
common_attn_metadata,
query_start_loc=query_len * self.arange[:batch_size + 1],
@ -435,7 +411,7 @@ class EagleProposer:
)
attn_metadata = tree_attn_metadata_builder.build_for_drafting(
common_attn_metadata=common_attn_metadata,
draft_index=tree_root_level + 1,
draft_index=level + 1,
)
# Apply new attention metadata to all layers.
@ -516,7 +492,6 @@ class EagleProposer:
level_num_drafts = self.cu_drafts_per_level[level +
1] - total_num_drafts
total_num_drafts = self.cu_drafts_per_level[level + 1]
return draft_token_ids_list
def prepare_inputs(