[BugFix] Fix chunked prompt logprobs + preemption (#29071)

This commit is contained in:
Nick Hill 2025-11-22 13:07:18 -08:00 committed by GitHub
parent eb5352a770
commit 7df331c66b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 127 additions and 31 deletions

View File

@ -853,6 +853,7 @@ class VllmRunner:
@staticmethod
def _final_steps_generate_w_logprobs(
req_outputs: list[RequestOutput],
include_prompt_token_ids: bool = False,
) -> list[TokensTextLogprobsPromptLogprobs]:
outputs: list[TokensTextLogprobsPromptLogprobs] = []
for req_output in req_outputs:
@ -861,9 +862,26 @@ class VllmRunner:
output_str = sample.text
output_ids = list(sample.token_ids)
output_logprobs = sample.logprobs
if include_prompt_token_ids:
outputs.append(
(output_ids, output_str, output_logprobs, req_output.prompt_logprobs)
( # type: ignore[arg-type]
output_ids,
output_str,
output_logprobs,
req_output.prompt_token_ids,
req_output.prompt_logprobs,
)
)
else:
outputs.append(
(
output_ids,
output_str,
output_logprobs,
req_output.prompt_logprobs,
)
)
return outputs
def generate_w_logprobs(
@ -873,6 +891,7 @@ class VllmRunner:
images: PromptImageInput | None = None,
audios: PromptAudioInput | None = None,
videos: PromptVideoInput | None = None,
include_prompt_token_ids: bool = False,
**kwargs: Any,
) -> list[TokensTextLogprobs] | list[TokensTextLogprobsPromptLogprobs]:
inputs = self.get_inputs(prompts, images=images, videos=videos, audios=audios)
@ -882,7 +901,7 @@ class VllmRunner:
)
toks_str_logsprobs_prompt_logprobs = self._final_steps_generate_w_logprobs(
req_outputs
req_outputs, include_prompt_token_ids
)
# Omit prompt logprobs if not required by sampling params
return (

View File

@ -605,3 +605,79 @@ def test_spec_decode_logprobs(
)
assert ref_logprob.rank == spec_logprob.rank
assert ref_logprob.decoded_token == spec_logprob.decoded_token
def test_prompt_logprobs_with_chunking_and_preemption():
"""Test that prompt logprobs are correctly returned when using
both chunked prefill and preemption.
This test ensures that the num_prompt_logprobs tracking persists
across preemptions and prefill chunks.
"""
# Create prompts that will trigger chunking and preemption
prompts = [
"The following numbers of the sequence "
+ ", ".join(str(i) for i in range(10))
+ " are:",
"In one word, the capital of France is ",
] + [f"Tell me about the number {i}: " for i in range(32)]
sampling_params = SamplingParams(
temperature=0.0,
max_tokens=40,
min_tokens=20,
prompt_logprobs=2, # Request prompt logprobs
)
with VllmRunner(
"Qwen/Qwen3-0.6B",
max_model_len=512,
enable_chunked_prefill=True,
max_num_batched_tokens=48, # Force prefill chunking
num_gpu_blocks_override=32, # Force preemptions
disable_log_stats=False,
gpu_memory_utilization=0.25,
) as vllm_model:
metrics_before = vllm_model.llm.get_metrics()
# Generate with prompt logprobs using generate_w_logprobs which
# returns (output_ids, output_str, output_logprobs, prompt_logprobs)
outputs = vllm_model.generate_w_logprobs(
prompts, sampling_params=sampling_params, include_prompt_token_ids=True
)
# Verify that all outputs have prompt logprobs
for i, output in enumerate(outputs):
_, _, _, prompt_token_ids, prompt_logprobs = output
assert prompt_logprobs is not None and len(prompt_logprobs) > 0, (
f"Output {i} missing prompt logprobs"
)
assert len(prompt_logprobs) == len(prompt_token_ids), (
"Unexpected number of prompt logprob positions"
)
# Each position should have the requested number of logprobs
for pos, logprobs_dict in enumerate(prompt_logprobs):
if logprobs_dict is not None: # First token may be None
assert (
sampling_params.prompt_logprobs
<= len(logprobs_dict)
<= sampling_params.prompt_logprobs + 1
), (
f"Output {i} position {pos} has {len(logprobs_dict)} "
f"logprobs, expected {sampling_params.prompt_logprobs}"
)
# Check that we actually had preemptions
metrics_after = vllm_model.llm.get_metrics()
preemptions_before = next(
(m.value for m in metrics_before if m.name == "vllm:num_preemptions"), 0
)
preemptions_after = next(
(m.value for m in metrics_after if m.name == "vllm:num_preemptions"), 0
)
preemptions = preemptions_after - preemptions_before
assert preemptions > 0, "Test did not trigger any preemptions"
print(f"Test passed with {preemptions} preemptions")

View File

@ -219,9 +219,6 @@ class InputBatch:
self.generators: dict[int, torch.Generator] = {}
self.num_logprobs: dict[str, int] = {}
# NOTE(rob): num_prompt_logprobs only includes reqs
# that are currently in the prefill phase.
self.num_prompt_logprobs: dict[str, int] = {}
# To accumulate prompt logprobs tensor chunks across prefill steps.
self.in_progress_prompt_logprobs_cpu: dict[str, LogprobsTensors] = {}
@ -385,12 +382,6 @@ class InputBatch:
if sampling_params.logprobs == -1
else sampling_params.logprobs
)
if sampling_params.prompt_logprobs is not None:
self.num_prompt_logprobs[req_id] = (
self.vocab_size
if sampling_params.prompt_logprobs == -1
else sampling_params.prompt_logprobs
)
if sampling_params.allowed_token_ids:
self.has_allowed_token_ids.add(req_id)
@ -488,7 +479,6 @@ class InputBatch:
self.repetition_penalties_reqs.discard(req_id)
self.generators.pop(req_index, None)
self.num_logprobs.pop(req_id, None)
self.num_prompt_logprobs.pop(req_id, None)
self.in_progress_prompt_logprobs_cpu.pop(req_id, None)
self.has_allowed_token_ids.discard(req_id)
@ -972,10 +962,6 @@ class InputBatch:
def max_num_logprobs(self) -> int | None:
return max(self.num_logprobs.values()) if self.num_logprobs else None
@property
def no_prompt_logprob(self) -> bool:
return not self.num_prompt_logprobs
@property
def no_allowed_token_ids(self) -> bool:
return len(self.has_allowed_token_ids) == 0

