mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 01:35:01 +08:00
[BugFix] Fix chunked prompt logprobs + preemption (#29071)
This commit is contained in:
parent
eb5352a770
commit
7df331c66b
@ -853,6 +853,7 @@ class VllmRunner:
|
||||
@staticmethod
|
||||
def _final_steps_generate_w_logprobs(
|
||||
req_outputs: list[RequestOutput],
|
||||
include_prompt_token_ids: bool = False,
|
||||
) -> list[TokensTextLogprobsPromptLogprobs]:
|
||||
outputs: list[TokensTextLogprobsPromptLogprobs] = []
|
||||
for req_output in req_outputs:
|
||||
@ -861,9 +862,26 @@ class VllmRunner:
|
||||
output_str = sample.text
|
||||
output_ids = list(sample.token_ids)
|
||||
output_logprobs = sample.logprobs
|
||||
outputs.append(
|
||||
(output_ids, output_str, output_logprobs, req_output.prompt_logprobs)
|
||||
)
|
||||
if include_prompt_token_ids:
|
||||
outputs.append(
|
||||
( # type: ignore[arg-type]
|
||||
output_ids,
|
||||
output_str,
|
||||
output_logprobs,
|
||||
req_output.prompt_token_ids,
|
||||
req_output.prompt_logprobs,
|
||||
)
|
||||
)
|
||||
else:
|
||||
outputs.append(
|
||||
(
|
||||
output_ids,
|
||||
output_str,
|
||||
output_logprobs,
|
||||
req_output.prompt_logprobs,
|
||||
)
|
||||
)
|
||||
|
||||
return outputs
|
||||
|
||||
def generate_w_logprobs(
|
||||
@ -873,6 +891,7 @@ class VllmRunner:
|
||||
images: PromptImageInput | None = None,
|
||||
audios: PromptAudioInput | None = None,
|
||||
videos: PromptVideoInput | None = None,
|
||||
include_prompt_token_ids: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> list[TokensTextLogprobs] | list[TokensTextLogprobsPromptLogprobs]:
|
||||
inputs = self.get_inputs(prompts, images=images, videos=videos, audios=audios)
|
||||
@ -882,7 +901,7 @@ class VllmRunner:
|
||||
)
|
||||
|
||||
toks_str_logsprobs_prompt_logprobs = self._final_steps_generate_w_logprobs(
|
||||
req_outputs
|
||||
req_outputs, include_prompt_token_ids
|
||||
)
|
||||
# Omit prompt logprobs if not required by sampling params
|
||||
return (
|
||||
|
||||
@ -605,3 +605,79 @@ def test_spec_decode_logprobs(
|
||||
)
|
||||
assert ref_logprob.rank == spec_logprob.rank
|
||||
assert ref_logprob.decoded_token == spec_logprob.decoded_token
|
||||
|
||||
|
||||
def test_prompt_logprobs_with_chunking_and_preemption():
|
||||
"""Test that prompt logprobs are correctly returned when using
|
||||
both chunked prefill and preemption.
|
||||
|
||||
This test ensures that the num_prompt_logprobs tracking persists
|
||||
across preemptions and prefill chunks.
|
||||
"""
|
||||
|
||||
# Create prompts that will trigger chunking and preemption
|
||||
prompts = [
|
||||
"The following numbers of the sequence "
|
||||
+ ", ".join(str(i) for i in range(10))
|
||||
+ " are:",
|
||||
"In one word, the capital of France is ",
|
||||
] + [f"Tell me about the number {i}: " for i in range(32)]
|
||||
|
||||
sampling_params = SamplingParams(
|
||||
temperature=0.0,
|
||||
max_tokens=40,
|
||||
min_tokens=20,
|
||||
prompt_logprobs=2, # Request prompt logprobs
|
||||
)
|
||||
|
||||
with VllmRunner(
|
||||
"Qwen/Qwen3-0.6B",
|
||||
max_model_len=512,
|
||||
enable_chunked_prefill=True,
|
||||
max_num_batched_tokens=48, # Force prefill chunking
|
||||
num_gpu_blocks_override=32, # Force preemptions
|
||||
disable_log_stats=False,
|
||||
gpu_memory_utilization=0.25,
|
||||
) as vllm_model:
|
||||
metrics_before = vllm_model.llm.get_metrics()
|
||||
|
||||
# Generate with prompt logprobs using generate_w_logprobs which
|
||||
# returns (output_ids, output_str, output_logprobs, prompt_logprobs)
|
||||
outputs = vllm_model.generate_w_logprobs(
|
||||
prompts, sampling_params=sampling_params, include_prompt_token_ids=True
|
||||
)
|
||||
|
||||
# Verify that all outputs have prompt logprobs
|
||||
for i, output in enumerate(outputs):
|
||||
_, _, _, prompt_token_ids, prompt_logprobs = output
|
||||
assert prompt_logprobs is not None and len(prompt_logprobs) > 0, (
|
||||
f"Output {i} missing prompt logprobs"
|
||||
)
|
||||
assert len(prompt_logprobs) == len(prompt_token_ids), (
|
||||
"Unexpected number of prompt logprob positions"
|
||||
)
|
||||
|
||||
# Each position should have the requested number of logprobs
|
||||
for pos, logprobs_dict in enumerate(prompt_logprobs):
|
||||
if logprobs_dict is not None: # First token may be None
|
||||
assert (
|
||||
sampling_params.prompt_logprobs
|
||||
<= len(logprobs_dict)
|
||||
<= sampling_params.prompt_logprobs + 1
|
||||
), (
|
||||
f"Output {i} position {pos} has {len(logprobs_dict)} "
|
||||
f"logprobs, expected {sampling_params.prompt_logprobs}"
|
||||
)
|
||||
|
||||
# Check that we actually had preemptions
|
||||
metrics_after = vllm_model.llm.get_metrics()
|
||||
preemptions_before = next(
|
||||
(m.value for m in metrics_before if m.name == "vllm:num_preemptions"), 0
|
||||
)
|
||||
preemptions_after = next(
|
||||
(m.value for m in metrics_after if m.name == "vllm:num_preemptions"), 0
|
||||
)
|
||||
preemptions = preemptions_after - preemptions_before
|
||||
assert preemptions > 0, "Test did not trigger any preemptions"
|
||||
|
||||
print(f"Test passed with {preemptions} preemptions")
|
||||
|
||||
@ -219,9 +219,6 @@ class InputBatch:
|
||||
self.generators: dict[int, torch.Generator] = {}
|
||||
|
||||
self.num_logprobs: dict[str, int] = {}
|
||||
# NOTE(rob): num_prompt_logprobs only includes reqs
|
||||
# that are currently in the prefill phase.
|
||||
self.num_prompt_logprobs: dict[str, int] = {}
|
||||
|
||||
# To accumulate prompt logprobs tensor chunks across prefill steps.
|
||||
self.in_progress_prompt_logprobs_cpu: dict[str, LogprobsTensors] = {}
|
||||
@ -385,12 +382,6 @@ class InputBatch:
|
||||
if sampling_params.logprobs == -1
|
||||
else sampling_params.logprobs
|
||||
)
|
||||
if sampling_params.prompt_logprobs is not None:
|
||||
self.num_prompt_logprobs[req_id] = (
|
||||
self.vocab_size
|
||||
if sampling_params.prompt_logprobs == -1
|
||||
else sampling_params.prompt_logprobs
|
||||
)
|
||||
|
||||
if sampling_params.allowed_token_ids:
|
||||
self.has_allowed_token_ids.add(req_id)
|
||||
@ -488,7 +479,6 @@ class InputBatch:
|
||||
self.repetition_penalties_reqs.discard(req_id)
|
||||
self.generators.pop(req_index, None)
|
||||
self.num_logprobs.pop(req_id, None)
|
||||
self.num_prompt_logprobs.pop(req_id, None)
|
||||
self.in_progress_prompt_logprobs_cpu.pop(req_id, None)
|
||||
|
||||
self.has_allowed_token_ids.discard(req_id)
|
||||
@ -972,10 +962,6 @@ class InputBatch:
|
||||
def max_num_logprobs(self) -> int | None:
|
||||
return max(self.num_logprobs.values()) if self.num_logprobs else None
|
||||
|
||||
@property
|
||||
def no_prompt_logprob(self) -> bool:
|
||||
return not self.num_prompt_logprobs
|
||||
|
||||
@property
|
||||
def no_allowed_token_ids(self) -> bool:
|
||||
return len(self.has_allowed_token_ids) == 0
|
||||
|
||||
@ -393,6 +393,9 @@ class GPUModelRunner(
|
||||
|
||||
# Request states.
|
||||
self.requests: dict[str, CachedRequestState] = {}
|
||||
# NOTE(rob): num_prompt_logprobs only includes reqs
|
||||
# that are currently in the prefill phase.
|
||||
self.num_prompt_logprobs: dict[str, int] = {}
|
||||
self.comm_stream = torch.cuda.Stream()
|
||||
|
||||
# Input Batch
|
||||
@ -687,6 +690,7 @@ class GPUModelRunner(
|
||||
# Remove finished requests from the cached states.
|
||||
for req_id in scheduler_output.finished_req_ids:
|
||||
self.requests.pop(req_id, None)
|
||||
self.num_prompt_logprobs.pop(req_id, None)
|
||||
# Remove the finished requests from the persistent batch.
|
||||
# NOTE(woosuk): There could be an edge case where finished_req_ids and
|
||||
# scheduled_req_ids overlap. This happens when a request is aborted and
|
||||
@ -755,6 +759,13 @@ class GPUModelRunner(
|
||||
)
|
||||
self.requests[req_id] = req_state
|
||||
|
||||
if sampling_params and sampling_params.prompt_logprobs is not None:
|
||||
self.num_prompt_logprobs[req_id] = (
|
||||
self.input_batch.vocab_size
|
||||
if sampling_params.prompt_logprobs == -1
|
||||
else sampling_params.prompt_logprobs
|
||||
)
|
||||
|
||||
# Only relevant for models using M-RoPE (e.g, Qwen2-VL)
|
||||
if self.uses_mrope:
|
||||
self._init_mrope_positions(req_state)
|
||||
@ -2671,7 +2682,7 @@ class GPUModelRunner(
|
||||
scheduler_output, self.vllm_config
|
||||
)
|
||||
if self.cache_config.kv_sharing_fast_prefill:
|
||||
assert not self.input_batch.num_prompt_logprobs, (
|
||||
assert not self.num_prompt_logprobs, (
|
||||
"--kv-sharing-fast-prefill produces incorrect "
|
||||
"logprobs for prompt tokens, tokens, please disable "
|
||||
"it when the requests need prompt logprobs"
|
||||
@ -3436,7 +3447,7 @@ class GPUModelRunner(
|
||||
hidden_states: torch.Tensor,
|
||||
num_scheduled_tokens: dict[str, int],
|
||||
) -> dict[str, LogprobsTensors | None]:
|
||||
num_prompt_logprobs_dict = self.input_batch.num_prompt_logprobs
|
||||
num_prompt_logprobs_dict = self.num_prompt_logprobs
|
||||
if not num_prompt_logprobs_dict:
|
||||
return {}
|
||||
|
||||
@ -3447,7 +3458,10 @@ class GPUModelRunner(
|
||||
# maintainable loop over optimal performance.
|
||||
completed_prefill_reqs = []
|
||||
for req_id, num_prompt_logprobs in num_prompt_logprobs_dict.items():
|
||||
num_tokens = num_scheduled_tokens[req_id]
|
||||
num_tokens = num_scheduled_tokens.get(req_id)
|
||||
if num_tokens is None:
|
||||
# This can happen if the request was preempted in prefill stage.
|
||||
continue
|
||||
|
||||
# Get metadata for this request.
|
||||
request = self.requests[req_id]
|
||||
|
||||
@ -149,9 +149,6 @@ class InputBatch:
|
||||
self.generators: dict[int, torch.Generator] = {}
|
||||
|
||||
self.num_logprobs: dict[str, int] = {}
|
||||
# NOTE(rob): num_prompt_logprobs only includes reqs
|
||||
# that are currently in the prefill phase.
|
||||
self.num_prompt_logprobs: dict[str, int] = {}
|
||||
|
||||
# To accumulate prompt logprobs tensor chunks across prefill steps.
|
||||
self.in_progress_prompt_logprobs_cpu: dict[str, LogprobsTensors] = {}
|
||||
@ -256,8 +253,6 @@ class InputBatch:
|
||||
|
||||
if sampling_params.logprobs is not None:
|
||||
self.num_logprobs[req_id] = sampling_params.logprobs
|
||||
if sampling_params.prompt_logprobs is not None:
|
||||
self.num_prompt_logprobs[req_id] = sampling_params.prompt_logprobs
|
||||
if sampling_params.logit_bias is not None:
|
||||
self.logit_bias[req_index] = sampling_params.logit_bias
|
||||
|
||||
@ -317,7 +312,6 @@ class InputBatch:
|
||||
self.repetition_penalties_reqs.discard(req_id)
|
||||
self.generators.pop(req_index, None)
|
||||
self.num_logprobs.pop(req_id, None)
|
||||
self.num_prompt_logprobs.pop(req_id, None)
|
||||
self.in_progress_prompt_logprobs_cpu.pop(req_id, None)
|
||||
|
||||
# LoRA
|
||||
@ -584,10 +578,6 @@ class InputBatch:
|
||||
def max_num_logprobs(self) -> int | None:
|
||||
return max(self.num_logprobs.values()) if self.num_logprobs else None
|
||||
|
||||
@property
|
||||
def no_prompt_logprob(self) -> bool:
|
||||
return not self.num_prompt_logprobs
|
||||
|
||||
@property
|
||||
def no_allowed_token_ids(self) -> bool:
|
||||
return len(self.has_allowed_token_ids) == 0
|
||||
|
||||
@ -247,6 +247,9 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
|
||||
# Request states.
|
||||
self.requests: dict[str, CachedRequestState] = {}
|
||||
# NOTE(rob): num_prompt_logprobs only includes reqs
|
||||
# that are currently in the prefill phase.
|
||||
self.num_prompt_logprobs: dict[str, int] = {}
|
||||
|
||||
# Initialize input batch early to avoid AttributeError in _update_states
|
||||
self.input_batch = InputBatch(
|
||||
@ -420,6 +423,7 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
# Remove finished requests from the cached states.
|
||||
for req_id in scheduler_output.finished_req_ids:
|
||||
self.requests.pop(req_id, None)
|
||||
self.num_prompt_logprobs.pop(req_id, None)
|
||||
|
||||
# Remove the finished requests from the persistent batch.
|
||||
# NOTE(woosuk): There could be an edge case where finished_req_ids and
|
||||
@ -477,6 +481,13 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
lora_request=new_req_data.lora_request,
|
||||
)
|
||||
|
||||
if sampling_params and sampling_params.prompt_logprobs is not None:
|
||||
self.num_prompt_logprobs[req_id] = (
|
||||
self.input_batch.vocab_size
|
||||
if sampling_params.prompt_logprobs == -1
|
||||
else sampling_params.prompt_logprobs
|
||||
)
|
||||
|
||||
req_ids_to_add.append(req_id)
|
||||
|
||||
# Update the states of the running/resumed requests.
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user