[SpecDecode] Simplified alternative padded-speculation acceptance rate fix (#29845)

Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
This commit is contained in:
Lucas Wilkinson 2025-12-22 16:06:10 -05:00 committed by GitHub
parent 9586354053
commit de71747655
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 62 additions and 25 deletions

View File

@ -306,10 +306,16 @@ def test_prepare_inputs_padded():
proposer = _create_proposer("eagle", num_speculative_tokens)
output_metadata, token_indices_to_sample = proposer.prepare_inputs_padded(
common_attn_metadata, spec_decode_metadata, valid_sampled_tokens_count
output_metadata, token_indices_to_sample, num_rejected_tokens_gpu = (
proposer.prepare_inputs_padded(
common_attn_metadata, spec_decode_metadata, valid_sampled_tokens_count
)
)
# Verify num_rejected_tokens_gpu is calculated correctly
expected_num_rejected = torch.tensor([1, 0, 2], dtype=torch.int32, device=device)
assert torch.equal(num_rejected_tokens_gpu, expected_num_rejected)
assert output_metadata.max_query_len == 3
assert torch.equal(output_metadata.query_start_loc, expected_query_start_loc)
assert torch.equal(token_indices_to_sample, expected_token_indices_to_sample)

View File

@ -564,6 +564,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
self.dcp_rank = 0
self.dcp_local_block_size = parallel_config.cp_kv_cache_interleave_size
self.dcp_virtual_block_size = self.dcp_local_block_size * self.dcp_world_size
self.cp_kv_cache_interleave_size = parallel_config.cp_kv_cache_interleave_size
# Don't try to access the runner on AMD
if self.aot_schedule:
@ -727,8 +728,8 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
def _build_decode(
self,
block_table_tensor: torch.Tensor,
seq_lens_cpu: torch.Tensor,
seq_lens_device: torch.Tensor,
max_seq_len: int,
query_start_loc_cpu: torch.Tensor,
query_start_loc_device: torch.Tensor,
num_decode_tokens: int,
@ -778,13 +779,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
query_start_loc = common_attn_metadata.query_start_loc
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu
seq_lens = common_attn_metadata.seq_lens
seq_lens_cpu = common_attn_metadata.seq_lens_cpu
dcp_local_seq_lens = common_attn_metadata.dcp_local_seq_lens
dcp_local_seq_lens_cpu = common_attn_metadata.dcp_local_seq_lens_cpu
query_seq_lens_cpu = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1]
num_computed_tokens_cpu = common_attn_metadata.seq_lens_cpu - query_seq_lens_cpu
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
split_decodes_and_prefills(
@ -799,6 +794,8 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
prefill_metadata = None
if num_prefills > 0:
num_computed_tokens_cpu = common_attn_metadata.num_computed_tokens_cpu
reqs_start = num_decodes # prefill_start
context_lens_cpu = num_computed_tokens_cpu[reqs_start:num_reqs]
@ -995,13 +992,22 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
dcp_tot_seq_lens_device = None
if self.dcp_world_size > 1:
dcp_tot_seq_lens_device = seq_lens[:num_decodes]
seq_lens_cpu = dcp_local_seq_lens_cpu
seq_lens = dcp_local_seq_lens
# After DCP distribution, the maximum number of tokens for any rank is
# ceil(L / (N * I)) * I, where L is max_seq_len, N is dcp_world_size,
# and I is cp_kv_cache_interleave_size.
# This eliminates GPU->CPU sync while minimizing workspace
# over-allocation.
num_partitions = self.dcp_world_size * self.cp_kv_cache_interleave_size
max_seq_len = (
(max_seq_len + num_partitions - 1) // num_partitions
) * self.cp_kv_cache_interleave_size
decode_metadata = self._build_decode(
block_table_tensor=block_table_tensor[:num_decodes, ...],
seq_lens_cpu=seq_lens_cpu[:num_decodes],
seq_lens_device=seq_lens[:num_decodes],
max_seq_len=max_seq_len,
query_start_loc_cpu=query_start_loc_cpu[: num_decodes + 1],
query_start_loc_device=query_start_loc[: num_decodes + 1],
num_decode_tokens=num_decode_tokens,

View File

@ -169,8 +169,8 @@ class FlashAttnMLAMetadataBuilder(MLACommonMetadataBuilder[FlashAttnMLAMetadata]
def _build_decode(
self,
block_table_tensor: torch.Tensor,
seq_lens_cpu: torch.Tensor,
seq_lens_device: torch.Tensor,
max_seq_len: int,
query_start_loc_cpu: torch.Tensor,
query_start_loc_device: torch.Tensor,
num_decode_tokens: int,
@ -178,7 +178,6 @@ class FlashAttnMLAMetadataBuilder(MLACommonMetadataBuilder[FlashAttnMLAMetadata]
) -> FlashAttnMLADecodeMetadata:
query_lens_cpu = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1]
max_query_len = query_lens_cpu.max().item()
max_seq_len = seq_lens_cpu.max().item()
# For Flash Attention MLA + full cudagraph
max_num_splits = 0
@ -193,7 +192,7 @@ class FlashAttnMLAMetadataBuilder(MLACommonMetadataBuilder[FlashAttnMLAMetadata]
max_num_splits = 1
scheduler_metadata = self._schedule_decode(
num_reqs=seq_lens_cpu.numel(),
num_reqs=seq_lens_device.shape[0],
cu_query_lens=query_start_loc_device,
max_query_len=max_query_len,
seqlens=seq_lens_device,

View File

@ -143,8 +143,8 @@ class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]):
def _build_decode(
self,
block_table_tensor: torch.Tensor,
seq_lens_cpu: torch.Tensor,
seq_lens_device: torch.Tensor,
max_seq_len: int,
query_start_loc_cpu: torch.Tensor,
query_start_loc_device: torch.Tensor,
num_decode_tokens: int,

View File

@ -106,8 +106,8 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
def _build_decode(
self,
block_table_tensor: torch.Tensor,
seq_lens_cpu: torch.Tensor,
seq_lens_device: torch.Tensor,
max_seq_len: int,
query_start_loc_cpu: torch.Tensor,
query_start_loc_device: torch.Tensor,
num_decode_tokens: int,

View File

@ -236,6 +236,7 @@ class EagleProposer:
common_attn_metadata: CommonAttentionMetadata,
sampling_metadata: SamplingMetadata,
mm_embed_inputs: tuple[list[torch.Tensor], torch.Tensor] | None = None,
num_rejected_tokens_gpu: torch.Tensor | None = None,
) -> torch.Tensor:
num_tokens = target_token_ids.shape[0]
batch_size = next_token_ids.shape[0]
@ -414,6 +415,17 @@ class EagleProposer:
common_attn_metadata.query_start_loc_cpu = torch.from_numpy(
self.token_arange_np[: batch_size + 1]
).clone()
# In padded drafter batch, we need to adjust the sequence lengths
# to remove the "padding" (i.e. rejected tokens).
# Only apply this adjustment when we have rejected tokens
# (i.e., not the first proposal).
if self.num_speculative_tokens > 1 and num_rejected_tokens_gpu is not None:
common_attn_metadata.seq_lens -= num_rejected_tokens_gpu
# Invalidate the CPU-side shadows to avoid H<>D sync.
common_attn_metadata._seq_lens_cpu = None
common_attn_metadata._num_computed_tokens_cpu = None
for token_index in range(self.num_speculative_tokens - 1):
# Update the inputs.
# cast to int32 is crucial when eagle model is compiled.
@ -628,13 +640,14 @@ class EagleProposer:
common_attn_metadata: CommonAttentionMetadata,
spec_decode_metadata: SpecDecodeMetadata,
valid_sampled_tokens_count: torch.Tensor,
) -> tuple[CommonAttentionMetadata, torch.Tensor]:
) -> tuple[CommonAttentionMetadata, torch.Tensor, torch.Tensor]:
"""
This function is used to prepare the inputs for speculative decoding
It updates the common_attn_metadata for speculative decoding,
but does not consider the rejected tokens. Instead, all tokens
are included as inputs to the speculator, with the rejected tokens
used as padding and filtered out later by `token_indices_to_sample`.
No blocking CPU operations should be introduced in this function.
"""
num_reqs = common_attn_metadata.num_reqs
device = valid_sampled_tokens_count.device
@ -642,14 +655,17 @@ class EagleProposer:
token_indices_to_sample = torch.empty(
(num_reqs,), dtype=torch.int32, device=device
)
num_rejected_tokens_gpu = torch.empty(
(num_reqs,), dtype=torch.int32, device=device
)
# Kernel grid: one program per request (row)
grid = (num_reqs,)
eagle_prepare_inputs_padded_kernel[grid](
spec_decode_metadata.cu_num_draft_tokens,
valid_sampled_tokens_count,
common_attn_metadata.query_start_loc,
token_indices_to_sample,
num_rejected_tokens_gpu,
num_reqs,
)
@ -674,7 +690,11 @@ class EagleProposer:
dcp_local_seq_lens=common_attn_metadata.dcp_local_seq_lens,
)
return spec_common_attn_metadata, token_indices_to_sample
return (
spec_common_attn_metadata,
token_indices_to_sample,
num_rejected_tokens_gpu,
)
def propose_tree(
self,

View File

@ -23,6 +23,7 @@ def eagle_prepare_inputs_padded_kernel(
valid_sampled_tokens_count_ptr, # [num_reqs]
query_start_loc_gpu_ptr, # [num_reqs + 1]
token_indices_to_sample_ptr, # [num_reqs] (output)
num_rejected_tokens_gpu_ptr, # [num_reqs] (output)
num_reqs, # tl.int32
):
"""
@ -56,6 +57,7 @@ def eagle_prepare_inputs_padded_kernel(
index_to_sample = q_last_tok_idx - num_rejected_tokens
tl.store(token_indices_to_sample_ptr + req_idx, index_to_sample)
tl.store(num_rejected_tokens_gpu_ptr + req_idx, num_rejected_tokens)
@triton.jit

View File

@ -3534,6 +3534,7 @@ class GPUModelRunner(
next_token_ids, valid_sampled_tokens_count
)
num_rejected_tokens_gpu = None
if spec_decode_metadata is None:
token_indices_to_sample = None
# input_ids can be None for multimodal models.
@ -3564,12 +3565,14 @@ class GPUModelRunner(
else:
target_hidden_states = hidden_states[token_indices]
else:
common_attn_metadata, token_indices_to_sample = (
self.drafter.prepare_inputs_padded(
common_attn_metadata,
spec_decode_metadata,
valid_sampled_tokens_count,
)
(
common_attn_metadata,
token_indices_to_sample,
num_rejected_tokens_gpu,
) = self.drafter.prepare_inputs_padded(
common_attn_metadata,
spec_decode_metadata,
valid_sampled_tokens_count,
)
total_num_tokens = common_attn_metadata.num_actual_tokens
# When padding the batch, token_indices is just a range
@ -3600,6 +3603,7 @@ class GPUModelRunner(
sampling_metadata=sampling_metadata,
common_attn_metadata=common_attn_metadata,
mm_embed_inputs=mm_embed_inputs,
num_rejected_tokens_gpu=num_rejected_tokens_gpu,
)
return draft_token_ids