diff --git a/tests/v1/e2e/test_async_sched_and_preempt.py b/tests/v1/e2e/test_async_sched_and_preempt.py index 54fa9aca381bb..0f7ccb35a7576 100644 --- a/tests/v1/e2e/test_async_sched_and_preempt.py +++ b/tests/v1/e2e/test_async_sched_and_preempt.py @@ -28,9 +28,8 @@ def test_preempt_and_async_scheduling_e2e(monkeypatch: pytest.MonkeyPatch): sampling_param_tests: list[dict[str, Any]] = [ dict(), # dict(min_tokens=20), - # TODO enable these with https://github.com/vllm-project/vllm/pull/26467. - # dict(repetition_penalty=0.1), - # dict(bad_words=[]), + dict(presence_penalty=-1.0), + dict(bad_words=["the", " the"]), ] default_params = dict( @@ -42,9 +41,9 @@ def test_preempt_and_async_scheduling_e2e(monkeypatch: pytest.MonkeyPatch): m.setenv("VLLM_ATTENTION_BACKEND", "FLEX_ATTENTION") # m.setenv("VLLM_KERNEL_OVERRIDE_BATCH_INVARIANT", "1") - outputs = [] + outputs: list[tuple[str, list]] = [] for test_preemption in [False, True]: - for executor in ["uni", "mp"]: + for executor in ["mp", "uni"]: for async_scheduling in [False, True]: cache_arg: dict[str, Any] = ( dict(num_gpu_blocks_override=32) @@ -78,6 +77,21 @@ def test_preempt_and_async_scheduling_e2e(monkeypatch: pytest.MonkeyPatch): ), ) ) + + if not outputs: + # First check that the different parameter configs + # actually result in different output. + for other_test, params in zip( + results[1:], sampling_param_tests[1:] + ): + with pytest.raises(AssertionError): + check_outputs_equal( + outputs_0_lst=results[0], + outputs_1_lst=other_test, + name_0=f"baseline params={params}", + name_1=f"other params={params}", + ) + outputs.append((test_config, results)) baseline_config, baseline_tests = outputs[0] diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 5bc7a488bf83a..0f1504724d7c6 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -737,7 +737,9 @@ class Scheduler(SchedulerInterface): req_to_new_blocks[req_id].get_block_ids(allow_none=True) ) num_computed_tokens.append(req.num_computed_tokens) - num_output_tokens.append(req.num_output_tokens) + num_output_tokens.append( + req.num_output_tokens + req.num_output_placeholders + ) return CachedRequestData( req_ids=req_ids, diff --git a/vllm/v1/worker/gpu_input_batch.py b/vllm/v1/worker/gpu_input_batch.py index bc1186e5feb7b..0ced400bcb663 100644 --- a/vllm/v1/worker/gpu_input_batch.py +++ b/vllm/v1/worker/gpu_input_batch.py @@ -79,6 +79,7 @@ class InputBatch: block_sizes: list[int], # The block_size of each kv cache group kernel_block_sizes: list[int], logitsprocs: Optional[LogitsProcessors] = None, + logitsprocs_need_output_token_ids: bool = False, is_spec_decode: bool = False, is_pooling_model: bool = False, num_speculative_tokens: int = 0, @@ -240,6 +241,7 @@ class InputBatch: # Store provided logitsprocs. If none are provided, initialize empty # data structure self.logitsprocs = logitsprocs or LogitsProcessors() + self.logitsprocs_need_output_token_ids = logitsprocs_need_output_token_ids # Store last speculative tokens for sampler. self.spec_token_ids: list[Optional[list[int]]] = [] @@ -252,6 +254,11 @@ class InputBatch: # Cached reference to the GPU tensor of previously sampled tokens self.prev_sampled_token_ids: Optional[torch.Tensor] = None self.prev_req_id_to_index: Optional[dict[str, int]] = None + # These are used to update output_token_ids with real sampled + # ids from prior step, if required by current sampling params + # (e.g. penalties). + self.sampled_token_ids_cpu: Optional[torch.Tensor] = None + self.async_copy_ready_event: Optional[torch.cuda.Event] = None @property def req_ids(self) -> list[str]: @@ -776,6 +783,19 @@ class InputBatch: self._make_prompt_token_ids_tensor() if needs_prompt_token_ids else None ) + # Only set output_token_ids if required by the current requests' + # sampling parameters. + needs_output_token_ids = ( + not self.no_penalties + or bool(self.bad_words_token_ids) + or self.logitsprocs_need_output_token_ids + ) + output_token_ids = ( + cast(list[list[int]], self.req_output_token_ids) + if needs_output_token_ids + else [] + ) + allowed_token_ids_mask: Optional[torch.Tensor] = None if not self.no_allowed_token_ids: assert self.allowed_token_ids_mask is not None @@ -798,7 +818,7 @@ class InputBatch: frequency_penalties=self.frequency_penalties[:num_reqs], presence_penalties=self.presence_penalties[:num_reqs], repetition_penalties=self.repetition_penalties[:num_reqs], - output_token_ids=cast(list[list[int]], self.req_output_token_ids), + output_token_ids=output_token_ids, spec_token_ids=cast(list[list[int]], self.spec_token_ids), no_penalties=self.no_penalties, allowed_token_ids_mask=allowed_token_ids_mask, @@ -859,6 +879,52 @@ class InputBatch: return prompt_lora_mapping, token_lora_mapping, active_lora_requests + def set_async_sampled_token_ids( + self, + sampled_token_ids_cpu: torch.Tensor, + async_copy_ready_event: torch.cuda.Event, + ) -> None: + """ + In async scheduling case, store ref to sampled_token_ids_cpu + tensor and corresponding copy-ready event. Used to repair + output_token_ids prior to sampling, if needed by logits processors. + """ + if self.sampling_metadata.output_token_ids: + self.sampled_token_ids_cpu = sampled_token_ids_cpu + self.async_copy_ready_event = async_copy_ready_event + else: + self.sampled_token_ids_cpu = None + self.async_copy_ready_event = None + + def update_async_output_token_ids(self) -> None: + """ + In async scheduling case, update output_token_ids in sampling metadata + from prior steps sampled token ids once they've finished copying to CPU. + This is called right before they are needed by the logits processors. + """ + output_token_ids = self.sampling_metadata.output_token_ids + if self.sampled_token_ids_cpu is None or not output_token_ids: + # Output token ids not needed or not async scheduling. + return + + assert self.prev_req_id_to_index is not None + sampled_token_ids = None + for index, req_id in enumerate(self.req_ids): + prev_index = self.prev_req_id_to_index.get(req_id) + if prev_index is None: + continue + req_output_token_ids = output_token_ids[index] + if not req_output_token_ids or req_output_token_ids[-1] != -1: + # Final output id is not a placeholder, some tokens must have + # been discarded after a kv-load failure. + continue + if sampled_token_ids is None: + assert self.async_copy_ready_event is not None + self.async_copy_ready_event.synchronize() + sampled_token_ids = self.sampled_token_ids_cpu.squeeze(-1).tolist() + # Replace placeholder token id with actual sampled id. + req_output_token_ids[-1] = sampled_token_ids[prev_index] + @property def num_reqs(self) -> int: return len(self.req_id_to_index) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index ba27385589a7a..2dce58237c7b0 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -178,7 +178,7 @@ class AsyncGPUModelRunnerOutput(AsyncModelRunnerOutput): self._invalid_req_indices = invalid_req_indices # Event on the copy stream so we can synchronize the non-blocking copy. - self._async_copy_ready_event = torch.cuda.Event() + self.async_copy_ready_event = torch.cuda.Event() # Keep a reference to the device tensor to avoid it being # deallocated until we finish copying it to the host. @@ -188,22 +188,22 @@ class AsyncGPUModelRunnerOutput(AsyncModelRunnerOutput): default_stream = torch.cuda.current_stream() with torch.cuda.stream(async_output_copy_stream): async_output_copy_stream.wait_stream(default_stream) - self._sampled_token_ids_cpu = self._sampled_token_ids.to( + self.sampled_token_ids_cpu = self._sampled_token_ids.to( "cpu", non_blocking=True ) - self._async_copy_ready_event.record() + self.async_copy_ready_event.record() def get_output(self) -> ModelRunnerOutput: """Copy the device tensors to the host and return a ModelRunnerOutput. This function blocks until the copy is finished. """ - self._async_copy_ready_event.synchronize() + self.async_copy_ready_event.synchronize() # Release the device tensor once the copy has completed del self._sampled_token_ids - valid_sampled_token_ids = self._sampled_token_ids_cpu.tolist() + valid_sampled_token_ids = self.sampled_token_ids_cpu.tolist() for i in self._invalid_req_indices: valid_sampled_token_ids[i].clear() @@ -349,6 +349,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # solution, we initialize the input batch here, and re-initialize it # in `initialize_kv_cache` if the block_sizes here is different from # the block_sizes in the kv cache config. + custom_logitsprocs = model_config.logits_processors self.input_batch = InputBatch( max_num_reqs=self.max_num_reqs, # We need to use the encoder length for encoder-decoer @@ -366,8 +367,11 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): self.device, self.pin_memory, self.is_pooling_model, - self.vllm_config.model_config.logits_processors, + custom_logitsprocs, ), + # We currently don't know whether a particular custom logits processor + # uses output token ids so we set this conservatively. + logitsprocs_need_output_token_ids=bool(custom_logitsprocs), is_pooling_model=self.is_pooling_model, ) @@ -2210,6 +2214,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # Sample the next token and get logprobs if needed. sampling_metadata = self.input_batch.sampling_metadata if spec_decode_metadata is None: + # Update output token ids with tokens sampled in last step + # if async scheduling and required by current sampling params. + self.input_batch.update_async_output_token_ids() return self.sampler( logits=logits, sampling_metadata=sampling_metadata, @@ -2666,13 +2673,22 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): if not self.use_async_scheduling: return output - return AsyncGPUModelRunnerOutput( + async_output = AsyncGPUModelRunnerOutput( model_runner_output=output, sampled_token_ids=sampler_output.sampled_token_ids, invalid_req_indices=invalid_req_indices, async_output_copy_stream=self.async_output_copy_stream, ) + # Save ref of sampled_token_ids CPU tensor if the batch contains + # any requests with sampling params that that require output ids. + self.input_batch.set_async_sampled_token_ids( + async_output.sampled_token_ids_cpu, + async_output.async_copy_ready_event, + ) + + return async_output + def take_draft_token_ids(self) -> Optional[DraftTokenIds]: if self._draft_token_ids is None: return None @@ -4198,6 +4214,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): kernel_block_sizes=kernel_block_sizes, is_spec_decode=bool(self.vllm_config.speculative_config), logitsprocs=self.input_batch.logitsprocs, + logitsprocs_need_output_token_ids=self.input_batch.logitsprocs_need_output_token_ids, is_pooling_model=self.is_pooling_model, num_speculative_tokens=( self.vllm_config.speculative_config.num_speculative_tokens