mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-23 22:16:14 +08:00
[bugfix] fix bug when top_logprobs=0 with spec decoding (#30059)
Signed-off-by: realliujiaxu <realliujiaxu@163.com>
This commit is contained in:
parent
f3237f3f6b
commit
d2c919dcc2
@ -528,9 +528,11 @@ def test_logprobs_mode(logprobs_mode: LogprobsMode):
|
||||
),
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize("top_logprobs", [0, 3])
|
||||
def test_spec_decode_logprobs(
|
||||
logprobs_mode: LogprobsMode,
|
||||
model_setup: tuple[str, str, str],
|
||||
top_logprobs: int,
|
||||
):
|
||||
"""Spec decode logprobs should match those of the base model.
|
||||
|
||||
@ -543,7 +545,7 @@ def test_spec_decode_logprobs(
|
||||
|
||||
prompt = "Hello world " * 50
|
||||
sampling_params = SamplingParams(
|
||||
temperature=0, logprobs=3, max_tokens=10, ignore_eos=False
|
||||
temperature=0, logprobs=top_logprobs, max_tokens=10, ignore_eos=False
|
||||
)
|
||||
method, model_name, spec_model_name = model_setup
|
||||
max_model_len = 256
|
||||
|
||||
@ -111,7 +111,7 @@ def create_sampling_metadata(
|
||||
top_p=top_p,
|
||||
top_k=top_k,
|
||||
generators=generators,
|
||||
max_num_logprobs=0,
|
||||
max_num_logprobs=None,
|
||||
no_penalties=no_penalties,
|
||||
prompt_token_ids=prompt_token_ids,
|
||||
frequency_penalties=frequency_penalties,
|
||||
|
||||
@ -145,7 +145,7 @@ class RejectionSampler(nn.Module):
|
||||
)
|
||||
|
||||
logprobs_tensors = None
|
||||
if sampling_metadata.max_num_logprobs:
|
||||
if sampling_metadata.max_num_logprobs is not None:
|
||||
logprobs_tensors = self._get_logprobs_tensors(
|
||||
sampling_metadata.max_num_logprobs,
|
||||
metadata,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user