[BugFix] Fix logprobs with spec decode and modified logits (#30846)

Signed-off-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
Nick Hill 2025-12-18 19:58:28 -08:00 committed by GitHub
parent 7b43db210c
commit 2ac85a4544
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 30 additions and 11 deletions

View File

@ -547,6 +547,13 @@ def test_spec_decode_logprobs(
sampling_params = SamplingParams(
temperature=0, logprobs=top_logprobs, max_tokens=10, ignore_eos=False
)
penalty_sampling_params = SamplingParams(
temperature=0,
logprobs=top_logprobs,
max_tokens=10,
ignore_eos=False,
presence_penalty=-1.0,
)
method, model_name, spec_model_name = model_setup
max_model_len = 256
@ -558,14 +565,17 @@ def test_spec_decode_logprobs(
seed=42,
logprobs_mode=logprobs_mode,
gpu_memory_utilization=0.4,
enable_prefix_caching=False,
)
ref_results = ref_llm.generate(
[prompt, prompt], [sampling_params, penalty_sampling_params]
)
ref_results = ref_llm.generate([prompt], sampling_params)
# Collect logprobs outputs from reference LLM.
ref_logprobs = []
for output in ref_results[0].outputs:
for logprobs in output.logprobs:
for token_id in logprobs:
ref_logprobs.append(logprobs[token_id])
for results in ref_results:
for output in results.outputs:
for logprobs in output.logprobs:
ref_logprobs.extend(logprobs.values())
del ref_llm
torch.cuda.empty_cache()
cleanup_dist_env_and_memory()
@ -587,14 +597,17 @@ def test_spec_decode_logprobs(
# Force prefill chunking
enable_chunked_prefill=True,
max_num_batched_tokens=32,
enable_prefix_caching=False,
)
spec_results = spec_llm.generate(
[prompt, prompt], [sampling_params, penalty_sampling_params]
)
spec_results = spec_llm.generate([prompt], sampling_params)
# Collect logprobs outputs from spec decode LLM.
spec_logprobs = []
for output in spec_results[0].outputs:
for logprobs in output.logprobs:
for token_id in logprobs:
spec_logprobs.append(logprobs[token_id])
for results in spec_results:
for output in results.outputs:
for logprobs in output.logprobs:
spec_logprobs.extend(logprobs.values())
del spec_llm
torch.cuda.empty_cache()
cleanup_dist_env_and_memory()

View File

@ -119,8 +119,14 @@ class RejectionSampler(nn.Module):
raw_target_logits = logits[target_logits_indices]
# Use float32 for the target_logits.
raw_target_logits = raw_target_logits.to(torch.float32)
target_logits = raw_target_logits
if not self.is_processed_logprobs_mode:
# Clone raw_target_logits before applying processors to preserve
# the original raw logits for logprobs computation, since
# apply_logits_processors modifies the tensor in-place.
target_logits = target_logits.clone()
target_logits = self.apply_logits_processors(
raw_target_logits, sampling_metadata, metadata
target_logits, sampling_metadata, metadata
)
# [num_tokens, vocab_size]
# NOTE(woosuk): `target_logits` can be updated in place inside the