[V1][Spec Decode] Small refactors to improve eagle bookkeeping performance (#18424)

Signed-off-by: qizixi <qizixi@meta.com>
This commit is contained in:
qizixi 2025-05-23 23:51:22 -07:00 committed by GitHub
parent ec82c3e388
commit d55e446d13
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 21 additions and 19 deletions

View File

@ -100,8 +100,12 @@ def test_prepare_inputs():
dtype=torch.int32,
device=device)
# n1 + n2 + n3 - a - b -c
num_tokens = cu_target_query_lens[-1].item() - num_rejected_tokens.sum(
).item()
cu_num_tokens, token_indices = EagleProposer.prepare_inputs(
cu_target_query_lens, num_rejected_tokens)
cu_target_query_lens, num_rejected_tokens, num_tokens)
assert torch.equal(cu_num_tokens, expected_cu_num_tokens)
assert token_indices.shape[0] == expected_cu_num_tokens[-1].item()

View File

@ -271,6 +271,7 @@ class EagleProposer:
cu_target_query_lens: torch.Tensor,
# [batch_size]
num_rejected_tokens: torch.Tensor,
num_tokens: int,
) -> tuple[torch.Tensor, torch.Tensor]:
# cu_target_query_lens: [0, a, a + b, a + b + c]
# num_rejected_tokens: [n1, n2, n3]
@ -288,18 +289,13 @@ class EagleProposer:
# [a - n1, b - n2, c - n3] ->
# [0, a - n1, a + b - n1 - n2, a + b + c - n1 - n2 - n3]
cu_num_tokens = torch.empty_like(cu_target_query_lens)
cu_num_tokens = torch.zeros_like(cu_target_query_lens)
torch.cumsum(num_tokens_per_req, dim=0, out=cu_num_tokens[1:])
cu_num_tokens[0] = 0
# FIXME(woosuk): Avoid synchronization.
num_tokens = cu_num_tokens[-1].item()
token_indices = torch.empty(
num_tokens,
dtype=torch.int32,
device=cu_num_tokens.device,
device=cu_target_query_lens.device,
)
batch_size = num_rejected_tokens.shape[0]
BLOCK_SIZE = 1024
prepare_eagle_input_kernel[(batch_size, )](

View File

@ -34,8 +34,8 @@ from vllm.multimodal.utils import group_mm_inputs_by_modality
from vllm.sampling_params import SamplingType
from vllm.sequence import IntermediateTensors
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler,
GiB_bytes, LazyLoader, cdiv, check_use_alibi,
is_pin_memory_available)
GiB_bytes, LazyLoader, async_tensor_h2d, cdiv,
check_use_alibi, is_pin_memory_available)
from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
from vllm.v1.core.encoder_cache_manager import compute_encoder_budget
@ -281,7 +281,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
def _may_reorder_batch(self, scheduler_output: "SchedulerOutput") -> bool:
"""
Update the order of requests in the batch based on the attention
backend's needs. For example, some attention backends (namely MLA) may
backend's needs. For example, some attention backends (namely MLA) may
want to separate requests based on if the attention computation will be
compute-bound or memory-bound.
@ -1360,9 +1360,10 @@ class GPUModelRunner(LoRAModelRunnerMixin):
scheduler_output.num_scheduled_tokens[req_id])
next_token_id = req_state.get_token_id(seq_len)
next_token_ids.append(next_token_id)
next_token_ids = torch.tensor(next_token_ids,
dtype=torch.int32,
device=self.device)
next_token_ids = async_tensor_h2d(next_token_ids,
dtype=torch.int32,
target_device=self.device,
pin_memory=True)
eagle_attn_metadata = attn_metadata[self.drafter.attn_layer_name]
# NOTE: deepseek_mtp uses MLA which does not have `block_table`
@ -1390,14 +1391,16 @@ class GPUModelRunner(LoRAModelRunnerMixin):
n + 1 - len(valid_sampled_token_ids[i]) if n > 0 else 0
for i, n in enumerate(num_draft_tokens)
]
num_rejected_tokens = torch.tensor(
num_rejected_tokens_tensor = async_tensor_h2d(
num_rejected_tokens,
dtype=torch.int32,
device=self.device,
)
target_device=self.device,
pin_memory=True)
num_tokens = num_scheduled_tokens - sum(num_rejected_tokens)
cu_num_tokens, token_indices = self.drafter.prepare_inputs(
eagle_attn_metadata.query_start_loc,
num_rejected_tokens,
num_rejected_tokens_tensor,
num_tokens,
)
target_token_ids = self.input_ids[token_indices]
target_positions = positions[token_indices]
@ -1408,7 +1411,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
target_hidden_states = hidden_states[token_indices]
target_slot_mapping = eagle_attn_metadata.slot_mapping[
token_indices]
draft_token_ids = self.drafter.propose(
target_token_ids=target_token_ids,
target_positions=target_positions,