mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 06:55:01 +08:00
[V1] Add tree drafting tests for eagle spec decoding (#22705)
Signed-off-by: Giancarlo Delfin <gdelfin@meta.com>
This commit is contained in:
parent
3f52738dce
commit
d94e3026de
@ -1,6 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Optional
|
||||
from unittest import mock
|
||||
|
||||
import pytest
|
||||
@ -23,7 +24,11 @@ eagle_dir = "yuhuili/EAGLE-LLaMA3.1-Instruct-8B"
|
||||
eagle3_dir = "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B"
|
||||
|
||||
|
||||
def _create_proposer(method: str, k: int) -> EagleProposer:
|
||||
def _create_proposer(
|
||||
method: str,
|
||||
num_speculative_tokens: int,
|
||||
speculative_token_tree: Optional[list[tuple[int]]] = None,
|
||||
) -> EagleProposer:
|
||||
model_config = ModelConfig(model=model_dir,
|
||||
runner="generate",
|
||||
max_model_len=100)
|
||||
@ -31,12 +36,18 @@ def _create_proposer(method: str, k: int) -> EagleProposer:
|
||||
# Choose model directory based on method
|
||||
draft_model_dir = eagle_dir if method == "eagle" else eagle3_dir
|
||||
|
||||
spec_token_tree_str = None
|
||||
if speculative_token_tree is not None:
|
||||
assert num_speculative_tokens == len(speculative_token_tree)
|
||||
spec_token_tree_str = str(speculative_token_tree)
|
||||
|
||||
speculative_config = SpeculativeConfig(
|
||||
target_model_config=model_config,
|
||||
target_parallel_config=ParallelConfig(),
|
||||
model=draft_model_dir,
|
||||
method=method,
|
||||
num_speculative_tokens=k,
|
||||
num_speculative_tokens=num_speculative_tokens,
|
||||
speculative_token_tree=spec_token_tree_str,
|
||||
)
|
||||
|
||||
vllm_config = VllmConfig(
|
||||
@ -189,7 +200,7 @@ def test_load_model(mock_get_model, mock_get_layers, mock_get_pp_group, method,
|
||||
target_model.lm_head = mock.MagicMock()
|
||||
|
||||
# Create proposer using the helper function
|
||||
proposer = _create_proposer(method, k=8)
|
||||
proposer = _create_proposer(method, num_speculative_tokens=8)
|
||||
|
||||
# Call the method under test
|
||||
proposer.load_model(target_model)
|
||||
@ -226,6 +237,10 @@ def test_propose(method, attn_backend, num_speculative_tokens, monkeypatch):
|
||||
pytest.skip("TRITON_ATTN_VLLM_V1 does not support "
|
||||
"multi-token eagle spec decode on current platform")
|
||||
|
||||
if (attn_backend == "TREE_ATTN"):
|
||||
pytest.skip("TREE_ATTN is tested separately in test_propose_tree"
|
||||
"because it requires special input mocking.")
|
||||
|
||||
if attn_backend == "FLASH_ATTN_VLLM_V1" and current_platform.is_rocm():
|
||||
monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1")
|
||||
|
||||
@ -378,3 +393,142 @@ def test_propose(method, attn_backend, num_speculative_tokens, monkeypatch):
|
||||
|
||||
# Verify all tokens match our expectations
|
||||
assert torch.equal(result, expected_tokens)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"spec_token_tree",
|
||||
[
|
||||
[(0, )], # A single token
|
||||
[(0, ), (0, 0), (0, 0, 0)], # Chain
|
||||
[(0, ), (1, ), (2, )], # Parallel
|
||||
[(0, ), (1, ), (2, ), (0, 0), (0, 1), (1, 0), (1, 1), (2, 0),
|
||||
(2, 1)], # Tree
|
||||
])
|
||||
def test_propose_tree(spec_token_tree):
|
||||
# Get GPU device.
|
||||
device = torch.device(current_platform.device_type)
|
||||
|
||||
# Setup test parameters.
|
||||
batch_size = 2
|
||||
seq_len_1 = 5
|
||||
seq_len_2 = 3
|
||||
total_tokens = seq_len_1 + seq_len_2
|
||||
vocab_size = 100
|
||||
seq_lens = [seq_len_1, seq_len_2]
|
||||
num_speculative_tokens = len(spec_token_tree)
|
||||
|
||||
# Create proposer first so we can use its actual hidden_size.
|
||||
proposer = _create_proposer("eagle",
|
||||
num_speculative_tokens,
|
||||
speculative_token_tree=spec_token_tree)
|
||||
# Get the hidden_size from the proposer to ensure consistency.
|
||||
hidden_size = proposer.hidden_size
|
||||
|
||||
# Helper to create deterministic logits that will produce specific tokens
|
||||
def create_deterministic_logits(token_ids, k: int):
|
||||
logits = torch.full((batch_size, vocab_size), -100.0, device=device)
|
||||
for i, token_id in enumerate(token_ids):
|
||||
# Assign decreasing values to the k, consecutive, tokens.
|
||||
for j in range(k):
|
||||
logits[i, token_id + j] = 100.0 - j
|
||||
return logits
|
||||
|
||||
# Mock a model that returns deterministic logits.
|
||||
base_token_ids = torch.tensor([42, 60], dtype=torch.int64, device=device)
|
||||
|
||||
# Skip loading the model and replace it with a mock that returns
|
||||
# deterministic outputs.
|
||||
model_mock = mock.MagicMock()
|
||||
|
||||
# Mock the model forward calls.
|
||||
forward_returns = [(torch.zeros(total_tokens, hidden_size, device=device),
|
||||
torch.zeros(total_tokens, hidden_size, device=device))]
|
||||
for cu_num_drafts in proposer.cu_drafts_per_level:
|
||||
h_logits = torch.zeros(batch_size * cu_num_drafts,
|
||||
hidden_size,
|
||||
device=device)
|
||||
h_states = torch.zeros(batch_size * cu_num_drafts,
|
||||
hidden_size,
|
||||
device=device)
|
||||
forward_returns.append((h_logits, h_states))
|
||||
model_mock.side_effect = forward_returns
|
||||
|
||||
# Mock the compute_logits calls.
|
||||
cu_num_drafts_tensor = torch.tensor([0] + proposer.cu_drafts_per_level,
|
||||
dtype=torch.int32,
|
||||
device=device)
|
||||
logits_returns = []
|
||||
for level, num_children in enumerate(proposer.child_drafts_per_level):
|
||||
token_ids = base_token_ids + cu_num_drafts_tensor[level]
|
||||
level_num_drafts = cu_num_drafts_tensor[
|
||||
level + 1] - cu_num_drafts_tensor[level]
|
||||
level_logits = []
|
||||
for i in range(level_num_drafts // num_children):
|
||||
level_logits.append(
|
||||
create_deterministic_logits(token_ids + i * num_children,
|
||||
num_children))
|
||||
logits_returns.append(torch.stack(level_logits, dim=1))
|
||||
model_mock.compute_logits.side_effect = logits_returns
|
||||
|
||||
# Assign the mock to the proposer
|
||||
proposer.model = model_mock
|
||||
|
||||
# Assign draft attn_layer_names since load_model is not invoked
|
||||
proposer.attn_layer_names = ["layer.0"]
|
||||
|
||||
# Get the tree attention metadata builder.
|
||||
attn_metadata_builder_cls, _ = get_attention_backend(_Backend.TREE_ATTN)
|
||||
attn_metadata_builder = attn_metadata_builder_cls(
|
||||
kv_cache_spec=create_standard_kv_cache_spec(proposer.vllm_config),
|
||||
layer_names=proposer.attn_layer_names,
|
||||
vllm_config=proposer.vllm_config,
|
||||
device=device,
|
||||
)
|
||||
|
||||
# Mock runner for attention metadata building.
|
||||
proposer.runner = mock.MagicMock()
|
||||
proposer.runner.attn_groups.append([mock.MagicMock()])
|
||||
proposer.runner.attn_groups[0][0].metadata_builder = attn_metadata_builder
|
||||
|
||||
# Setup inputs for the proposer.
|
||||
target_token_ids = torch.randint(0,
|
||||
vocab_size, (total_tokens, ),
|
||||
device=device)
|
||||
target_positions = torch.cat([
|
||||
torch.arange(seq_len_1, device=device),
|
||||
torch.arange(seq_len_2, device=device)
|
||||
])
|
||||
target_hidden_states = torch.randn(total_tokens,
|
||||
hidden_size,
|
||||
device=device)
|
||||
next_token_ids = torch.randint(0,
|
||||
vocab_size, (batch_size, ),
|
||||
dtype=torch.int32,
|
||||
device=device)
|
||||
batch_spec = BatchSpec(
|
||||
seq_lens=seq_lens,
|
||||
query_lens=seq_lens,
|
||||
)
|
||||
common_attn_metadata = create_common_attn_metadata(
|
||||
batch_spec,
|
||||
block_size=16,
|
||||
device=device,
|
||||
)
|
||||
sampling_metadata = mock.MagicMock()
|
||||
|
||||
# Propose draft tokens.
|
||||
result = proposer.propose(target_token_ids=target_token_ids,
|
||||
target_positions=target_positions,
|
||||
target_hidden_states=target_hidden_states,
|
||||
next_token_ids=next_token_ids,
|
||||
common_attn_metadata=common_attn_metadata,
|
||||
sampling_metadata=sampling_metadata)
|
||||
assert result.shape == (batch_size, num_speculative_tokens)
|
||||
|
||||
# The tokens are expected to be consecutive integers starting
|
||||
# from the base token IDs.
|
||||
expected_tokens = base_token_ids[:, None] + torch.arange(
|
||||
num_speculative_tokens, dtype=torch.int64, device=device)
|
||||
|
||||
# Verify that the draft tokens match our expectations.
|
||||
assert torch.equal(result, expected_tokens)
|
||||
|
||||
@ -39,12 +39,6 @@ def test_eagle_max_len(monkeypatch: pytest.MonkeyPatch,
|
||||
num_speculative_tokens: int, attn_backend: str):
|
||||
with monkeypatch.context() as m:
|
||||
m.setenv("VLLM_USE_V1", "1")
|
||||
|
||||
if attn_backend == "TREE_ATTN" and num_speculative_tokens > 1:
|
||||
# TREE_ATTN fails the test with multi-token spec decode
|
||||
# TODO: Investigate why
|
||||
pytest.skip("TREE_ATTN fails the test")
|
||||
|
||||
m.setenv("VLLM_ATTENTION_BACKEND", attn_backend)
|
||||
|
||||
if (attn_backend == "TRITON_ATTN_VLLM_V1"
|
||||
|
||||
@ -236,9 +236,9 @@ class TreeAttentionMetadataBuilder(
|
||||
# Use prefill for drafting at the root level.
|
||||
self.tree_attn_bias = torch.empty(0)
|
||||
else:
|
||||
# Slice the tree attention bias for drafting.
|
||||
query_len = common_attn_metadata.max_query_len
|
||||
start, end = draft_index, draft_index + query_len
|
||||
# Slice the tree attention bias for drafting. Exclude
|
||||
# the root level.
|
||||
start, end = 1, 1 + common_attn_metadata.max_query_len
|
||||
self.tree_attn_bias = self.tree_attn_bias[start:end,
|
||||
start:end].contiguous()
|
||||
|
||||
|
||||
@ -113,13 +113,6 @@ class EagleProposer:
|
||||
num_drafts_per_level[level])
|
||||
self.child_drafts_per_level.append(num_drafts_per_level[level] //
|
||||
num_drafts_per_level[level - 1])
|
||||
# Find the first level where the tree branches off into one or more
|
||||
# children.
|
||||
self.first_branching_level = None
|
||||
for level in range(tree_depth):
|
||||
if self.cu_drafts_per_level[level] > level + 1:
|
||||
self.first_branching_level = level
|
||||
break
|
||||
# Precompute draft position offsets in flattened tree.
|
||||
self.tree_draft_pos_offsets = torch.arange(
|
||||
1,
|
||||
@ -209,11 +202,10 @@ class EagleProposer:
|
||||
logits = self.model.compute_logits(sample_hidden_states, None)
|
||||
positions = target_positions[last_token_indices]
|
||||
hidden_states = hidden_states[last_token_indices]
|
||||
if self.first_branching_level == 0:
|
||||
# Branching has occurred at the root level. Draft using tree
|
||||
# attention.
|
||||
|
||||
if isinstance(attn_metadata, TreeAttentionMetadata):
|
||||
# Draft using tree attention.
|
||||
draft_token_ids_list = self.propose_tree(
|
||||
tree_root_level=0,
|
||||
batch_size=batch_size,
|
||||
logits=logits,
|
||||
positions=positions,
|
||||
@ -242,11 +234,10 @@ class EagleProposer:
|
||||
(TritonAttentionMetadata, AiterFlashAttentionMetadata,
|
||||
FlashAttentionMetadata))
|
||||
else:
|
||||
# Currently, only FlashAttention and TreeAttention support
|
||||
# multi-token eagle spec decode. This is because the code below
|
||||
# makes assumptions about attn_metadata attributes available.
|
||||
assert isinstance(attn_metadata,
|
||||
(FlashAttentionMetadata, TreeAttentionMetadata))
|
||||
# Currently, only FlashAttention supports multi-token eagle spec
|
||||
# decode. This is because the code below makes assumptions about
|
||||
# attn_metadata attributes available.
|
||||
assert isinstance(attn_metadata, FlashAttentionMetadata)
|
||||
|
||||
# Generate the remaining draft tokens.
|
||||
draft_token_ids_list = [draft_token_ids]
|
||||
@ -259,7 +250,7 @@ class EagleProposer:
|
||||
attn_metadata.num_actual_tokens = batch_size
|
||||
attn_metadata.max_query_len = 1
|
||||
attn_metadata.query_start_loc = self.arange[:batch_size + 1]
|
||||
for token_index in range(self.num_speculative_tokens - 1):
|
||||
for _ in range(self.num_speculative_tokens - 1):
|
||||
# Update the inputs.
|
||||
# cast to int32 is crucial when eagle model is compiled.
|
||||
# tensor.argmax() returns int64 by default.
|
||||
@ -327,21 +318,6 @@ class EagleProposer:
|
||||
hidden_states = hidden_states[:batch_size]
|
||||
logits = self.model.compute_logits(last_hidden_states[:batch_size],
|
||||
None)
|
||||
|
||||
if self.first_branching_level == token_index + 1:
|
||||
# Branching has occurred. The remaining tokens are drafted
|
||||
# using tree attention.
|
||||
draft_token_ids_list += self.propose_tree(
|
||||
tree_root_level=token_index + 1,
|
||||
batch_size=batch_size,
|
||||
logits=logits,
|
||||
positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
common_attn_metadata=common_attn_metadata,
|
||||
)
|
||||
# [batch_size, num_tree_tokens]
|
||||
return torch.cat(draft_token_ids_list, dim=1)
|
||||
|
||||
draft_token_ids = logits.argmax(dim=-1)
|
||||
draft_token_ids_list.append(draft_token_ids)
|
||||
|
||||
@ -351,7 +327,6 @@ class EagleProposer:
|
||||
|
||||
def propose_tree(
|
||||
self,
|
||||
tree_root_level: int,
|
||||
batch_size: int,
|
||||
# [num_tokens, vocab_size]
|
||||
logits: torch.Tensor,
|
||||
@ -366,10 +341,10 @@ class EagleProposer:
|
||||
assert isinstance(tree_attn_metadata_builder,
|
||||
TreeAttentionMetadataBuilder)
|
||||
|
||||
total_num_drafts = self.cu_drafts_per_level[tree_root_level]
|
||||
total_num_drafts = self.cu_drafts_per_level[0]
|
||||
level_num_drafts = total_num_drafts
|
||||
# Sample a draft token for each child at the tree root level.
|
||||
num_children = self.child_drafts_per_level[tree_root_level]
|
||||
num_children = self.child_drafts_per_level[0]
|
||||
if num_children == 1:
|
||||
draft_token_ids = logits.argmax(dim=-1).view(batch_size, -1)
|
||||
else:
|
||||
@ -393,22 +368,23 @@ class EagleProposer:
|
||||
positions.view(batch_size, -1) +
|
||||
self.tree_draft_pos_offsets[:batch_size, :])
|
||||
tree_depth = len(self.cu_drafts_per_level)
|
||||
for level in range(tree_root_level, tree_depth - 1):
|
||||
for level in range(tree_depth - 1):
|
||||
# Get draft positions for RoPE.
|
||||
draft_positions = positions + (level + 1)
|
||||
exceeds_max_model_len = (positions +
|
||||
total_num_drafts) >= self.max_model_len
|
||||
# Mask out the position ids that exceed the max model length.
|
||||
# Otherwise, we may get out-of-range error in RoPE.
|
||||
clamped_draft_positions = torch.where(
|
||||
draft_positions = torch.where(
|
||||
exceeds_max_model_len,
|
||||
0,
|
||||
draft_positions,
|
||||
)
|
||||
).view(batch_size, -1)
|
||||
|
||||
if level_num_drafts > 1:
|
||||
# Repeat the positions for each draft at this level.
|
||||
draft_positions = clamped_draft_positions.repeat_interleave(
|
||||
level_num_drafts).reshape(batch_size, -1)
|
||||
draft_positions = draft_positions.repeat_interleave(
|
||||
level_num_drafts, dim=1)
|
||||
|
||||
if num_children > 1:
|
||||
# Repeat draft hidden states for each child.
|
||||
@ -425,7 +401,7 @@ class EagleProposer:
|
||||
|
||||
# Build new attention metadata for the next level of drafts.
|
||||
# This is necessary to support tree attention.
|
||||
query_len = total_num_drafts - tree_root_level
|
||||
query_len = total_num_drafts
|
||||
common_attn_metadata = replace(
|
||||
common_attn_metadata,
|
||||
query_start_loc=query_len * self.arange[:batch_size + 1],
|
||||
@ -435,7 +411,7 @@ class EagleProposer:
|
||||
)
|
||||
attn_metadata = tree_attn_metadata_builder.build_for_drafting(
|
||||
common_attn_metadata=common_attn_metadata,
|
||||
draft_index=tree_root_level + 1,
|
||||
draft_index=level + 1,
|
||||
)
|
||||
|
||||
# Apply new attention metadata to all layers.
|
||||
@ -516,7 +492,6 @@ class EagleProposer:
|
||||
level_num_drafts = self.cu_drafts_per_level[level +
|
||||
1] - total_num_drafts
|
||||
total_num_drafts = self.cu_drafts_per_level[level + 1]
|
||||
|
||||
return draft_token_ids_list
|
||||
|
||||
def prepare_inputs(
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user