[V1][Spec Decode] Small refactors to improve eagle bookkeeping performance (#18424)

Signed-off-by: qizixi <qizixi@meta.com>
This commit is contained in:
qizixi 2025-05-23 23:51:22 -07:00 committed by GitHub
parent ec82c3e388
commit d55e446d13
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 21 additions and 19 deletions

View File

@ -100,8 +100,12 @@ def test_prepare_inputs():
dtype=torch.int32, dtype=torch.int32,
device=device) device=device)
# n1 + n2 + n3 - a - b -c
num_tokens = cu_target_query_lens[-1].item() - num_rejected_tokens.sum(
).item()
cu_num_tokens, token_indices = EagleProposer.prepare_inputs( cu_num_tokens, token_indices = EagleProposer.prepare_inputs(
cu_target_query_lens, num_rejected_tokens) cu_target_query_lens, num_rejected_tokens, num_tokens)
assert torch.equal(cu_num_tokens, expected_cu_num_tokens) assert torch.equal(cu_num_tokens, expected_cu_num_tokens)
assert token_indices.shape[0] == expected_cu_num_tokens[-1].item() assert token_indices.shape[0] == expected_cu_num_tokens[-1].item()

View File

@ -271,6 +271,7 @@ class EagleProposer:
cu_target_query_lens: torch.Tensor, cu_target_query_lens: torch.Tensor,
# [batch_size] # [batch_size]
num_rejected_tokens: torch.Tensor, num_rejected_tokens: torch.Tensor,
num_tokens: int,
) -> tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
# cu_target_query_lens: [0, a, a + b, a + b + c] # cu_target_query_lens: [0, a, a + b, a + b + c]
# num_rejected_tokens: [n1, n2, n3] # num_rejected_tokens: [n1, n2, n3]
@ -288,18 +289,13 @@ class EagleProposer:
# [a - n1, b - n2, c - n3] -> # [a - n1, b - n2, c - n3] ->
# [0, a - n1, a + b - n1 - n2, a + b + c - n1 - n2 - n3] # [0, a - n1, a + b - n1 - n2, a + b + c - n1 - n2 - n3]
cu_num_tokens = torch.empty_like(cu_target_query_lens) cu_num_tokens = torch.zeros_like(cu_target_query_lens)
torch.cumsum(num_tokens_per_req, dim=0, out=cu_num_tokens[1:]) torch.cumsum(num_tokens_per_req, dim=0, out=cu_num_tokens[1:])
cu_num_tokens[0] = 0
# FIXME(woosuk): Avoid synchronization.
num_tokens = cu_num_tokens[-1].item()
token_indices = torch.empty( token_indices = torch.empty(
num_tokens, num_tokens,
dtype=torch.int32, dtype=torch.int32,
device=cu_num_tokens.device, device=cu_target_query_lens.device,
) )
batch_size = num_rejected_tokens.shape[0] batch_size = num_rejected_tokens.shape[0]
BLOCK_SIZE = 1024 BLOCK_SIZE = 1024
prepare_eagle_input_kernel[(batch_size, )]( prepare_eagle_input_kernel[(batch_size, )](

View File

@ -34,8 +34,8 @@ from vllm.multimodal.utils import group_mm_inputs_by_modality
from vllm.sampling_params import SamplingType from vllm.sampling_params import SamplingType
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler, from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler,
GiB_bytes, LazyLoader, cdiv, check_use_alibi, GiB_bytes, LazyLoader, async_tensor_h2d, cdiv,
is_pin_memory_available) check_use_alibi, is_pin_memory_available)
from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata
from vllm.v1.attention.backends.utils import CommonAttentionMetadata from vllm.v1.attention.backends.utils import CommonAttentionMetadata
from vllm.v1.core.encoder_cache_manager import compute_encoder_budget from vllm.v1.core.encoder_cache_manager import compute_encoder_budget
@ -1360,9 +1360,10 @@ class GPUModelRunner(LoRAModelRunnerMixin):
scheduler_output.num_scheduled_tokens[req_id]) scheduler_output.num_scheduled_tokens[req_id])
next_token_id = req_state.get_token_id(seq_len) next_token_id = req_state.get_token_id(seq_len)
next_token_ids.append(next_token_id) next_token_ids.append(next_token_id)
next_token_ids = torch.tensor(next_token_ids, next_token_ids = async_tensor_h2d(next_token_ids,
dtype=torch.int32, dtype=torch.int32,
device=self.device) target_device=self.device,
pin_memory=True)
eagle_attn_metadata = attn_metadata[self.drafter.attn_layer_name] eagle_attn_metadata = attn_metadata[self.drafter.attn_layer_name]
# NOTE: deepseek_mtp uses MLA which does not have `block_table` # NOTE: deepseek_mtp uses MLA which does not have `block_table`
@ -1390,14 +1391,16 @@ class GPUModelRunner(LoRAModelRunnerMixin):
n + 1 - len(valid_sampled_token_ids[i]) if n > 0 else 0 n + 1 - len(valid_sampled_token_ids[i]) if n > 0 else 0
for i, n in enumerate(num_draft_tokens) for i, n in enumerate(num_draft_tokens)
] ]
num_rejected_tokens = torch.tensor( num_rejected_tokens_tensor = async_tensor_h2d(
num_rejected_tokens, num_rejected_tokens,
dtype=torch.int32, dtype=torch.int32,
device=self.device, target_device=self.device,
) pin_memory=True)
num_tokens = num_scheduled_tokens - sum(num_rejected_tokens)
cu_num_tokens, token_indices = self.drafter.prepare_inputs( cu_num_tokens, token_indices = self.drafter.prepare_inputs(
eagle_attn_metadata.query_start_loc, eagle_attn_metadata.query_start_loc,
num_rejected_tokens, num_rejected_tokens_tensor,
num_tokens,
) )
target_token_ids = self.input_ids[token_indices] target_token_ids = self.input_ids[token_indices]
target_positions = positions[token_indices] target_positions = positions[token_indices]
@ -1408,7 +1411,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
target_hidden_states = hidden_states[token_indices] target_hidden_states = hidden_states[token_indices]
target_slot_mapping = eagle_attn_metadata.slot_mapping[ target_slot_mapping = eagle_attn_metadata.slot_mapping[
token_indices] token_indices]
draft_token_ids = self.drafter.propose( draft_token_ids = self.drafter.propose(
target_token_ids=target_token_ids, target_token_ids=target_token_ids,
target_positions=target_positions, target_positions=target_positions,