View File

@ -393,6 +393,9 @@ class GPUModelRunner(
# Request states.
self.requests: dict[str, CachedRequestState] = {}
# NOTE(rob): num_prompt_logprobs only includes reqs
# that are currently in the prefill phase.
self.num_prompt_logprobs: dict[str, int] = {}
self.comm_stream = torch.cuda.Stream()
# Input Batch
@ -687,6 +690,7 @@ class GPUModelRunner(
# Remove finished requests from the cached states.
for req_id in scheduler_output.finished_req_ids:
self.requests.pop(req_id, None)
self.num_prompt_logprobs.pop(req_id, None)
# Remove the finished requests from the persistent batch.
# NOTE(woosuk): There could be an edge case where finished_req_ids and
# scheduled_req_ids overlap. This happens when a request is aborted and
@ -755,6 +759,13 @@ class GPUModelRunner(
)
self.requests[req_id] = req_state
if sampling_params and sampling_params.prompt_logprobs is not None:
self.num_prompt_logprobs[req_id] = (
self.input_batch.vocab_size
if sampling_params.prompt_logprobs == -1
else sampling_params.prompt_logprobs
)
# Only relevant for models using M-RoPE (e.g, Qwen2-VL)
if self.uses_mrope:
self._init_mrope_positions(req_state)
@ -2671,7 +2682,7 @@ class GPUModelRunner(
scheduler_output, self.vllm_config
)
if self.cache_config.kv_sharing_fast_prefill:
assert not self.input_batch.num_prompt_logprobs, (
assert not self.num_prompt_logprobs, (
"--kv-sharing-fast-prefill produces incorrect "
"logprobs for prompt tokens, tokens, please disable "
"it when the requests need prompt logprobs"
@ -3436,7 +3447,7 @@ class GPUModelRunner(
hidden_states: torch.Tensor,
num_scheduled_tokens: dict[str, int],
) -> dict[str, LogprobsTensors | None]:
num_prompt_logprobs_dict = self.input_batch.num_prompt_logprobs
num_prompt_logprobs_dict = self.num_prompt_logprobs
if not num_prompt_logprobs_dict:
return {}
@ -3447,7 +3458,10 @@ class GPUModelRunner(
# maintainable loop over optimal performance.
completed_prefill_reqs = []
for req_id, num_prompt_logprobs in num_prompt_logprobs_dict.items():
num_tokens = num_scheduled_tokens[req_id]
num_tokens = num_scheduled_tokens.get(req_id)
if num_tokens is None:
# This can happen if the request was preempted in prefill stage.
continue
# Get metadata for this request.
request = self.requests[req_id]

View File

@ -149,9 +149,6 @@ class InputBatch:
self.generators: dict[int, torch.Generator] = {}
self.num_logprobs: dict[str, int] = {}
# NOTE(rob): num_prompt_logprobs only includes reqs
# that are currently in the prefill phase.
self.num_prompt_logprobs: dict[str, int] = {}
# To accumulate prompt logprobs tensor chunks across prefill steps.
self.in_progress_prompt_logprobs_cpu: dict[str, LogprobsTensors] = {}
@ -256,8 +253,6 @@ class InputBatch:
if sampling_params.logprobs is not None:
self.num_logprobs[req_id] = sampling_params.logprobs
if sampling_params.prompt_logprobs is not None:
self.num_prompt_logprobs[req_id] = sampling_params.prompt_logprobs
if sampling_params.logit_bias is not None:
self.logit_bias[req_index] = sampling_params.logit_bias
@ -317,7 +312,6 @@ class InputBatch:
self.repetition_penalties_reqs.discard(req_id)
self.generators.pop(req_index, None)
self.num_logprobs.pop(req_id, None)
self.num_prompt_logprobs.pop(req_id, None)
self.in_progress_prompt_logprobs_cpu.pop(req_id, None)
# LoRA
@ -584,10 +578,6 @@ class InputBatch:
def max_num_logprobs(self) -> int | None:
return max(self.num_logprobs.values()) if self.num_logprobs else None
@property
def no_prompt_logprob(self) -> bool:
return not self.num_prompt_logprobs
@property
def no_allowed_token_ids(self) -> bool:
return len(self.has_allowed_token_ids) == 0

View File

@ -247,6 +247,9 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# Request states.
self.requests: dict[str, CachedRequestState] = {}
# NOTE(rob): num_prompt_logprobs only includes reqs
# that are currently in the prefill phase.
self.num_prompt_logprobs: dict[str, int] = {}
# Initialize input batch early to avoid AttributeError in _update_states
self.input_batch = InputBatch(
@ -420,6 +423,7 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# Remove finished requests from the cached states.
for req_id in scheduler_output.finished_req_ids:
self.requests.pop(req_id, None)
self.num_prompt_logprobs.pop(req_id, None)
# Remove the finished requests from the persistent batch.
# NOTE(woosuk): There could be an edge case where finished_req_ids and
@ -477,6 +481,13 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
lora_request=new_req_data.lora_request,
)
if sampling_params and sampling_params.prompt_logprobs is not None:
self.num_prompt_logprobs[req_id] = (
self.input_batch.vocab_size
if sampling_params.prompt_logprobs == -1
else sampling_params.prompt_logprobs
)
req_ids_to_add.append(req_id)
# Update the states of the running/resumed requests.