[BugFix] Fix logprobs with spec decode and modified logits (#30846)

Signed-off-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
Nick Hill 2025-12-18 19:58:28 -08:00 committed by GitHub
parent 7b43db210c
commit 2ac85a4544
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 30 additions and 11 deletions

View File

@ -547,6 +547,13 @@ def test_spec_decode_logprobs(
sampling_params = SamplingParams( sampling_params = SamplingParams(
temperature=0, logprobs=top_logprobs, max_tokens=10, ignore_eos=False temperature=0, logprobs=top_logprobs, max_tokens=10, ignore_eos=False
) )
penalty_sampling_params = SamplingParams(
temperature=0,
logprobs=top_logprobs,
max_tokens=10,
ignore_eos=False,
presence_penalty=-1.0,
)
method, model_name, spec_model_name = model_setup method, model_name, spec_model_name = model_setup
max_model_len = 256 max_model_len = 256
@ -558,14 +565,17 @@ def test_spec_decode_logprobs(
seed=42, seed=42,
logprobs_mode=logprobs_mode, logprobs_mode=logprobs_mode,
gpu_memory_utilization=0.4, gpu_memory_utilization=0.4,
enable_prefix_caching=False,
)
ref_results = ref_llm.generate(
[prompt, prompt], [sampling_params, penalty_sampling_params]
) )
ref_results = ref_llm.generate([prompt], sampling_params)
# Collect logprobs outputs from reference LLM. # Collect logprobs outputs from reference LLM.
ref_logprobs = [] ref_logprobs = []
for output in ref_results[0].outputs: for results in ref_results:
for logprobs in output.logprobs: for output in results.outputs:
for token_id in logprobs: for logprobs in output.logprobs:
ref_logprobs.append(logprobs[token_id]) ref_logprobs.extend(logprobs.values())
del ref_llm del ref_llm
torch.cuda.empty_cache() torch.cuda.empty_cache()
cleanup_dist_env_and_memory() cleanup_dist_env_and_memory()
@ -587,14 +597,17 @@ def test_spec_decode_logprobs(
# Force prefill chunking # Force prefill chunking
enable_chunked_prefill=True, enable_chunked_prefill=True,
max_num_batched_tokens=32, max_num_batched_tokens=32,
enable_prefix_caching=False,
)
spec_results = spec_llm.generate(
[prompt, prompt], [sampling_params, penalty_sampling_params]
) )
spec_results = spec_llm.generate([prompt], sampling_params)
# Collect logprobs outputs from spec decode LLM. # Collect logprobs outputs from spec decode LLM.
spec_logprobs = [] spec_logprobs = []
for output in spec_results[0].outputs: for results in spec_results:
for logprobs in output.logprobs: for output in results.outputs:
for token_id in logprobs: for logprobs in output.logprobs:
spec_logprobs.append(logprobs[token_id]) spec_logprobs.extend(logprobs.values())
del spec_llm del spec_llm
torch.cuda.empty_cache() torch.cuda.empty_cache()
cleanup_dist_env_and_memory() cleanup_dist_env_and_memory()

View File

@ -119,8 +119,14 @@ class RejectionSampler(nn.Module):
raw_target_logits = logits[target_logits_indices] raw_target_logits = logits[target_logits_indices]
# Use float32 for the target_logits. # Use float32 for the target_logits.
raw_target_logits = raw_target_logits.to(torch.float32) raw_target_logits = raw_target_logits.to(torch.float32)
target_logits = raw_target_logits
if not self.is_processed_logprobs_mode:
# Clone raw_target_logits before applying processors to preserve
# the original raw logits for logprobs computation, since
# apply_logits_processors modifies the tensor in-place.
target_logits = target_logits.clone()
target_logits = self.apply_logits_processors( target_logits = self.apply_logits_processors(
raw_target_logits, sampling_metadata, metadata target_logits, sampling_metadata, metadata
) )
# [num_tokens, vocab_size] # [num_tokens, vocab_size]
# NOTE(woosuk): `target_logits` can be updated in place inside the # NOTE(woosuk): `target_logits` can be updated in place inside the