mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-15 18:25:45 +08:00
[Bugfix] Unify rank computation across regular decoding and speculative decoding (#7899)
This commit is contained in:
parent
ef99a78760
commit
f205c09854
@ -4,10 +4,12 @@ import pytest
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from vllm.model_executor.layers.rejection_sampler import RejectionSampler
|
from vllm.model_executor.layers.rejection_sampler import RejectionSampler
|
||||||
|
from vllm.model_executor.layers.sampler import _get_ranks
|
||||||
from vllm.model_executor.layers.typical_acceptance_sampler import (
|
from vllm.model_executor.layers.typical_acceptance_sampler import (
|
||||||
TypicalAcceptanceSampler)
|
TypicalAcceptanceSampler)
|
||||||
from vllm.sequence import SequenceGroupMetadata, get_all_seq_ids
|
from vllm.sequence import SequenceGroupMetadata, get_all_seq_ids
|
||||||
from vllm.spec_decode.util import split_batch_by_proposal_len
|
from vllm.spec_decode.util import (get_sampled_token_logprobs,
|
||||||
|
split_batch_by_proposal_len)
|
||||||
|
|
||||||
|
|
||||||
def test_get_all_seq_ids():
|
def test_get_all_seq_ids():
|
||||||
@ -126,3 +128,20 @@ def mock_spec_decode_sampler(acceptance_sampler_method):
|
|||||||
return sampler
|
return sampler
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Invalid sampler name {acceptance_sampler_method}")
|
raise ValueError(f"Invalid sampler name {acceptance_sampler_method}")
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_sampled_token_logprobs():
|
||||||
|
"""Verify get_sampled_token_logprobs returns consistent rankings
|
||||||
|
with regular get_ranks when probabilities match exactly.
|
||||||
|
"""
|
||||||
|
logprob_tensor = torch.tensor(
|
||||||
|
[[[-.1, -.1]] * 2]) # shape (num_steps, batch_size, vocab_size)
|
||||||
|
sampled_token_tensor = torch.tensor([[1,
|
||||||
|
0]]) # shape (num_steps, batch_size)
|
||||||
|
ranks_spec_dec, _ = get_sampled_token_logprobs(logprob_tensor,
|
||||||
|
sampled_token_tensor)
|
||||||
|
|
||||||
|
ranks_regular = _get_ranks(logprob_tensor.reshape((2, -1)),
|
||||||
|
sampled_token_tensor.reshape(-1))
|
||||||
|
|
||||||
|
assert torch.equal(ranks_spec_dec.reshape(-1), ranks_regular)
|
||||||
|
|||||||
@ -43,8 +43,8 @@ def get_sampled_token_logprobs(
|
|||||||
sampled_token_ids, ]
|
sampled_token_ids, ]
|
||||||
expanded_selected_logprobs = selected_logprobs.unsqueeze(-1).expand(
|
expanded_selected_logprobs = selected_logprobs.unsqueeze(-1).expand(
|
||||||
-1, -1, vocab_size)
|
-1, -1, vocab_size)
|
||||||
sampled_token_ids_ranks = (logprob_tensor >=
|
sampled_token_ids_ranks = (logprob_tensor >
|
||||||
expanded_selected_logprobs).sum(-1)
|
expanded_selected_logprobs).sum(-1).add_(1)
|
||||||
|
|
||||||
return sampled_token_ids_ranks, selected_logprobs
|
return sampled_token_ids_ranks, selected_logprobs
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user