diff --git a/tests/v1/spec_decode/test_eagle.py b/tests/v1/spec_decode/test_eagle.py index 2b4f8bd2a8b9..7b8445a0b287 100644 --- a/tests/v1/spec_decode/test_eagle.py +++ b/tests/v1/spec_decode/test_eagle.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from typing import Optional from unittest import mock import pytest @@ -23,7 +24,11 @@ eagle_dir = "yuhuili/EAGLE-LLaMA3.1-Instruct-8B" eagle3_dir = "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B" -def _create_proposer(method: str, k: int) -> EagleProposer: +def _create_proposer( + method: str, + num_speculative_tokens: int, + speculative_token_tree: Optional[list[tuple[int]]] = None, +) -> EagleProposer: model_config = ModelConfig(model=model_dir, runner="generate", max_model_len=100) @@ -31,12 +36,18 @@ def _create_proposer(method: str, k: int) -> EagleProposer: # Choose model directory based on method draft_model_dir = eagle_dir if method == "eagle" else eagle3_dir + spec_token_tree_str = None + if speculative_token_tree is not None: + assert num_speculative_tokens == len(speculative_token_tree) + spec_token_tree_str = str(speculative_token_tree) + speculative_config = SpeculativeConfig( target_model_config=model_config, target_parallel_config=ParallelConfig(), model=draft_model_dir, method=method, - num_speculative_tokens=k, + num_speculative_tokens=num_speculative_tokens, + speculative_token_tree=spec_token_tree_str, ) vllm_config = VllmConfig( @@ -189,7 +200,7 @@ def test_load_model(mock_get_model, mock_get_layers, mock_get_pp_group, method, target_model.lm_head = mock.MagicMock() # Create proposer using the helper function - proposer = _create_proposer(method, k=8) + proposer = _create_proposer(method, num_speculative_tokens=8) # Call the method under test proposer.load_model(target_model) @@ -226,6 +237,10 @@ def test_propose(method, attn_backend, num_speculative_tokens, monkeypatch): pytest.skip("TRITON_ATTN_VLLM_V1 does not support " "multi-token eagle spec decode on current platform") + if (attn_backend == "TREE_ATTN"): + pytest.skip("TREE_ATTN is tested separately in test_propose_tree" + "because it requires special input mocking.") + if attn_backend == "FLASH_ATTN_VLLM_V1" and current_platform.is_rocm(): monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1") @@ -378,3 +393,142 @@ def test_propose(method, attn_backend, num_speculative_tokens, monkeypatch): # Verify all tokens match our expectations assert torch.equal(result, expected_tokens) + + +@pytest.mark.parametrize( + "spec_token_tree", + [ + [(0, )], # A single token + [(0, ), (0, 0), (0, 0, 0)], # Chain + [(0, ), (1, ), (2, )], # Parallel + [(0, ), (1, ), (2, ), (0, 0), (0, 1), (1, 0), (1, 1), (2, 0), + (2, 1)], # Tree + ]) +def test_propose_tree(spec_token_tree): + # Get GPU device. + device = torch.device(current_platform.device_type) + + # Setup test parameters. + batch_size = 2 + seq_len_1 = 5 + seq_len_2 = 3 + total_tokens = seq_len_1 + seq_len_2 + vocab_size = 100 + seq_lens = [seq_len_1, seq_len_2] + num_speculative_tokens = len(spec_token_tree) + + # Create proposer first so we can use its actual hidden_size. + proposer = _create_proposer("eagle", + num_speculative_tokens, + speculative_token_tree=spec_token_tree) + # Get the hidden_size from the proposer to ensure consistency. + hidden_size = proposer.hidden_size + + # Helper to create deterministic logits that will produce specific tokens + def create_deterministic_logits(token_ids, k: int): + logits = torch.full((batch_size, vocab_size), -100.0, device=device) + for i, token_id in enumerate(token_ids): + # Assign decreasing values to the k, consecutive, tokens. + for j in range(k): + logits[i, token_id + j] = 100.0 - j + return logits + + # Mock a model that returns deterministic logits. + base_token_ids = torch.tensor([42, 60], dtype=torch.int64, device=device) + + # Skip loading the model and replace it with a mock that returns + # deterministic outputs. + model_mock = mock.MagicMock() + + # Mock the model forward calls. + forward_returns = [(torch.zeros(total_tokens, hidden_size, device=device), + torch.zeros(total_tokens, hidden_size, device=device))] + for cu_num_drafts in proposer.cu_drafts_per_level: + h_logits = torch.zeros(batch_size * cu_num_drafts, + hidden_size, + device=device) + h_states = torch.zeros(batch_size * cu_num_drafts, + hidden_size, + device=device) + forward_returns.append((h_logits, h_states)) + model_mock.side_effect = forward_returns + + # Mock the compute_logits calls. + cu_num_drafts_tensor = torch.tensor([0] + proposer.cu_drafts_per_level, + dtype=torch.int32, + device=device) + logits_returns = [] + for level, num_children in enumerate(proposer.child_drafts_per_level): + token_ids = base_token_ids + cu_num_drafts_tensor[level] + level_num_drafts = cu_num_drafts_tensor[ + level + 1] - cu_num_drafts_tensor[level] + level_logits = [] + for i in range(level_num_drafts // num_children): + level_logits.append( + create_deterministic_logits(token_ids + i * num_children, + num_children)) + logits_returns.append(torch.stack(level_logits, dim=1)) + model_mock.compute_logits.side_effect = logits_returns + + # Assign the mock to the proposer + proposer.model = model_mock + + # Assign draft attn_layer_names since load_model is not invoked + proposer.attn_layer_names = ["layer.0"] + + # Get the tree attention metadata builder. + attn_metadata_builder_cls, _ = get_attention_backend(_Backend.TREE_ATTN) + attn_metadata_builder = attn_metadata_builder_cls( + kv_cache_spec=create_standard_kv_cache_spec(proposer.vllm_config), + layer_names=proposer.attn_layer_names, + vllm_config=proposer.vllm_config, + device=device, + ) + + # Mock runner for attention metadata building. + proposer.runner = mock.MagicMock() + proposer.runner.attn_groups.append([mock.MagicMock()]) + proposer.runner.attn_groups[0][0].metadata_builder = attn_metadata_builder + + # Setup inputs for the proposer. + target_token_ids = torch.randint(0, + vocab_size, (total_tokens, ), + device=device) + target_positions = torch.cat([ + torch.arange(seq_len_1, device=device), + torch.arange(seq_len_2, device=device) + ]) + target_hidden_states = torch.randn(total_tokens, + hidden_size, + device=device) + next_token_ids = torch.randint(0, + vocab_size, (batch_size, ), + dtype=torch.int32, + device=device) + batch_spec = BatchSpec( + seq_lens=seq_lens, + query_lens=seq_lens, + ) + common_attn_metadata = create_common_attn_metadata( + batch_spec, + block_size=16, + device=device, + ) + sampling_metadata = mock.MagicMock() + + # Propose draft tokens. + result = proposer.propose(target_token_ids=target_token_ids, + target_positions=target_positions, + target_hidden_states=target_hidden_states, + next_token_ids=next_token_ids, + common_attn_metadata=common_attn_metadata, + sampling_metadata=sampling_metadata) + assert result.shape == (batch_size, num_speculative_tokens) + + # The tokens are expected to be consecutive integers starting + # from the base token IDs. + expected_tokens = base_token_ids[:, None] + torch.arange( + num_speculative_tokens, dtype=torch.int64, device=device) + + # Verify that the draft tokens match our expectations. + assert torch.equal(result, expected_tokens) diff --git a/tests/v1/spec_decode/test_max_len.py b/tests/v1/spec_decode/test_max_len.py index 01019b29e010..a5b10bb51866 100644 --- a/tests/v1/spec_decode/test_max_len.py +++ b/tests/v1/spec_decode/test_max_len.py @@ -39,12 +39,6 @@ def test_eagle_max_len(monkeypatch: pytest.MonkeyPatch, num_speculative_tokens: int, attn_backend: str): with monkeypatch.context() as m: m.setenv("VLLM_USE_V1", "1") - - if attn_backend == "TREE_ATTN" and num_speculative_tokens > 1: - # TREE_ATTN fails the test with multi-token spec decode - # TODO: Investigate why - pytest.skip("TREE_ATTN fails the test") - m.setenv("VLLM_ATTENTION_BACKEND", attn_backend) if (attn_backend == "TRITON_ATTN_VLLM_V1" diff --git a/vllm/v1/attention/backends/tree_attn.py b/vllm/v1/attention/backends/tree_attn.py index 3b53b039f1dc..5d10e9e26082 100644 --- a/vllm/v1/attention/backends/tree_attn.py +++ b/vllm/v1/attention/backends/tree_attn.py @@ -236,9 +236,9 @@ class TreeAttentionMetadataBuilder( # Use prefill for drafting at the root level. self.tree_attn_bias = torch.empty(0) else: - # Slice the tree attention bias for drafting. - query_len = common_attn_metadata.max_query_len - start, end = draft_index, draft_index + query_len + # Slice the tree attention bias for drafting. Exclude + # the root level. + start, end = 1, 1 + common_attn_metadata.max_query_len self.tree_attn_bias = self.tree_attn_bias[start:end, start:end].contiguous() diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index f75d76dd978f..a8a160a0f995 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -113,13 +113,6 @@ class EagleProposer: num_drafts_per_level[level]) self.child_drafts_per_level.append(num_drafts_per_level[level] // num_drafts_per_level[level - 1]) - # Find the first level where the tree branches off into one or more - # children. - self.first_branching_level = None - for level in range(tree_depth): - if self.cu_drafts_per_level[level] > level + 1: - self.first_branching_level = level - break # Precompute draft position offsets in flattened tree. self.tree_draft_pos_offsets = torch.arange( 1, @@ -209,11 +202,10 @@ class EagleProposer: logits = self.model.compute_logits(sample_hidden_states, None) positions = target_positions[last_token_indices] hidden_states = hidden_states[last_token_indices] - if self.first_branching_level == 0: - # Branching has occurred at the root level. Draft using tree - # attention. + + if isinstance(attn_metadata, TreeAttentionMetadata): + # Draft using tree attention. draft_token_ids_list = self.propose_tree( - tree_root_level=0, batch_size=batch_size, logits=logits, positions=positions, @@ -242,11 +234,10 @@ class EagleProposer: (TritonAttentionMetadata, AiterFlashAttentionMetadata, FlashAttentionMetadata)) else: - # Currently, only FlashAttention and TreeAttention support - # multi-token eagle spec decode. This is because the code below - # makes assumptions about attn_metadata attributes available. - assert isinstance(attn_metadata, - (FlashAttentionMetadata, TreeAttentionMetadata)) + # Currently, only FlashAttention supports multi-token eagle spec + # decode. This is because the code below makes assumptions about + # attn_metadata attributes available. + assert isinstance(attn_metadata, FlashAttentionMetadata) # Generate the remaining draft tokens. draft_token_ids_list = [draft_token_ids] @@ -259,7 +250,7 @@ class EagleProposer: attn_metadata.num_actual_tokens = batch_size attn_metadata.max_query_len = 1 attn_metadata.query_start_loc = self.arange[:batch_size + 1] - for token_index in range(self.num_speculative_tokens - 1): + for _ in range(self.num_speculative_tokens - 1): # Update the inputs. # cast to int32 is crucial when eagle model is compiled. # tensor.argmax() returns int64 by default. @@ -327,21 +318,6 @@ class EagleProposer: hidden_states = hidden_states[:batch_size] logits = self.model.compute_logits(last_hidden_states[:batch_size], None) - - if self.first_branching_level == token_index + 1: - # Branching has occurred. The remaining tokens are drafted - # using tree attention. - draft_token_ids_list += self.propose_tree( - tree_root_level=token_index + 1, - batch_size=batch_size, - logits=logits, - positions=positions, - hidden_states=hidden_states, - common_attn_metadata=common_attn_metadata, - ) - # [batch_size, num_tree_tokens] - return torch.cat(draft_token_ids_list, dim=1) - draft_token_ids = logits.argmax(dim=-1) draft_token_ids_list.append(draft_token_ids) @@ -351,7 +327,6 @@ class EagleProposer: def propose_tree( self, - tree_root_level: int, batch_size: int, # [num_tokens, vocab_size] logits: torch.Tensor, @@ -366,10 +341,10 @@ class EagleProposer: assert isinstance(tree_attn_metadata_builder, TreeAttentionMetadataBuilder) - total_num_drafts = self.cu_drafts_per_level[tree_root_level] + total_num_drafts = self.cu_drafts_per_level[0] level_num_drafts = total_num_drafts # Sample a draft token for each child at the tree root level. - num_children = self.child_drafts_per_level[tree_root_level] + num_children = self.child_drafts_per_level[0] if num_children == 1: draft_token_ids = logits.argmax(dim=-1).view(batch_size, -1) else: @@ -393,22 +368,23 @@ class EagleProposer: positions.view(batch_size, -1) + self.tree_draft_pos_offsets[:batch_size, :]) tree_depth = len(self.cu_drafts_per_level) - for level in range(tree_root_level, tree_depth - 1): + for level in range(tree_depth - 1): # Get draft positions for RoPE. draft_positions = positions + (level + 1) exceeds_max_model_len = (positions + total_num_drafts) >= self.max_model_len # Mask out the position ids that exceed the max model length. # Otherwise, we may get out-of-range error in RoPE. - clamped_draft_positions = torch.where( + draft_positions = torch.where( exceeds_max_model_len, 0, draft_positions, - ) + ).view(batch_size, -1) + if level_num_drafts > 1: # Repeat the positions for each draft at this level. - draft_positions = clamped_draft_positions.repeat_interleave( - level_num_drafts).reshape(batch_size, -1) + draft_positions = draft_positions.repeat_interleave( + level_num_drafts, dim=1) if num_children > 1: # Repeat draft hidden states for each child. @@ -425,7 +401,7 @@ class EagleProposer: # Build new attention metadata for the next level of drafts. # This is necessary to support tree attention. - query_len = total_num_drafts - tree_root_level + query_len = total_num_drafts common_attn_metadata = replace( common_attn_metadata, query_start_loc=query_len * self.arange[:batch_size + 1], @@ -435,7 +411,7 @@ class EagleProposer: ) attn_metadata = tree_attn_metadata_builder.build_for_drafting( common_attn_metadata=common_attn_metadata, - draft_index=tree_root_level + 1, + draft_index=level + 1, ) # Apply new attention metadata to all layers. @@ -516,7 +492,6 @@ class EagleProposer: level_num_drafts = self.cu_drafts_per_level[level + 1] - total_num_drafts total_num_drafts = self.cu_drafts_per_level[level + 1] - return draft_token_ids_list def prepare_inputs(