[Bugfix][Spec Decode] Fix wrong valid_mask for padded speculation when chunked prefill occurs (#26231)

Signed-off-by: seven-mile <i@7li.moe>
Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
Co-authored-by: Benjamin Chislett <bchislett@nvidia.com>
This commit is contained in:
7mile 2025-10-07 02:24:22 +08:00 committed by GitHub
parent 824a3f403f
commit b2ea5ba677
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -522,13 +522,9 @@ class EagleProposer:
)
# Generate a mask for all valid tokens within those requests
max_gen_len = sampled_token_ids.shape[-1]
if max_gen_len == 1:
valid_mask = torch.ones_like(valid_sampled_token_ids_gpu, dtype=torch.bool)
else:
valid_mask = (valid_sampled_token_ids_gpu != -1) & (
valid_sampled_token_ids_gpu < gpu_input_batch.vocab_size
)
valid_mask = (valid_sampled_token_ids_gpu != -1) & (
valid_sampled_token_ids_gpu < gpu_input_batch.vocab_size
)
# Count the number of valid tokens in each request
valid_sampled_tokens_count = valid_mask.sum(dim=1)