mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-23 18:24:29 +08:00
[V1][Spec Decode] Small refactors to improve eagle bookkeeping performance (#18424)
Signed-off-by: qizixi <qizixi@meta.com>
This commit is contained in:
parent
ec82c3e388
commit
d55e446d13
@ -100,8 +100,12 @@ def test_prepare_inputs():
|
||||
dtype=torch.int32,
|
||||
device=device)
|
||||
|
||||
# n1 + n2 + n3 - a - b -c
|
||||
num_tokens = cu_target_query_lens[-1].item() - num_rejected_tokens.sum(
|
||||
).item()
|
||||
|
||||
cu_num_tokens, token_indices = EagleProposer.prepare_inputs(
|
||||
cu_target_query_lens, num_rejected_tokens)
|
||||
cu_target_query_lens, num_rejected_tokens, num_tokens)
|
||||
|
||||
assert torch.equal(cu_num_tokens, expected_cu_num_tokens)
|
||||
assert token_indices.shape[0] == expected_cu_num_tokens[-1].item()
|
||||
|
||||
@ -271,6 +271,7 @@ class EagleProposer:
|
||||
cu_target_query_lens: torch.Tensor,
|
||||
# [batch_size]
|
||||
num_rejected_tokens: torch.Tensor,
|
||||
num_tokens: int,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
# cu_target_query_lens: [0, a, a + b, a + b + c]
|
||||
# num_rejected_tokens: [n1, n2, n3]
|
||||
@ -288,18 +289,13 @@ class EagleProposer:
|
||||
|
||||
# [a - n1, b - n2, c - n3] ->
|
||||
# [0, a - n1, a + b - n1 - n2, a + b + c - n1 - n2 - n3]
|
||||
cu_num_tokens = torch.empty_like(cu_target_query_lens)
|
||||
cu_num_tokens = torch.zeros_like(cu_target_query_lens)
|
||||
torch.cumsum(num_tokens_per_req, dim=0, out=cu_num_tokens[1:])
|
||||
cu_num_tokens[0] = 0
|
||||
|
||||
# FIXME(woosuk): Avoid synchronization.
|
||||
num_tokens = cu_num_tokens[-1].item()
|
||||
token_indices = torch.empty(
|
||||
num_tokens,
|
||||
dtype=torch.int32,
|
||||
device=cu_num_tokens.device,
|
||||
device=cu_target_query_lens.device,
|
||||
)
|
||||
|
||||
batch_size = num_rejected_tokens.shape[0]
|
||||
BLOCK_SIZE = 1024
|
||||
prepare_eagle_input_kernel[(batch_size, )](
|
||||
|
||||
@ -34,8 +34,8 @@ from vllm.multimodal.utils import group_mm_inputs_by_modality
|
||||
from vllm.sampling_params import SamplingType
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler,
|
||||
GiB_bytes, LazyLoader, cdiv, check_use_alibi,
|
||||
is_pin_memory_available)
|
||||
GiB_bytes, LazyLoader, async_tensor_h2d, cdiv,
|
||||
check_use_alibi, is_pin_memory_available)
|
||||
from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata
|
||||
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
|
||||
from vllm.v1.core.encoder_cache_manager import compute_encoder_budget
|
||||
@ -281,7 +281,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
def _may_reorder_batch(self, scheduler_output: "SchedulerOutput") -> bool:
|
||||
"""
|
||||
Update the order of requests in the batch based on the attention
|
||||
backend's needs. For example, some attention backends (namely MLA) may
|
||||
backend's needs. For example, some attention backends (namely MLA) may
|
||||
want to separate requests based on if the attention computation will be
|
||||
compute-bound or memory-bound.
|
||||
|
||||
@ -1360,9 +1360,10 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
scheduler_output.num_scheduled_tokens[req_id])
|
||||
next_token_id = req_state.get_token_id(seq_len)
|
||||
next_token_ids.append(next_token_id)
|
||||
next_token_ids = torch.tensor(next_token_ids,
|
||||
dtype=torch.int32,
|
||||
device=self.device)
|
||||
next_token_ids = async_tensor_h2d(next_token_ids,
|
||||
dtype=torch.int32,
|
||||
target_device=self.device,
|
||||
pin_memory=True)
|
||||
eagle_attn_metadata = attn_metadata[self.drafter.attn_layer_name]
|
||||
|
||||
# NOTE: deepseek_mtp uses MLA which does not have `block_table`
|
||||
@ -1390,14 +1391,16 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
n + 1 - len(valid_sampled_token_ids[i]) if n > 0 else 0
|
||||
for i, n in enumerate(num_draft_tokens)
|
||||
]
|
||||
num_rejected_tokens = torch.tensor(
|
||||
num_rejected_tokens_tensor = async_tensor_h2d(
|
||||
num_rejected_tokens,
|
||||
dtype=torch.int32,
|
||||
device=self.device,
|
||||
)
|
||||
target_device=self.device,
|
||||
pin_memory=True)
|
||||
num_tokens = num_scheduled_tokens - sum(num_rejected_tokens)
|
||||
cu_num_tokens, token_indices = self.drafter.prepare_inputs(
|
||||
eagle_attn_metadata.query_start_loc,
|
||||
num_rejected_tokens,
|
||||
num_rejected_tokens_tensor,
|
||||
num_tokens,
|
||||
)
|
||||
target_token_ids = self.input_ids[token_indices]
|
||||
target_positions = positions[token_indices]
|
||||
@ -1408,7 +1411,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
target_hidden_states = hidden_states[token_indices]
|
||||
target_slot_mapping = eagle_attn_metadata.slot_mapping[
|
||||
token_indices]
|
||||
|
||||
draft_token_ids = self.drafter.propose(
|
||||
target_token_ids=target_token_ids,
|
||||
target_positions=target_positions,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user