From 2ac85a4544cf9488037e61bf9ed7a87d0c3696bb Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Thu, 18 Dec 2025 19:58:28 -0800 Subject: [PATCH] [BugFix] Fix logprobs with spec decode and modified logits (#30846) Signed-off-by: Nick Hill --- tests/v1/sample/test_logprobs.py | 33 ++++++++++++++++++++--------- vllm/v1/sample/rejection_sampler.py | 8 ++++++- 2 files changed, 30 insertions(+), 11 deletions(-) diff --git a/tests/v1/sample/test_logprobs.py b/tests/v1/sample/test_logprobs.py index 76a0e8e25a4ae..1e2cc2241ba95 100644 --- a/tests/v1/sample/test_logprobs.py +++ b/tests/v1/sample/test_logprobs.py @@ -547,6 +547,13 @@ def test_spec_decode_logprobs( sampling_params = SamplingParams( temperature=0, logprobs=top_logprobs, max_tokens=10, ignore_eos=False ) + penalty_sampling_params = SamplingParams( + temperature=0, + logprobs=top_logprobs, + max_tokens=10, + ignore_eos=False, + presence_penalty=-1.0, + ) method, model_name, spec_model_name = model_setup max_model_len = 256 @@ -558,14 +565,17 @@ def test_spec_decode_logprobs( seed=42, logprobs_mode=logprobs_mode, gpu_memory_utilization=0.4, + enable_prefix_caching=False, + ) + ref_results = ref_llm.generate( + [prompt, prompt], [sampling_params, penalty_sampling_params] ) - ref_results = ref_llm.generate([prompt], sampling_params) # Collect logprobs outputs from reference LLM. ref_logprobs = [] - for output in ref_results[0].outputs: - for logprobs in output.logprobs: - for token_id in logprobs: - ref_logprobs.append(logprobs[token_id]) + for results in ref_results: + for output in results.outputs: + for logprobs in output.logprobs: + ref_logprobs.extend(logprobs.values()) del ref_llm torch.cuda.empty_cache() cleanup_dist_env_and_memory() @@ -587,14 +597,17 @@ def test_spec_decode_logprobs( # Force prefill chunking enable_chunked_prefill=True, max_num_batched_tokens=32, + enable_prefix_caching=False, + ) + spec_results = spec_llm.generate( + [prompt, prompt], [sampling_params, penalty_sampling_params] ) - spec_results = spec_llm.generate([prompt], sampling_params) # Collect logprobs outputs from spec decode LLM. spec_logprobs = [] - for output in spec_results[0].outputs: - for logprobs in output.logprobs: - for token_id in logprobs: - spec_logprobs.append(logprobs[token_id]) + for results in spec_results: + for output in results.outputs: + for logprobs in output.logprobs: + spec_logprobs.extend(logprobs.values()) del spec_llm torch.cuda.empty_cache() cleanup_dist_env_and_memory() diff --git a/vllm/v1/sample/rejection_sampler.py b/vllm/v1/sample/rejection_sampler.py index 50b91d8292ee8..f2338e9b4b7d0 100644 --- a/vllm/v1/sample/rejection_sampler.py +++ b/vllm/v1/sample/rejection_sampler.py @@ -119,8 +119,14 @@ class RejectionSampler(nn.Module): raw_target_logits = logits[target_logits_indices] # Use float32 for the target_logits. raw_target_logits = raw_target_logits.to(torch.float32) + target_logits = raw_target_logits + if not self.is_processed_logprobs_mode: + # Clone raw_target_logits before applying processors to preserve + # the original raw logits for logprobs computation, since + # apply_logits_processors modifies the tensor in-place. + target_logits = target_logits.clone() target_logits = self.apply_logits_processors( - raw_target_logits, sampling_metadata, metadata + target_logits, sampling_metadata, metadata ) # [num_tokens, vocab_size] # NOTE(woosuk): `target_logits` can be updated in place inside the