[bugfix] fix bug when top_logprobs=0 with spec decoding (#30059)

Signed-off-by: realliujiaxu <realliujiaxu@163.com>
This commit is contained in:
realliujiaxu 2025-12-13 01:03:35 +08:00 committed by GitHub
parent f3237f3f6b
commit d2c919dcc2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 5 additions and 3 deletions

View File

@ -528,9 +528,11 @@ def test_logprobs_mode(logprobs_mode: LogprobsMode):
), ),
], ],
) )
@pytest.mark.parametrize("top_logprobs", [0, 3])
def test_spec_decode_logprobs( def test_spec_decode_logprobs(
logprobs_mode: LogprobsMode, logprobs_mode: LogprobsMode,
model_setup: tuple[str, str, str], model_setup: tuple[str, str, str],
top_logprobs: int,
): ):
"""Spec decode logprobs should match those of the base model. """Spec decode logprobs should match those of the base model.
@ -543,7 +545,7 @@ def test_spec_decode_logprobs(
prompt = "Hello world " * 50 prompt = "Hello world " * 50
sampling_params = SamplingParams( sampling_params = SamplingParams(
temperature=0, logprobs=3, max_tokens=10, ignore_eos=False temperature=0, logprobs=top_logprobs, max_tokens=10, ignore_eos=False
) )
method, model_name, spec_model_name = model_setup method, model_name, spec_model_name = model_setup
max_model_len = 256 max_model_len = 256

View File

@ -111,7 +111,7 @@ def create_sampling_metadata(
top_p=top_p, top_p=top_p,
top_k=top_k, top_k=top_k,
generators=generators, generators=generators,
max_num_logprobs=0, max_num_logprobs=None,
no_penalties=no_penalties, no_penalties=no_penalties,
prompt_token_ids=prompt_token_ids, prompt_token_ids=prompt_token_ids,
frequency_penalties=frequency_penalties, frequency_penalties=frequency_penalties,

View File

@ -145,7 +145,7 @@ class RejectionSampler(nn.Module):
) )
logprobs_tensors = None logprobs_tensors = None
if sampling_metadata.max_num_logprobs: if sampling_metadata.max_num_logprobs is not None:
logprobs_tensors = self._get_logprobs_tensors( logprobs_tensors = self._get_logprobs_tensors(
sampling_metadata.max_num_logprobs, sampling_metadata.max_num_logprobs,
metadata, metadata,