mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-30 21:07:09 +08:00
[V1][Spec Decode] Small refactors to improve eagle bookkeeping performance (#18424)
Signed-off-by: qizixi <qizixi@meta.com>
This commit is contained in:
parent
ec82c3e388
commit
d55e446d13
@ -100,8 +100,12 @@ def test_prepare_inputs():
|
|||||||
dtype=torch.int32,
|
dtype=torch.int32,
|
||||||
device=device)
|
device=device)
|
||||||
|
|
||||||
|
# n1 + n2 + n3 - a - b -c
|
||||||
|
num_tokens = cu_target_query_lens[-1].item() - num_rejected_tokens.sum(
|
||||||
|
).item()
|
||||||
|
|
||||||
cu_num_tokens, token_indices = EagleProposer.prepare_inputs(
|
cu_num_tokens, token_indices = EagleProposer.prepare_inputs(
|
||||||
cu_target_query_lens, num_rejected_tokens)
|
cu_target_query_lens, num_rejected_tokens, num_tokens)
|
||||||
|
|
||||||
assert torch.equal(cu_num_tokens, expected_cu_num_tokens)
|
assert torch.equal(cu_num_tokens, expected_cu_num_tokens)
|
||||||
assert token_indices.shape[0] == expected_cu_num_tokens[-1].item()
|
assert token_indices.shape[0] == expected_cu_num_tokens[-1].item()
|
||||||
|
|||||||
@ -271,6 +271,7 @@ class EagleProposer:
|
|||||||
cu_target_query_lens: torch.Tensor,
|
cu_target_query_lens: torch.Tensor,
|
||||||
# [batch_size]
|
# [batch_size]
|
||||||
num_rejected_tokens: torch.Tensor,
|
num_rejected_tokens: torch.Tensor,
|
||||||
|
num_tokens: int,
|
||||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
# cu_target_query_lens: [0, a, a + b, a + b + c]
|
# cu_target_query_lens: [0, a, a + b, a + b + c]
|
||||||
# num_rejected_tokens: [n1, n2, n3]
|
# num_rejected_tokens: [n1, n2, n3]
|
||||||
@ -288,18 +289,13 @@ class EagleProposer:
|
|||||||
|
|
||||||
# [a - n1, b - n2, c - n3] ->
|
# [a - n1, b - n2, c - n3] ->
|
||||||
# [0, a - n1, a + b - n1 - n2, a + b + c - n1 - n2 - n3]
|
# [0, a - n1, a + b - n1 - n2, a + b + c - n1 - n2 - n3]
|
||||||
cu_num_tokens = torch.empty_like(cu_target_query_lens)
|
cu_num_tokens = torch.zeros_like(cu_target_query_lens)
|
||||||
torch.cumsum(num_tokens_per_req, dim=0, out=cu_num_tokens[1:])
|
torch.cumsum(num_tokens_per_req, dim=0, out=cu_num_tokens[1:])
|
||||||
cu_num_tokens[0] = 0
|
|
||||||
|
|
||||||
# FIXME(woosuk): Avoid synchronization.
|
|
||||||
num_tokens = cu_num_tokens[-1].item()
|
|
||||||
token_indices = torch.empty(
|
token_indices = torch.empty(
|
||||||
num_tokens,
|
num_tokens,
|
||||||
dtype=torch.int32,
|
dtype=torch.int32,
|
||||||
device=cu_num_tokens.device,
|
device=cu_target_query_lens.device,
|
||||||
)
|
)
|
||||||
|
|
||||||
batch_size = num_rejected_tokens.shape[0]
|
batch_size = num_rejected_tokens.shape[0]
|
||||||
BLOCK_SIZE = 1024
|
BLOCK_SIZE = 1024
|
||||||
prepare_eagle_input_kernel[(batch_size, )](
|
prepare_eagle_input_kernel[(batch_size, )](
|
||||||
|
|||||||
@ -34,8 +34,8 @@ from vllm.multimodal.utils import group_mm_inputs_by_modality
|
|||||||
from vllm.sampling_params import SamplingType
|
from vllm.sampling_params import SamplingType
|
||||||
from vllm.sequence import IntermediateTensors
|
from vllm.sequence import IntermediateTensors
|
||||||
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler,
|
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler,
|
||||||
GiB_bytes, LazyLoader, cdiv, check_use_alibi,
|
GiB_bytes, LazyLoader, async_tensor_h2d, cdiv,
|
||||||
is_pin_memory_available)
|
check_use_alibi, is_pin_memory_available)
|
||||||
from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata
|
from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata
|
||||||
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
|
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
|
||||||
from vllm.v1.core.encoder_cache_manager import compute_encoder_budget
|
from vllm.v1.core.encoder_cache_manager import compute_encoder_budget
|
||||||
@ -1360,9 +1360,10 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
scheduler_output.num_scheduled_tokens[req_id])
|
scheduler_output.num_scheduled_tokens[req_id])
|
||||||
next_token_id = req_state.get_token_id(seq_len)
|
next_token_id = req_state.get_token_id(seq_len)
|
||||||
next_token_ids.append(next_token_id)
|
next_token_ids.append(next_token_id)
|
||||||
next_token_ids = torch.tensor(next_token_ids,
|
next_token_ids = async_tensor_h2d(next_token_ids,
|
||||||
dtype=torch.int32,
|
dtype=torch.int32,
|
||||||
device=self.device)
|
target_device=self.device,
|
||||||
|
pin_memory=True)
|
||||||
eagle_attn_metadata = attn_metadata[self.drafter.attn_layer_name]
|
eagle_attn_metadata = attn_metadata[self.drafter.attn_layer_name]
|
||||||
|
|
||||||
# NOTE: deepseek_mtp uses MLA which does not have `block_table`
|
# NOTE: deepseek_mtp uses MLA which does not have `block_table`
|
||||||
@ -1390,14 +1391,16 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
n + 1 - len(valid_sampled_token_ids[i]) if n > 0 else 0
|
n + 1 - len(valid_sampled_token_ids[i]) if n > 0 else 0
|
||||||
for i, n in enumerate(num_draft_tokens)
|
for i, n in enumerate(num_draft_tokens)
|
||||||
]
|
]
|
||||||
num_rejected_tokens = torch.tensor(
|
num_rejected_tokens_tensor = async_tensor_h2d(
|
||||||
num_rejected_tokens,
|
num_rejected_tokens,
|
||||||
dtype=torch.int32,
|
dtype=torch.int32,
|
||||||
device=self.device,
|
target_device=self.device,
|
||||||
)
|
pin_memory=True)
|
||||||
|
num_tokens = num_scheduled_tokens - sum(num_rejected_tokens)
|
||||||
cu_num_tokens, token_indices = self.drafter.prepare_inputs(
|
cu_num_tokens, token_indices = self.drafter.prepare_inputs(
|
||||||
eagle_attn_metadata.query_start_loc,
|
eagle_attn_metadata.query_start_loc,
|
||||||
num_rejected_tokens,
|
num_rejected_tokens_tensor,
|
||||||
|
num_tokens,
|
||||||
)
|
)
|
||||||
target_token_ids = self.input_ids[token_indices]
|
target_token_ids = self.input_ids[token_indices]
|
||||||
target_positions = positions[token_indices]
|
target_positions = positions[token_indices]
|
||||||
@ -1408,7 +1411,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
target_hidden_states = hidden_states[token_indices]
|
target_hidden_states = hidden_states[token_indices]
|
||||||
target_slot_mapping = eagle_attn_metadata.slot_mapping[
|
target_slot_mapping = eagle_attn_metadata.slot_mapping[
|
||||||
token_indices]
|
token_indices]
|
||||||
|
|
||||||
draft_token_ids = self.drafter.propose(
|
draft_token_ids = self.drafter.propose(
|
||||||
target_token_ids=target_token_ids,
|
target_token_ids=target_token_ids,
|
||||||
target_positions=target_positions,
|
target_positions=target_positions,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user