mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-18 07:05:02 +08:00
[Bugfix][Spec Decode] Fix wrong valid_mask for padded speculation when chunked prefill occurs (#26231)
Signed-off-by: seven-mile <i@7li.moe> Signed-off-by: Benjamin Chislett <bchislett@nvidia.com> Co-authored-by: Benjamin Chislett <bchislett@nvidia.com>
This commit is contained in:
parent
824a3f403f
commit
b2ea5ba677
@ -522,13 +522,9 @@ class EagleProposer:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Generate a mask for all valid tokens within those requests
|
# Generate a mask for all valid tokens within those requests
|
||||||
max_gen_len = sampled_token_ids.shape[-1]
|
valid_mask = (valid_sampled_token_ids_gpu != -1) & (
|
||||||
if max_gen_len == 1:
|
valid_sampled_token_ids_gpu < gpu_input_batch.vocab_size
|
||||||
valid_mask = torch.ones_like(valid_sampled_token_ids_gpu, dtype=torch.bool)
|
)
|
||||||
else:
|
|
||||||
valid_mask = (valid_sampled_token_ids_gpu != -1) & (
|
|
||||||
valid_sampled_token_ids_gpu < gpu_input_batch.vocab_size
|
|
||||||
)
|
|
||||||
|
|
||||||
# Count the number of valid tokens in each request
|
# Count the number of valid tokens in each request
|
||||||
valid_sampled_tokens_count = valid_mask.sum(dim=1)
|
valid_sampled_tokens_count = valid_mask.sum(dim=1)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user