mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-24 02:15:57 +08:00
[bugfix] fix bug when top_logprobs=0 with spec decoding (#30059)
Signed-off-by: realliujiaxu <realliujiaxu@163.com>
This commit is contained in:
parent
f3237f3f6b
commit
d2c919dcc2
@ -528,9 +528,11 @@ def test_logprobs_mode(logprobs_mode: LogprobsMode):
|
|||||||
),
|
),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
@pytest.mark.parametrize("top_logprobs", [0, 3])
|
||||||
def test_spec_decode_logprobs(
|
def test_spec_decode_logprobs(
|
||||||
logprobs_mode: LogprobsMode,
|
logprobs_mode: LogprobsMode,
|
||||||
model_setup: tuple[str, str, str],
|
model_setup: tuple[str, str, str],
|
||||||
|
top_logprobs: int,
|
||||||
):
|
):
|
||||||
"""Spec decode logprobs should match those of the base model.
|
"""Spec decode logprobs should match those of the base model.
|
||||||
|
|
||||||
@ -543,7 +545,7 @@ def test_spec_decode_logprobs(
|
|||||||
|
|
||||||
prompt = "Hello world " * 50
|
prompt = "Hello world " * 50
|
||||||
sampling_params = SamplingParams(
|
sampling_params = SamplingParams(
|
||||||
temperature=0, logprobs=3, max_tokens=10, ignore_eos=False
|
temperature=0, logprobs=top_logprobs, max_tokens=10, ignore_eos=False
|
||||||
)
|
)
|
||||||
method, model_name, spec_model_name = model_setup
|
method, model_name, spec_model_name = model_setup
|
||||||
max_model_len = 256
|
max_model_len = 256
|
||||||
|
|||||||
@ -111,7 +111,7 @@ def create_sampling_metadata(
|
|||||||
top_p=top_p,
|
top_p=top_p,
|
||||||
top_k=top_k,
|
top_k=top_k,
|
||||||
generators=generators,
|
generators=generators,
|
||||||
max_num_logprobs=0,
|
max_num_logprobs=None,
|
||||||
no_penalties=no_penalties,
|
no_penalties=no_penalties,
|
||||||
prompt_token_ids=prompt_token_ids,
|
prompt_token_ids=prompt_token_ids,
|
||||||
frequency_penalties=frequency_penalties,
|
frequency_penalties=frequency_penalties,
|
||||||
|
|||||||
@ -145,7 +145,7 @@ class RejectionSampler(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
logprobs_tensors = None
|
logprobs_tensors = None
|
||||||
if sampling_metadata.max_num_logprobs:
|
if sampling_metadata.max_num_logprobs is not None:
|
||||||
logprobs_tensors = self._get_logprobs_tensors(
|
logprobs_tensors = self._get_logprobs_tensors(
|
||||||
sampling_metadata.max_num_logprobs,
|
sampling_metadata.max_num_logprobs,
|
||||||
metadata,
|
metadata,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user