mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-13 22:15:01 +08:00
[V1] Add tree drafting tests for eagle spec decoding (#22705)
Signed-off-by: Giancarlo Delfin <gdelfin@meta.com>
This commit is contained in:
parent
3f52738dce
commit
d94e3026de
@ -1,6 +1,7 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
from typing import Optional
|
||||||
from unittest import mock
|
from unittest import mock
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
@ -23,7 +24,11 @@ eagle_dir = "yuhuili/EAGLE-LLaMA3.1-Instruct-8B"
|
|||||||
eagle3_dir = "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B"
|
eagle3_dir = "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B"
|
||||||
|
|
||||||
|
|
||||||
def _create_proposer(method: str, k: int) -> EagleProposer:
|
def _create_proposer(
|
||||||
|
method: str,
|
||||||
|
num_speculative_tokens: int,
|
||||||
|
speculative_token_tree: Optional[list[tuple[int]]] = None,
|
||||||
|
) -> EagleProposer:
|
||||||
model_config = ModelConfig(model=model_dir,
|
model_config = ModelConfig(model=model_dir,
|
||||||
runner="generate",
|
runner="generate",
|
||||||
max_model_len=100)
|
max_model_len=100)
|
||||||
@ -31,12 +36,18 @@ def _create_proposer(method: str, k: int) -> EagleProposer:
|
|||||||
# Choose model directory based on method
|
# Choose model directory based on method
|
||||||
draft_model_dir = eagle_dir if method == "eagle" else eagle3_dir
|
draft_model_dir = eagle_dir if method == "eagle" else eagle3_dir
|
||||||
|
|
||||||
|
spec_token_tree_str = None
|
||||||
|
if speculative_token_tree is not None:
|
||||||
|
assert num_speculative_tokens == len(speculative_token_tree)
|
||||||
|
spec_token_tree_str = str(speculative_token_tree)
|
||||||
|
|
||||||
speculative_config = SpeculativeConfig(
|
speculative_config = SpeculativeConfig(
|
||||||
target_model_config=model_config,
|
target_model_config=model_config,
|
||||||
target_parallel_config=ParallelConfig(),
|
target_parallel_config=ParallelConfig(),
|
||||||
model=draft_model_dir,
|
model=draft_model_dir,
|
||||||
method=method,
|
method=method,
|
||||||
num_speculative_tokens=k,
|
num_speculative_tokens=num_speculative_tokens,
|
||||||
|
speculative_token_tree=spec_token_tree_str,
|
||||||
)
|
)
|
||||||
|
|
||||||
vllm_config = VllmConfig(
|
vllm_config = VllmConfig(
|
||||||
@ -189,7 +200,7 @@ def test_load_model(mock_get_model, mock_get_layers, mock_get_pp_group, method,
|
|||||||
target_model.lm_head = mock.MagicMock()
|
target_model.lm_head = mock.MagicMock()
|
||||||
|
|
||||||
# Create proposer using the helper function
|
# Create proposer using the helper function
|
||||||
proposer = _create_proposer(method, k=8)
|
proposer = _create_proposer(method, num_speculative_tokens=8)
|
||||||
|
|
||||||
# Call the method under test
|
# Call the method under test
|
||||||
proposer.load_model(target_model)
|
proposer.load_model(target_model)
|
||||||
@ -226,6 +237,10 @@ def test_propose(method, attn_backend, num_speculative_tokens, monkeypatch):
|
|||||||
pytest.skip("TRITON_ATTN_VLLM_V1 does not support "
|
pytest.skip("TRITON_ATTN_VLLM_V1 does not support "
|
||||||
"multi-token eagle spec decode on current platform")
|
"multi-token eagle spec decode on current platform")
|
||||||
|
|
||||||
|
if (attn_backend == "TREE_ATTN"):
|
||||||
|
pytest.skip("TREE_ATTN is tested separately in test_propose_tree"
|
||||||
|
"because it requires special input mocking.")
|
||||||
|
|
||||||
if attn_backend == "FLASH_ATTN_VLLM_V1" and current_platform.is_rocm():
|
if attn_backend == "FLASH_ATTN_VLLM_V1" and current_platform.is_rocm():
|
||||||
monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1")
|
monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1")
|
||||||
|
|
||||||
@ -378,3 +393,142 @@ def test_propose(method, attn_backend, num_speculative_tokens, monkeypatch):
|
|||||||
|
|
||||||
# Verify all tokens match our expectations
|
# Verify all tokens match our expectations
|
||||||
assert torch.equal(result, expected_tokens)
|
assert torch.equal(result, expected_tokens)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"spec_token_tree",
|
||||||
|
[
|
||||||
|
[(0, )], # A single token
|
||||||
|
[(0, ), (0, 0), (0, 0, 0)], # Chain
|
||||||
|
[(0, ), (1, ), (2, )], # Parallel
|
||||||
|
[(0, ), (1, ), (2, ), (0, 0), (0, 1), (1, 0), (1, 1), (2, 0),
|
||||||
|
(2, 1)], # Tree
|
||||||
|
])
|
||||||
|
def test_propose_tree(spec_token_tree):
|
||||||
|
# Get GPU device.
|
||||||
|
device = torch.device(current_platform.device_type)
|
||||||
|
|
||||||
|
# Setup test parameters.
|
||||||
|
batch_size = 2
|
||||||
|
seq_len_1 = 5
|
||||||
|
seq_len_2 = 3
|
||||||
|
total_tokens = seq_len_1 + seq_len_2
|
||||||
|
vocab_size = 100
|
||||||
|
seq_lens = [seq_len_1, seq_len_2]
|
||||||
|
num_speculative_tokens = len(spec_token_tree)
|
||||||
|
|
||||||
|
# Create proposer first so we can use its actual hidden_size.
|
||||||
|
proposer = _create_proposer("eagle",
|
||||||
|
num_speculative_tokens,
|
||||||
|
speculative_token_tree=spec_token_tree)
|
||||||
|
# Get the hidden_size from the proposer to ensure consistency.
|
||||||
|
hidden_size = proposer.hidden_size
|
||||||
|
|
||||||
|
# Helper to create deterministic logits that will produce specific tokens
|
||||||
|
def create_deterministic_logits(token_ids, k: int):
|
||||||
|
logits = torch.full((batch_size, vocab_size), -100.0, device=device)
|
||||||
|
for i, token_id in enumerate(token_ids):
|
||||||
|
# Assign decreasing values to the k, consecutive, tokens.
|
||||||
|
for j in range(k):
|
||||||
|
logits[i, token_id + j] = 100.0 - j
|
||||||
|
return logits
|
||||||
|
|
||||||
|
# Mock a model that returns deterministic logits.
|
||||||
|
base_token_ids = torch.tensor([42, 60], dtype=torch.int64, device=device)
|
||||||
|
|
||||||
|
# Skip loading the model and replace it with a mock that returns
|
||||||
|
# deterministic outputs.
|
||||||
|
model_mock = mock.MagicMock()
|
||||||
|
|
||||||
|
# Mock the model forward calls.
|
||||||
|
forward_returns = [(torch.zeros(total_tokens, hidden_size, device=device),
|
||||||
|
torch.zeros(total_tokens, hidden_size, device=device))]
|
||||||
|
for cu_num_drafts in proposer.cu_drafts_per_level:
|
||||||
|
h_logits = torch.zeros(batch_size * cu_num_drafts,
|
||||||
|
hidden_size,
|
||||||
|
device=device)
|
||||||
|
h_states = torch.zeros(batch_size * cu_num_drafts,
|
||||||
|
hidden_size,
|
||||||
|
device=device)
|
||||||
|
forward_returns.append((h_logits, h_states))
|
||||||
|
model_mock.side_effect = forward_returns
|
||||||
|
|
||||||
|
# Mock the compute_logits calls.
|
||||||
|
cu_num_drafts_tensor = torch.tensor([0] + proposer.cu_drafts_per_level,
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=device)
|
||||||
|
logits_returns = []
|
||||||
|
for level, num_children in enumerate(proposer.child_drafts_per_level):
|
||||||
|
token_ids = base_token_ids + cu_num_drafts_tensor[level]
|
||||||
|
level_num_drafts = cu_num_drafts_tensor[
|
||||||
|
level + 1] - cu_num_drafts_tensor[level]
|
||||||
|
level_logits = []
|
||||||
|
for i in range(level_num_drafts // num_children):
|
||||||
|
level_logits.append(
|
||||||
|
create_deterministic_logits(token_ids + i * num_children,
|
||||||
|
num_children))
|
||||||
|
logits_returns.append(torch.stack(level_logits, dim=1))
|
||||||
|
model_mock.compute_logits.side_effect = logits_returns
|
||||||
|
|
||||||
|
# Assign the mock to the proposer
|
||||||
|
proposer.model = model_mock
|
||||||
|
|
||||||
|
# Assign draft attn_layer_names since load_model is not invoked
|
||||||
|
proposer.attn_layer_names = ["layer.0"]
|
||||||
|
|
||||||
|
# Get the tree attention metadata builder.
|
||||||
|
attn_metadata_builder_cls, _ = get_attention_backend(_Backend.TREE_ATTN)
|
||||||
|
attn_metadata_builder = attn_metadata_builder_cls(
|
||||||
|
kv_cache_spec=create_standard_kv_cache_spec(proposer.vllm_config),
|
||||||
|
layer_names=proposer.attn_layer_names,
|
||||||
|
vllm_config=proposer.vllm_config,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Mock runner for attention metadata building.
|
||||||
|
proposer.runner = mock.MagicMock()
|
||||||
|
proposer.runner.attn_groups.append([mock.MagicMock()])
|
||||||
|
proposer.runner.attn_groups[0][0].metadata_builder = attn_metadata_builder
|
||||||
|
|
||||||
|
# Setup inputs for the proposer.
|
||||||
|
target_token_ids = torch.randint(0,
|
||||||
|
vocab_size, (total_tokens, ),
|
||||||
|
device=device)
|
||||||
|
target_positions = torch.cat([
|
||||||
|
torch.arange(seq_len_1, device=device),
|
||||||
|
torch.arange(seq_len_2, device=device)
|
||||||
|
])
|
||||||
|
target_hidden_states = torch.randn(total_tokens,
|
||||||
|
hidden_size,
|
||||||
|
device=device)
|
||||||
|
next_token_ids = torch.randint(0,
|
||||||
|
vocab_size, (batch_size, ),
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=device)
|
||||||
|
batch_spec = BatchSpec(
|
||||||
|
seq_lens=seq_lens,
|
||||||
|
query_lens=seq_lens,
|
||||||
|
)
|
||||||
|
common_attn_metadata = create_common_attn_metadata(
|
||||||
|
batch_spec,
|
||||||
|
block_size=16,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
sampling_metadata = mock.MagicMock()
|
||||||
|
|
||||||
|
# Propose draft tokens.
|
||||||
|
result = proposer.propose(target_token_ids=target_token_ids,
|
||||||
|
target_positions=target_positions,
|
||||||
|
target_hidden_states=target_hidden_states,
|
||||||
|
next_token_ids=next_token_ids,
|
||||||
|
common_attn_metadata=common_attn_metadata,
|
||||||
|
sampling_metadata=sampling_metadata)
|
||||||
|
assert result.shape == (batch_size, num_speculative_tokens)
|
||||||
|
|
||||||
|
# The tokens are expected to be consecutive integers starting
|
||||||
|
# from the base token IDs.
|
||||||
|
expected_tokens = base_token_ids[:, None] + torch.arange(
|
||||||
|
num_speculative_tokens, dtype=torch.int64, device=device)
|
||||||
|
|
||||||
|
# Verify that the draft tokens match our expectations.
|
||||||
|
assert torch.equal(result, expected_tokens)
|
||||||
|
|||||||
@ -39,12 +39,6 @@ def test_eagle_max_len(monkeypatch: pytest.MonkeyPatch,
|
|||||||
num_speculative_tokens: int, attn_backend: str):
|
num_speculative_tokens: int, attn_backend: str):
|
||||||
with monkeypatch.context() as m:
|
with monkeypatch.context() as m:
|
||||||
m.setenv("VLLM_USE_V1", "1")
|
m.setenv("VLLM_USE_V1", "1")
|
||||||
|
|
||||||
if attn_backend == "TREE_ATTN" and num_speculative_tokens > 1:
|
|
||||||
# TREE_ATTN fails the test with multi-token spec decode
|
|
||||||
# TODO: Investigate why
|
|
||||||
pytest.skip("TREE_ATTN fails the test")
|
|
||||||
|
|
||||||
m.setenv("VLLM_ATTENTION_BACKEND", attn_backend)
|
m.setenv("VLLM_ATTENTION_BACKEND", attn_backend)
|
||||||
|
|
||||||
if (attn_backend == "TRITON_ATTN_VLLM_V1"
|
if (attn_backend == "TRITON_ATTN_VLLM_V1"
|
||||||
|
|||||||
@ -236,9 +236,9 @@ class TreeAttentionMetadataBuilder(
|
|||||||
# Use prefill for drafting at the root level.
|
# Use prefill for drafting at the root level.
|
||||||
self.tree_attn_bias = torch.empty(0)
|
self.tree_attn_bias = torch.empty(0)
|
||||||
else:
|
else:
|
||||||
# Slice the tree attention bias for drafting.
|
# Slice the tree attention bias for drafting. Exclude
|
||||||
query_len = common_attn_metadata.max_query_len
|
# the root level.
|
||||||
start, end = draft_index, draft_index + query_len
|
start, end = 1, 1 + common_attn_metadata.max_query_len
|
||||||
self.tree_attn_bias = self.tree_attn_bias[start:end,
|
self.tree_attn_bias = self.tree_attn_bias[start:end,
|
||||||
start:end].contiguous()
|
start:end].contiguous()
|
||||||
|
|
||||||
|
|||||||
@ -113,13 +113,6 @@ class EagleProposer:
|
|||||||
num_drafts_per_level[level])
|
num_drafts_per_level[level])
|
||||||
self.child_drafts_per_level.append(num_drafts_per_level[level] //
|
self.child_drafts_per_level.append(num_drafts_per_level[level] //
|
||||||
num_drafts_per_level[level - 1])
|
num_drafts_per_level[level - 1])
|
||||||
# Find the first level where the tree branches off into one or more
|
|
||||||
# children.
|
|
||||||
self.first_branching_level = None
|
|
||||||
for level in range(tree_depth):
|
|
||||||
if self.cu_drafts_per_level[level] > level + 1:
|
|
||||||
self.first_branching_level = level
|
|
||||||
break
|
|
||||||
# Precompute draft position offsets in flattened tree.
|
# Precompute draft position offsets in flattened tree.
|
||||||
self.tree_draft_pos_offsets = torch.arange(
|
self.tree_draft_pos_offsets = torch.arange(
|
||||||
1,
|
1,
|
||||||
@ -209,11 +202,10 @@ class EagleProposer:
|
|||||||
logits = self.model.compute_logits(sample_hidden_states, None)
|
logits = self.model.compute_logits(sample_hidden_states, None)
|
||||||
positions = target_positions[last_token_indices]
|
positions = target_positions[last_token_indices]
|
||||||
hidden_states = hidden_states[last_token_indices]
|
hidden_states = hidden_states[last_token_indices]
|
||||||
if self.first_branching_level == 0:
|
|
||||||
# Branching has occurred at the root level. Draft using tree
|
if isinstance(attn_metadata, TreeAttentionMetadata):
|
||||||
# attention.
|
# Draft using tree attention.
|
||||||
draft_token_ids_list = self.propose_tree(
|
draft_token_ids_list = self.propose_tree(
|
||||||
tree_root_level=0,
|
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
logits=logits,
|
logits=logits,
|
||||||
positions=positions,
|
positions=positions,
|
||||||
@ -242,11 +234,10 @@ class EagleProposer:
|
|||||||
(TritonAttentionMetadata, AiterFlashAttentionMetadata,
|
(TritonAttentionMetadata, AiterFlashAttentionMetadata,
|
||||||
FlashAttentionMetadata))
|
FlashAttentionMetadata))
|
||||||
else:
|
else:
|
||||||
# Currently, only FlashAttention and TreeAttention support
|
# Currently, only FlashAttention supports multi-token eagle spec
|
||||||
# multi-token eagle spec decode. This is because the code below
|
# decode. This is because the code below makes assumptions about
|
||||||
# makes assumptions about attn_metadata attributes available.
|
# attn_metadata attributes available.
|
||||||
assert isinstance(attn_metadata,
|
assert isinstance(attn_metadata, FlashAttentionMetadata)
|
||||||
(FlashAttentionMetadata, TreeAttentionMetadata))
|
|
||||||
|
|
||||||
# Generate the remaining draft tokens.
|
# Generate the remaining draft tokens.
|
||||||
draft_token_ids_list = [draft_token_ids]
|
draft_token_ids_list = [draft_token_ids]
|
||||||
@ -259,7 +250,7 @@ class EagleProposer:
|
|||||||
attn_metadata.num_actual_tokens = batch_size
|
attn_metadata.num_actual_tokens = batch_size
|
||||||
attn_metadata.max_query_len = 1
|
attn_metadata.max_query_len = 1
|
||||||
attn_metadata.query_start_loc = self.arange[:batch_size + 1]
|
attn_metadata.query_start_loc = self.arange[:batch_size + 1]
|
||||||
for token_index in range(self.num_speculative_tokens - 1):
|
for _ in range(self.num_speculative_tokens - 1):
|
||||||
# Update the inputs.
|
# Update the inputs.
|
||||||
# cast to int32 is crucial when eagle model is compiled.
|
# cast to int32 is crucial when eagle model is compiled.
|
||||||
# tensor.argmax() returns int64 by default.
|
# tensor.argmax() returns int64 by default.
|
||||||
@ -327,21 +318,6 @@ class EagleProposer:
|
|||||||
hidden_states = hidden_states[:batch_size]
|
hidden_states = hidden_states[:batch_size]
|
||||||
logits = self.model.compute_logits(last_hidden_states[:batch_size],
|
logits = self.model.compute_logits(last_hidden_states[:batch_size],
|
||||||
None)
|
None)
|
||||||
|
|
||||||
if self.first_branching_level == token_index + 1:
|
|
||||||
# Branching has occurred. The remaining tokens are drafted
|
|
||||||
# using tree attention.
|
|
||||||
draft_token_ids_list += self.propose_tree(
|
|
||||||
tree_root_level=token_index + 1,
|
|
||||||
batch_size=batch_size,
|
|
||||||
logits=logits,
|
|
||||||
positions=positions,
|
|
||||||
hidden_states=hidden_states,
|
|
||||||
common_attn_metadata=common_attn_metadata,
|
|
||||||
)
|
|
||||||
# [batch_size, num_tree_tokens]
|
|
||||||
return torch.cat(draft_token_ids_list, dim=1)
|
|
||||||
|
|
||||||
draft_token_ids = logits.argmax(dim=-1)
|
draft_token_ids = logits.argmax(dim=-1)
|
||||||
draft_token_ids_list.append(draft_token_ids)
|
draft_token_ids_list.append(draft_token_ids)
|
||||||
|
|
||||||
@ -351,7 +327,6 @@ class EagleProposer:
|
|||||||
|
|
||||||
def propose_tree(
|
def propose_tree(
|
||||||
self,
|
self,
|
||||||
tree_root_level: int,
|
|
||||||
batch_size: int,
|
batch_size: int,
|
||||||
# [num_tokens, vocab_size]
|
# [num_tokens, vocab_size]
|
||||||
logits: torch.Tensor,
|
logits: torch.Tensor,
|
||||||
@ -366,10 +341,10 @@ class EagleProposer:
|
|||||||
assert isinstance(tree_attn_metadata_builder,
|
assert isinstance(tree_attn_metadata_builder,
|
||||||
TreeAttentionMetadataBuilder)
|
TreeAttentionMetadataBuilder)
|
||||||
|
|
||||||
total_num_drafts = self.cu_drafts_per_level[tree_root_level]
|
total_num_drafts = self.cu_drafts_per_level[0]
|
||||||
level_num_drafts = total_num_drafts
|
level_num_drafts = total_num_drafts
|
||||||
# Sample a draft token for each child at the tree root level.
|
# Sample a draft token for each child at the tree root level.
|
||||||
num_children = self.child_drafts_per_level[tree_root_level]
|
num_children = self.child_drafts_per_level[0]
|
||||||
if num_children == 1:
|
if num_children == 1:
|
||||||
draft_token_ids = logits.argmax(dim=-1).view(batch_size, -1)
|
draft_token_ids = logits.argmax(dim=-1).view(batch_size, -1)
|
||||||
else:
|
else:
|
||||||
@ -393,22 +368,23 @@ class EagleProposer:
|
|||||||
positions.view(batch_size, -1) +
|
positions.view(batch_size, -1) +
|
||||||
self.tree_draft_pos_offsets[:batch_size, :])
|
self.tree_draft_pos_offsets[:batch_size, :])
|
||||||
tree_depth = len(self.cu_drafts_per_level)
|
tree_depth = len(self.cu_drafts_per_level)
|
||||||
for level in range(tree_root_level, tree_depth - 1):
|
for level in range(tree_depth - 1):
|
||||||
# Get draft positions for RoPE.
|
# Get draft positions for RoPE.
|
||||||
draft_positions = positions + (level + 1)
|
draft_positions = positions + (level + 1)
|
||||||
exceeds_max_model_len = (positions +
|
exceeds_max_model_len = (positions +
|
||||||
total_num_drafts) >= self.max_model_len
|
total_num_drafts) >= self.max_model_len
|
||||||
# Mask out the position ids that exceed the max model length.
|
# Mask out the position ids that exceed the max model length.
|
||||||
# Otherwise, we may get out-of-range error in RoPE.
|
# Otherwise, we may get out-of-range error in RoPE.
|
||||||
clamped_draft_positions = torch.where(
|
draft_positions = torch.where(
|
||||||
exceeds_max_model_len,
|
exceeds_max_model_len,
|
||||||
0,
|
0,
|
||||||
draft_positions,
|
draft_positions,
|
||||||
)
|
).view(batch_size, -1)
|
||||||
|
|
||||||
if level_num_drafts > 1:
|
if level_num_drafts > 1:
|
||||||
# Repeat the positions for each draft at this level.
|
# Repeat the positions for each draft at this level.
|
||||||
draft_positions = clamped_draft_positions.repeat_interleave(
|
draft_positions = draft_positions.repeat_interleave(
|
||||||
level_num_drafts).reshape(batch_size, -1)
|
level_num_drafts, dim=1)
|
||||||
|
|
||||||
if num_children > 1:
|
if num_children > 1:
|
||||||
# Repeat draft hidden states for each child.
|
# Repeat draft hidden states for each child.
|
||||||
@ -425,7 +401,7 @@ class EagleProposer:
|
|||||||
|
|
||||||
# Build new attention metadata for the next level of drafts.
|
# Build new attention metadata for the next level of drafts.
|
||||||
# This is necessary to support tree attention.
|
# This is necessary to support tree attention.
|
||||||
query_len = total_num_drafts - tree_root_level
|
query_len = total_num_drafts
|
||||||
common_attn_metadata = replace(
|
common_attn_metadata = replace(
|
||||||
common_attn_metadata,
|
common_attn_metadata,
|
||||||
query_start_loc=query_len * self.arange[:batch_size + 1],
|
query_start_loc=query_len * self.arange[:batch_size + 1],
|
||||||
@ -435,7 +411,7 @@ class EagleProposer:
|
|||||||
)
|
)
|
||||||
attn_metadata = tree_attn_metadata_builder.build_for_drafting(
|
attn_metadata = tree_attn_metadata_builder.build_for_drafting(
|
||||||
common_attn_metadata=common_attn_metadata,
|
common_attn_metadata=common_attn_metadata,
|
||||||
draft_index=tree_root_level + 1,
|
draft_index=level + 1,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Apply new attention metadata to all layers.
|
# Apply new attention metadata to all layers.
|
||||||
@ -516,7 +492,6 @@ class EagleProposer:
|
|||||||
level_num_drafts = self.cu_drafts_per_level[level +
|
level_num_drafts = self.cu_drafts_per_level[level +
|
||||||
1] - total_num_drafts
|
1] - total_num_drafts
|
||||||
total_num_drafts = self.cu_drafts_per_level[level + 1]
|
total_num_drafts = self.cu_drafts_per_level[level + 1]
|
||||||
|
|
||||||
return draft_token_ids_list
|
return draft_token_ids_list
|
||||||
|
|
||||||
def prepare_inputs(
|
def prepare_inputs(
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user