[BugFix] Fix returned logprobs with spec decode + prefill chunking (#29216)

Signed-off-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
Nick Hill 2025-11-22 06:41:25 -08:00 committed by GitHub
parent 066209a045
commit d44a63c6d6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 22 additions and 15 deletions

View File

@ -521,8 +521,8 @@ def test_logprobs_mode(logprobs_mode: LogprobsMode):
pytest.param(
(
"eagle",
"meta-llama/Llama-3.1-8B-Instruct",
"yuhuili/EAGLE-LLaMA3.1-Instruct-8B",
"meta-llama/Llama-3.2-1B-Instruct",
"nm-testing/Llama3_2_1B_speculator.eagle3",
),
marks=large_gpu_mark(min_gb=32),
),
@ -541,7 +541,7 @@ def test_spec_decode_logprobs(
"""
from vllm import LLM
prompt = "Hello world"
prompt = "Hello world " * 50
sampling_params = SamplingParams(
temperature=0, logprobs=3, max_tokens=10, ignore_eos=False
)
@ -582,6 +582,9 @@ def test_spec_decode_logprobs(
seed=42,
logprobs_mode=logprobs_mode,
gpu_memory_utilization=0.4,
# Force prefill chunking
enable_chunked_prefill=True,
max_num_batched_tokens=32,
)
spec_results = spec_llm.generate([prompt], sampling_params)
# Collect logprobs outputs from spec decode LLM.
@ -597,6 +600,8 @@ def test_spec_decode_logprobs(
# Per-token logprobs are expected to be the same.
assert len(ref_logprobs) == len(spec_logprobs)
for ref_logprob, spec_logprob in zip(ref_logprobs, spec_logprobs):
assert math.isclose(ref_logprob.logprob, spec_logprob.logprob, abs_tol=1e-3)
assert math.isclose(
ref_logprob.logprob, spec_logprob.logprob, rel_tol=5e-2, abs_tol=1e-1
)
assert ref_logprob.rank == spec_logprob.rank
assert ref_logprob.decoded_token == spec_logprob.decoded_token

View File

@ -81,7 +81,10 @@ class Sampler(nn.Module):
if logprobs_mode == "raw_logprobs":
raw_logprobs = self.compute_logprobs(logits)
elif logprobs_mode == "raw_logits":
raw_logprobs = logits.clone()
if logits.dtype == torch.float32:
raw_logprobs = logits.clone()
else:
raw_logprobs = logits.to(torch.float32)
# Use float32 for the logits.
logits = logits.to(torch.float32)

View File

@ -2466,7 +2466,9 @@ class GPUModelRunner(
num_sampled_tokens = sampler_output.sampled_token_ids.shape[0]
sampled_token_ids = sampler_output.sampled_token_ids
logprobs_tensors = sampler_output.logprobs_tensors
invalid_req_indices = []
cu_num_new_tokens: list[int] | None = None
if not self.use_async_scheduling:
# Get the valid generated tokens.
max_gen_len = sampled_token_ids.shape[-1]
@ -2479,6 +2481,12 @@ class GPUModelRunner(
sampled_token_ids,
self.input_batch.vocab_size,
)
if logprobs_tensors:
# Needed for extracting logprobs when spec decoding.
# This must be done prior to discarding sampled tokens.
cu_num_new_tokens = [0]
for toks in valid_sampled_token_ids:
cu_num_new_tokens.append(cu_num_new_tokens[-1] + len(toks))
# Mask out the sampled tokens that should not be sampled.
for i in discard_sampled_tokens_req_indices:
valid_sampled_token_ids[int(i)].clear()
@ -2506,10 +2514,6 @@ class GPUModelRunner(
# the sampled tokens back, because there's no direct communication
# between the first-stage worker and the last-stage worker.
req_ids = self.input_batch.req_ids
logprobs_tensors = sampler_output.logprobs_tensors
cu_num_accepted_tokens = (
[0] if spec_decode_metadata and logprobs_tensors else None
)
for req_idx in range(num_sampled_tokens):
if self.use_async_scheduling:
sampled_ids = [-1] if req_idx not in invalid_req_indices_set else None
@ -2518,11 +2522,6 @@ class GPUModelRunner(
num_sampled_ids: int = len(sampled_ids) if sampled_ids else 0
if cu_num_accepted_tokens is not None:
cu_num_accepted_tokens.append(
cu_num_accepted_tokens[-1] + num_sampled_ids
)
if not sampled_ids:
continue
@ -2544,7 +2543,7 @@ class GPUModelRunner(
req_state.output_token_ids.extend(sampled_ids)
logprobs_lists = (
logprobs_tensors.tolists(cu_num_accepted_tokens)
logprobs_tensors.tolists(cu_num_new_tokens)
if not self.use_async_scheduling and logprobs_tensors is not None
else None
)