mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-02 11:40:53 +08:00
[SpecDecode] Simplified alternative padded-speculation acceptance rate fix (#29845)
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
This commit is contained in:
parent
9586354053
commit
de71747655
@ -306,10 +306,16 @@ def test_prepare_inputs_padded():
|
||||
|
||||
proposer = _create_proposer("eagle", num_speculative_tokens)
|
||||
|
||||
output_metadata, token_indices_to_sample = proposer.prepare_inputs_padded(
|
||||
common_attn_metadata, spec_decode_metadata, valid_sampled_tokens_count
|
||||
output_metadata, token_indices_to_sample, num_rejected_tokens_gpu = (
|
||||
proposer.prepare_inputs_padded(
|
||||
common_attn_metadata, spec_decode_metadata, valid_sampled_tokens_count
|
||||
)
|
||||
)
|
||||
|
||||
# Verify num_rejected_tokens_gpu is calculated correctly
|
||||
expected_num_rejected = torch.tensor([1, 0, 2], dtype=torch.int32, device=device)
|
||||
assert torch.equal(num_rejected_tokens_gpu, expected_num_rejected)
|
||||
|
||||
assert output_metadata.max_query_len == 3
|
||||
assert torch.equal(output_metadata.query_start_loc, expected_query_start_loc)
|
||||
assert torch.equal(token_indices_to_sample, expected_token_indices_to_sample)
|
||||
|
||||
@ -564,6 +564,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
|
||||
self.dcp_rank = 0
|
||||
self.dcp_local_block_size = parallel_config.cp_kv_cache_interleave_size
|
||||
self.dcp_virtual_block_size = self.dcp_local_block_size * self.dcp_world_size
|
||||
self.cp_kv_cache_interleave_size = parallel_config.cp_kv_cache_interleave_size
|
||||
|
||||
# Don't try to access the runner on AMD
|
||||
if self.aot_schedule:
|
||||
@ -727,8 +728,8 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
|
||||
def _build_decode(
|
||||
self,
|
||||
block_table_tensor: torch.Tensor,
|
||||
seq_lens_cpu: torch.Tensor,
|
||||
seq_lens_device: torch.Tensor,
|
||||
max_seq_len: int,
|
||||
query_start_loc_cpu: torch.Tensor,
|
||||
query_start_loc_device: torch.Tensor,
|
||||
num_decode_tokens: int,
|
||||
@ -778,13 +779,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
|
||||
query_start_loc = common_attn_metadata.query_start_loc
|
||||
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu
|
||||
seq_lens = common_attn_metadata.seq_lens
|
||||
seq_lens_cpu = common_attn_metadata.seq_lens_cpu
|
||||
dcp_local_seq_lens = common_attn_metadata.dcp_local_seq_lens
|
||||
dcp_local_seq_lens_cpu = common_attn_metadata.dcp_local_seq_lens_cpu
|
||||
|
||||
query_seq_lens_cpu = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1]
|
||||
|
||||
num_computed_tokens_cpu = common_attn_metadata.seq_lens_cpu - query_seq_lens_cpu
|
||||
|
||||
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
|
||||
split_decodes_and_prefills(
|
||||
@ -799,6 +794,8 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
|
||||
|
||||
prefill_metadata = None
|
||||
if num_prefills > 0:
|
||||
num_computed_tokens_cpu = common_attn_metadata.num_computed_tokens_cpu
|
||||
|
||||
reqs_start = num_decodes # prefill_start
|
||||
|
||||
context_lens_cpu = num_computed_tokens_cpu[reqs_start:num_reqs]
|
||||
@ -995,13 +992,22 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
|
||||
dcp_tot_seq_lens_device = None
|
||||
if self.dcp_world_size > 1:
|
||||
dcp_tot_seq_lens_device = seq_lens[:num_decodes]
|
||||
seq_lens_cpu = dcp_local_seq_lens_cpu
|
||||
seq_lens = dcp_local_seq_lens
|
||||
|
||||
# After DCP distribution, the maximum number of tokens for any rank is
|
||||
# ceil(L / (N * I)) * I, where L is max_seq_len, N is dcp_world_size,
|
||||
# and I is cp_kv_cache_interleave_size.
|
||||
# This eliminates GPU->CPU sync while minimizing workspace
|
||||
# over-allocation.
|
||||
num_partitions = self.dcp_world_size * self.cp_kv_cache_interleave_size
|
||||
max_seq_len = (
|
||||
(max_seq_len + num_partitions - 1) // num_partitions
|
||||
) * self.cp_kv_cache_interleave_size
|
||||
|
||||
decode_metadata = self._build_decode(
|
||||
block_table_tensor=block_table_tensor[:num_decodes, ...],
|
||||
seq_lens_cpu=seq_lens_cpu[:num_decodes],
|
||||
seq_lens_device=seq_lens[:num_decodes],
|
||||
max_seq_len=max_seq_len,
|
||||
query_start_loc_cpu=query_start_loc_cpu[: num_decodes + 1],
|
||||
query_start_loc_device=query_start_loc[: num_decodes + 1],
|
||||
num_decode_tokens=num_decode_tokens,
|
||||
|
||||
@ -169,8 +169,8 @@ class FlashAttnMLAMetadataBuilder(MLACommonMetadataBuilder[FlashAttnMLAMetadata]
|
||||
def _build_decode(
|
||||
self,
|
||||
block_table_tensor: torch.Tensor,
|
||||
seq_lens_cpu: torch.Tensor,
|
||||
seq_lens_device: torch.Tensor,
|
||||
max_seq_len: int,
|
||||
query_start_loc_cpu: torch.Tensor,
|
||||
query_start_loc_device: torch.Tensor,
|
||||
num_decode_tokens: int,
|
||||
@ -178,7 +178,6 @@ class FlashAttnMLAMetadataBuilder(MLACommonMetadataBuilder[FlashAttnMLAMetadata]
|
||||
) -> FlashAttnMLADecodeMetadata:
|
||||
query_lens_cpu = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1]
|
||||
max_query_len = query_lens_cpu.max().item()
|
||||
max_seq_len = seq_lens_cpu.max().item()
|
||||
|
||||
# For Flash Attention MLA + full cudagraph
|
||||
max_num_splits = 0
|
||||
@ -193,7 +192,7 @@ class FlashAttnMLAMetadataBuilder(MLACommonMetadataBuilder[FlashAttnMLAMetadata]
|
||||
max_num_splits = 1
|
||||
|
||||
scheduler_metadata = self._schedule_decode(
|
||||
num_reqs=seq_lens_cpu.numel(),
|
||||
num_reqs=seq_lens_device.shape[0],
|
||||
cu_query_lens=query_start_loc_device,
|
||||
max_query_len=max_query_len,
|
||||
seqlens=seq_lens_device,
|
||||
|
||||
@ -143,8 +143,8 @@ class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]):
|
||||
def _build_decode(
|
||||
self,
|
||||
block_table_tensor: torch.Tensor,
|
||||
seq_lens_cpu: torch.Tensor,
|
||||
seq_lens_device: torch.Tensor,
|
||||
max_seq_len: int,
|
||||
query_start_loc_cpu: torch.Tensor,
|
||||
query_start_loc_device: torch.Tensor,
|
||||
num_decode_tokens: int,
|
||||
|
||||
@ -106,8 +106,8 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
|
||||
def _build_decode(
|
||||
self,
|
||||
block_table_tensor: torch.Tensor,
|
||||
seq_lens_cpu: torch.Tensor,
|
||||
seq_lens_device: torch.Tensor,
|
||||
max_seq_len: int,
|
||||
query_start_loc_cpu: torch.Tensor,
|
||||
query_start_loc_device: torch.Tensor,
|
||||
num_decode_tokens: int,
|
||||
|
||||
@ -236,6 +236,7 @@ class EagleProposer:
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
mm_embed_inputs: tuple[list[torch.Tensor], torch.Tensor] | None = None,
|
||||
num_rejected_tokens_gpu: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
num_tokens = target_token_ids.shape[0]
|
||||
batch_size = next_token_ids.shape[0]
|
||||
@ -414,6 +415,17 @@ class EagleProposer:
|
||||
common_attn_metadata.query_start_loc_cpu = torch.from_numpy(
|
||||
self.token_arange_np[: batch_size + 1]
|
||||
).clone()
|
||||
|
||||
# In padded drafter batch, we need to adjust the sequence lengths
|
||||
# to remove the "padding" (i.e. rejected tokens).
|
||||
# Only apply this adjustment when we have rejected tokens
|
||||
# (i.e., not the first proposal).
|
||||
if self.num_speculative_tokens > 1 and num_rejected_tokens_gpu is not None:
|
||||
common_attn_metadata.seq_lens -= num_rejected_tokens_gpu
|
||||
# Invalidate the CPU-side shadows to avoid H<>D sync.
|
||||
common_attn_metadata._seq_lens_cpu = None
|
||||
common_attn_metadata._num_computed_tokens_cpu = None
|
||||
|
||||
for token_index in range(self.num_speculative_tokens - 1):
|
||||
# Update the inputs.
|
||||
# cast to int32 is crucial when eagle model is compiled.
|
||||
@ -628,13 +640,14 @@ class EagleProposer:
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
spec_decode_metadata: SpecDecodeMetadata,
|
||||
valid_sampled_tokens_count: torch.Tensor,
|
||||
) -> tuple[CommonAttentionMetadata, torch.Tensor]:
|
||||
) -> tuple[CommonAttentionMetadata, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
This function is used to prepare the inputs for speculative decoding
|
||||
It updates the common_attn_metadata for speculative decoding,
|
||||
but does not consider the rejected tokens. Instead, all tokens
|
||||
are included as inputs to the speculator, with the rejected tokens
|
||||
used as padding and filtered out later by `token_indices_to_sample`.
|
||||
No blocking CPU operations should be introduced in this function.
|
||||
"""
|
||||
num_reqs = common_attn_metadata.num_reqs
|
||||
device = valid_sampled_tokens_count.device
|
||||
@ -642,14 +655,17 @@ class EagleProposer:
|
||||
token_indices_to_sample = torch.empty(
|
||||
(num_reqs,), dtype=torch.int32, device=device
|
||||
)
|
||||
num_rejected_tokens_gpu = torch.empty(
|
||||
(num_reqs,), dtype=torch.int32, device=device
|
||||
)
|
||||
|
||||
# Kernel grid: one program per request (row)
|
||||
grid = (num_reqs,)
|
||||
eagle_prepare_inputs_padded_kernel[grid](
|
||||
spec_decode_metadata.cu_num_draft_tokens,
|
||||
valid_sampled_tokens_count,
|
||||
common_attn_metadata.query_start_loc,
|
||||
token_indices_to_sample,
|
||||
num_rejected_tokens_gpu,
|
||||
num_reqs,
|
||||
)
|
||||
|
||||
@ -674,7 +690,11 @@ class EagleProposer:
|
||||
dcp_local_seq_lens=common_attn_metadata.dcp_local_seq_lens,
|
||||
)
|
||||
|
||||
return spec_common_attn_metadata, token_indices_to_sample
|
||||
return (
|
||||
spec_common_attn_metadata,
|
||||
token_indices_to_sample,
|
||||
num_rejected_tokens_gpu,
|
||||
)
|
||||
|
||||
def propose_tree(
|
||||
self,
|
||||
|
||||
@ -23,6 +23,7 @@ def eagle_prepare_inputs_padded_kernel(
|
||||
valid_sampled_tokens_count_ptr, # [num_reqs]
|
||||
query_start_loc_gpu_ptr, # [num_reqs + 1]
|
||||
token_indices_to_sample_ptr, # [num_reqs] (output)
|
||||
num_rejected_tokens_gpu_ptr, # [num_reqs] (output)
|
||||
num_reqs, # tl.int32
|
||||
):
|
||||
"""
|
||||
@ -56,6 +57,7 @@ def eagle_prepare_inputs_padded_kernel(
|
||||
|
||||
index_to_sample = q_last_tok_idx - num_rejected_tokens
|
||||
tl.store(token_indices_to_sample_ptr + req_idx, index_to_sample)
|
||||
tl.store(num_rejected_tokens_gpu_ptr + req_idx, num_rejected_tokens)
|
||||
|
||||
|
||||
@triton.jit
|
||||
|
||||
@ -3534,6 +3534,7 @@ class GPUModelRunner(
|
||||
next_token_ids, valid_sampled_tokens_count
|
||||
)
|
||||
|
||||
num_rejected_tokens_gpu = None
|
||||
if spec_decode_metadata is None:
|
||||
token_indices_to_sample = None
|
||||
# input_ids can be None for multimodal models.
|
||||
@ -3564,12 +3565,14 @@ class GPUModelRunner(
|
||||
else:
|
||||
target_hidden_states = hidden_states[token_indices]
|
||||
else:
|
||||
common_attn_metadata, token_indices_to_sample = (
|
||||
self.drafter.prepare_inputs_padded(
|
||||
common_attn_metadata,
|
||||
spec_decode_metadata,
|
||||
valid_sampled_tokens_count,
|
||||
)
|
||||
(
|
||||
common_attn_metadata,
|
||||
token_indices_to_sample,
|
||||
num_rejected_tokens_gpu,
|
||||
) = self.drafter.prepare_inputs_padded(
|
||||
common_attn_metadata,
|
||||
spec_decode_metadata,
|
||||
valid_sampled_tokens_count,
|
||||
)
|
||||
total_num_tokens = common_attn_metadata.num_actual_tokens
|
||||
# When padding the batch, token_indices is just a range
|
||||
@ -3600,6 +3603,7 @@ class GPUModelRunner(
|
||||
sampling_metadata=sampling_metadata,
|
||||
common_attn_metadata=common_attn_metadata,
|
||||
mm_embed_inputs=mm_embed_inputs,
|
||||
num_rejected_tokens_gpu=num_rejected_tokens_gpu,
|
||||
)
|
||||
|
||||
return draft_token_ids
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user