diff --git a/tests/conftest.py b/tests/conftest.py index 5afdb225b892..163593eb3f14 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -853,6 +853,7 @@ class VllmRunner: @staticmethod def _final_steps_generate_w_logprobs( req_outputs: list[RequestOutput], + include_prompt_token_ids: bool = False, ) -> list[TokensTextLogprobsPromptLogprobs]: outputs: list[TokensTextLogprobsPromptLogprobs] = [] for req_output in req_outputs: @@ -861,9 +862,26 @@ class VllmRunner: output_str = sample.text output_ids = list(sample.token_ids) output_logprobs = sample.logprobs - outputs.append( - (output_ids, output_str, output_logprobs, req_output.prompt_logprobs) - ) + if include_prompt_token_ids: + outputs.append( + ( # type: ignore[arg-type] + output_ids, + output_str, + output_logprobs, + req_output.prompt_token_ids, + req_output.prompt_logprobs, + ) + ) + else: + outputs.append( + ( + output_ids, + output_str, + output_logprobs, + req_output.prompt_logprobs, + ) + ) + return outputs def generate_w_logprobs( @@ -873,6 +891,7 @@ class VllmRunner: images: PromptImageInput | None = None, audios: PromptAudioInput | None = None, videos: PromptVideoInput | None = None, + include_prompt_token_ids: bool = False, **kwargs: Any, ) -> list[TokensTextLogprobs] | list[TokensTextLogprobsPromptLogprobs]: inputs = self.get_inputs(prompts, images=images, videos=videos, audios=audios) @@ -882,7 +901,7 @@ class VllmRunner: ) toks_str_logsprobs_prompt_logprobs = self._final_steps_generate_w_logprobs( - req_outputs + req_outputs, include_prompt_token_ids ) # Omit prompt logprobs if not required by sampling params return ( diff --git a/tests/v1/sample/test_logprobs.py b/tests/v1/sample/test_logprobs.py index c0b0e1ea226e..c89c33be80c1 100644 --- a/tests/v1/sample/test_logprobs.py +++ b/tests/v1/sample/test_logprobs.py @@ -605,3 +605,79 @@ def test_spec_decode_logprobs( ) assert ref_logprob.rank == spec_logprob.rank assert ref_logprob.decoded_token == spec_logprob.decoded_token + + +def test_prompt_logprobs_with_chunking_and_preemption(): + """Test that prompt logprobs are correctly returned when using + both chunked prefill and preemption. + + This test ensures that the num_prompt_logprobs tracking persists + across preemptions and prefill chunks. + """ + + # Create prompts that will trigger chunking and preemption + prompts = [ + "The following numbers of the sequence " + + ", ".join(str(i) for i in range(10)) + + " are:", + "In one word, the capital of France is ", + ] + [f"Tell me about the number {i}: " for i in range(32)] + + sampling_params = SamplingParams( + temperature=0.0, + max_tokens=40, + min_tokens=20, + prompt_logprobs=2, # Request prompt logprobs + ) + + with VllmRunner( + "Qwen/Qwen3-0.6B", + max_model_len=512, + enable_chunked_prefill=True, + max_num_batched_tokens=48, # Force prefill chunking + num_gpu_blocks_override=32, # Force preemptions + disable_log_stats=False, + gpu_memory_utilization=0.25, + ) as vllm_model: + metrics_before = vllm_model.llm.get_metrics() + + # Generate with prompt logprobs using generate_w_logprobs which + # returns (output_ids, output_str, output_logprobs, prompt_logprobs) + outputs = vllm_model.generate_w_logprobs( + prompts, sampling_params=sampling_params, include_prompt_token_ids=True + ) + + # Verify that all outputs have prompt logprobs + for i, output in enumerate(outputs): + _, _, _, prompt_token_ids, prompt_logprobs = output + assert prompt_logprobs is not None and len(prompt_logprobs) > 0, ( + f"Output {i} missing prompt logprobs" + ) + assert len(prompt_logprobs) == len(prompt_token_ids), ( + "Unexpected number of prompt logprob positions" + ) + + # Each position should have the requested number of logprobs + for pos, logprobs_dict in enumerate(prompt_logprobs): + if logprobs_dict is not None: # First token may be None + assert ( + sampling_params.prompt_logprobs + <= len(logprobs_dict) + <= sampling_params.prompt_logprobs + 1 + ), ( + f"Output {i} position {pos} has {len(logprobs_dict)} " + f"logprobs, expected {sampling_params.prompt_logprobs}" + ) + + # Check that we actually had preemptions + metrics_after = vllm_model.llm.get_metrics() + preemptions_before = next( + (m.value for m in metrics_before if m.name == "vllm:num_preemptions"), 0 + ) + preemptions_after = next( + (m.value for m in metrics_after if m.name == "vllm:num_preemptions"), 0 + ) + preemptions = preemptions_after - preemptions_before + assert preemptions > 0, "Test did not trigger any preemptions" + + print(f"Test passed with {preemptions} preemptions") diff --git a/vllm/v1/worker/gpu_input_batch.py b/vllm/v1/worker/gpu_input_batch.py index 7b4bc1d2a224..d6fef450c028 100644 --- a/vllm/v1/worker/gpu_input_batch.py +++ b/vllm/v1/worker/gpu_input_batch.py @@ -219,9 +219,6 @@ class InputBatch: self.generators: dict[int, torch.Generator] = {} self.num_logprobs: dict[str, int] = {} - # NOTE(rob): num_prompt_logprobs only includes reqs - # that are currently in the prefill phase. - self.num_prompt_logprobs: dict[str, int] = {} # To accumulate prompt logprobs tensor chunks across prefill steps. self.in_progress_prompt_logprobs_cpu: dict[str, LogprobsTensors] = {} @@ -385,12 +382,6 @@ class InputBatch: if sampling_params.logprobs == -1 else sampling_params.logprobs ) - if sampling_params.prompt_logprobs is not None: - self.num_prompt_logprobs[req_id] = ( - self.vocab_size - if sampling_params.prompt_logprobs == -1 - else sampling_params.prompt_logprobs - ) if sampling_params.allowed_token_ids: self.has_allowed_token_ids.add(req_id) @@ -488,7 +479,6 @@ class InputBatch: self.repetition_penalties_reqs.discard(req_id) self.generators.pop(req_index, None) self.num_logprobs.pop(req_id, None) - self.num_prompt_logprobs.pop(req_id, None) self.in_progress_prompt_logprobs_cpu.pop(req_id, None) self.has_allowed_token_ids.discard(req_id) @@ -972,10 +962,6 @@ class InputBatch: def max_num_logprobs(self) -> int | None: return max(self.num_logprobs.values()) if self.num_logprobs else None - @property - def no_prompt_logprob(self) -> bool: - return not self.num_prompt_logprobs - @property def no_allowed_token_ids(self) -> bool: return len(self.has_allowed_token_ids) == 0 diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 979f97758703..49285a7b8e0a 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -393,6 +393,9 @@ class GPUModelRunner( # Request states. self.requests: dict[str, CachedRequestState] = {} + # NOTE(rob): num_prompt_logprobs only includes reqs + # that are currently in the prefill phase. + self.num_prompt_logprobs: dict[str, int] = {} self.comm_stream = torch.cuda.Stream() # Input Batch @@ -687,6 +690,7 @@ class GPUModelRunner( # Remove finished requests from the cached states. for req_id in scheduler_output.finished_req_ids: self.requests.pop(req_id, None) + self.num_prompt_logprobs.pop(req_id, None) # Remove the finished requests from the persistent batch. # NOTE(woosuk): There could be an edge case where finished_req_ids and # scheduled_req_ids overlap. This happens when a request is aborted and @@ -755,6 +759,13 @@ class GPUModelRunner( ) self.requests[req_id] = req_state + if sampling_params and sampling_params.prompt_logprobs is not None: + self.num_prompt_logprobs[req_id] = ( + self.input_batch.vocab_size + if sampling_params.prompt_logprobs == -1 + else sampling_params.prompt_logprobs + ) + # Only relevant for models using M-RoPE (e.g, Qwen2-VL) if self.uses_mrope: self._init_mrope_positions(req_state) @@ -2671,7 +2682,7 @@ class GPUModelRunner( scheduler_output, self.vllm_config ) if self.cache_config.kv_sharing_fast_prefill: - assert not self.input_batch.num_prompt_logprobs, ( + assert not self.num_prompt_logprobs, ( "--kv-sharing-fast-prefill produces incorrect " "logprobs for prompt tokens, tokens, please disable " "it when the requests need prompt logprobs" @@ -3436,7 +3447,7 @@ class GPUModelRunner( hidden_states: torch.Tensor, num_scheduled_tokens: dict[str, int], ) -> dict[str, LogprobsTensors | None]: - num_prompt_logprobs_dict = self.input_batch.num_prompt_logprobs + num_prompt_logprobs_dict = self.num_prompt_logprobs if not num_prompt_logprobs_dict: return {} @@ -3447,7 +3458,10 @@ class GPUModelRunner( # maintainable loop over optimal performance. completed_prefill_reqs = [] for req_id, num_prompt_logprobs in num_prompt_logprobs_dict.items(): - num_tokens = num_scheduled_tokens[req_id] + num_tokens = num_scheduled_tokens.get(req_id) + if num_tokens is None: + # This can happen if the request was preempted in prefill stage. + continue # Get metadata for this request. request = self.requests[req_id] diff --git a/vllm/v1/worker/tpu_input_batch.py b/vllm/v1/worker/tpu_input_batch.py index 6bf4f9193184..2ed65ca9d31c 100644 --- a/vllm/v1/worker/tpu_input_batch.py +++ b/vllm/v1/worker/tpu_input_batch.py @@ -149,9 +149,6 @@ class InputBatch: self.generators: dict[int, torch.Generator] = {} self.num_logprobs: dict[str, int] = {} - # NOTE(rob): num_prompt_logprobs only includes reqs - # that are currently in the prefill phase. - self.num_prompt_logprobs: dict[str, int] = {} # To accumulate prompt logprobs tensor chunks across prefill steps. self.in_progress_prompt_logprobs_cpu: dict[str, LogprobsTensors] = {} @@ -256,8 +253,6 @@ class InputBatch: if sampling_params.logprobs is not None: self.num_logprobs[req_id] = sampling_params.logprobs - if sampling_params.prompt_logprobs is not None: - self.num_prompt_logprobs[req_id] = sampling_params.prompt_logprobs if sampling_params.logit_bias is not None: self.logit_bias[req_index] = sampling_params.logit_bias @@ -317,7 +312,6 @@ class InputBatch: self.repetition_penalties_reqs.discard(req_id) self.generators.pop(req_index, None) self.num_logprobs.pop(req_id, None) - self.num_prompt_logprobs.pop(req_id, None) self.in_progress_prompt_logprobs_cpu.pop(req_id, None) # LoRA @@ -584,10 +578,6 @@ class InputBatch: def max_num_logprobs(self) -> int | None: return max(self.num_logprobs.values()) if self.num_logprobs else None - @property - def no_prompt_logprob(self) -> bool: - return not self.num_prompt_logprobs - @property def no_allowed_token_ids(self) -> bool: return len(self.has_allowed_token_ids) == 0 diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index 5f6012ec614c..72d4474b8962 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -247,6 +247,9 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # Request states. self.requests: dict[str, CachedRequestState] = {} + # NOTE(rob): num_prompt_logprobs only includes reqs + # that are currently in the prefill phase. + self.num_prompt_logprobs: dict[str, int] = {} # Initialize input batch early to avoid AttributeError in _update_states self.input_batch = InputBatch( @@ -420,6 +423,7 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # Remove finished requests from the cached states. for req_id in scheduler_output.finished_req_ids: self.requests.pop(req_id, None) + self.num_prompt_logprobs.pop(req_id, None) # Remove the finished requests from the persistent batch. # NOTE(woosuk): There could be an edge case where finished_req_ids and @@ -477,6 +481,13 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): lora_request=new_req_data.lora_request, ) + if sampling_params and sampling_params.prompt_logprobs is not None: + self.num_prompt_logprobs[req_id] = ( + self.input_batch.vocab_size + if sampling_params.prompt_logprobs == -1 + else sampling_params.prompt_logprobs + ) + req_ids_to_add.append(req_id) # Update the states of the running/resumed requests.