From 70189d8eb0863d499014616cf90cd3b7a9bbc9e5 Mon Sep 17 00:00:00 2001 From: zhuhaoran Date: Fri, 5 Dec 2025 15:51:43 +0800 Subject: [PATCH 01/14] fix: copy actual sampled_token_ids to req_output_token_ids Signed-off-by: zhuhaoran --- vllm/v1/worker/gpu_input_batch.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/vllm/v1/worker/gpu_input_batch.py b/vllm/v1/worker/gpu_input_batch.py index ead7a3619dea5..0f523eacfa4bf 100644 --- a/vllm/v1/worker/gpu_input_batch.py +++ b/vllm/v1/worker/gpu_input_batch.py @@ -949,9 +949,11 @@ class InputBatch: if sampled_token_ids is None: assert self.async_copy_ready_event is not None self.async_copy_ready_event.synchronize() - sampled_token_ids = self.sampled_token_ids_cpu.squeeze(-1).tolist() + sampled_token_ids = self.sampled_token_ids_cpu.tolist() # Replace placeholder token id with actual sampled id. - req_output_token_ids[-1] = sampled_token_ids[prev_index] + req_output_token_ids[-len(sampled_token_ids[prev_index]) :] = ( + sampled_token_ids[prev_index] + ) @property def num_reqs(self) -> int: From 0999b7224a0f7334cdc9db2fb7be725380978647 Mon Sep 17 00:00:00 2001 From: zhuhaoran Date: Sat, 6 Dec 2025 13:28:08 +0800 Subject: [PATCH 02/14] fix: apply suggestions from @njhill Signed-off-by: zhuhaoran --- vllm/v1/worker/gpu_input_batch.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/vllm/v1/worker/gpu_input_batch.py b/vllm/v1/worker/gpu_input_batch.py index 0f523eacfa4bf..ff71c6d54179d 100644 --- a/vllm/v1/worker/gpu_input_batch.py +++ b/vllm/v1/worker/gpu_input_batch.py @@ -950,10 +950,9 @@ class InputBatch: assert self.async_copy_ready_event is not None self.async_copy_ready_event.synchronize() sampled_token_ids = self.sampled_token_ids_cpu.tolist() - # Replace placeholder token id with actual sampled id. - req_output_token_ids[-len(sampled_token_ids[prev_index]) :] = ( - sampled_token_ids[prev_index] - ) + # Replace placeholder token id(s) with actual sampled id(s). + if sampled_ids := sampled_token_ids[prev_index]: + req_output_token_ids[-len(sampled_ids) :] = sampled_ids @property def num_reqs(self) -> int: From 461e7b3e32cecaf798c9e19820980cda7474aebf Mon Sep 17 00:00:00 2001 From: izhuhaoran Date: Thu, 11 Dec 2025 19:32:42 +0800 Subject: [PATCH 03/14] feat: support for penalty and badwords for async + spec Signed-off-by: izhuhaoran --- vllm/v1/engine/input_processor.py | 13 +---- vllm/v1/worker/gpu_input_batch.py | 25 ++++++++ vllm/v1/worker/gpu_model_runner.py | 91 +++++++++++++++++++++++++++--- 3 files changed, 110 insertions(+), 19 deletions(-) diff --git a/vllm/v1/engine/input_processor.py b/vllm/v1/engine/input_processor.py index e6a94f4e3de5d..d7454592ab1f7 100644 --- a/vllm/v1/engine/input_processor.py +++ b/vllm/v1/engine/input_processor.py @@ -147,22 +147,15 @@ class InputProcessor: raise ValueError( "vLLM V1 does not support per request user provided logits processors." ) - # Async scheduling + spec decode currently incompatible with some - # sampling parameters. + # Async scheduling + spec decode currently incompatible with structured outputs if ( self.vllm_config.speculative_config is not None and self.vllm_config.scheduler_config.async_scheduling - and ( - params.frequency_penalty != 0.0 - or params.presence_penalty != 0.0 - or params.repetition_penalty != 1.0 - or params.bad_words_token_ids - or params.structured_outputs - ) + and params.structured_outputs ): raise ValueError( "async scheduling with spec decoding doesn't yet support " - "penalties, bad words or structured outputs in sampling parameters." + "structured outputs in sampling parameters." ) def _validate_params( diff --git a/vllm/v1/worker/gpu_input_batch.py b/vllm/v1/worker/gpu_input_batch.py index ff71c6d54179d..10f5df96f9535 100644 --- a/vllm/v1/worker/gpu_input_batch.py +++ b/vllm/v1/worker/gpu_input_batch.py @@ -954,6 +954,31 @@ class InputBatch: if sampled_ids := sampled_token_ids[prev_index]: req_output_token_ids[-len(sampled_ids) :] = sampled_ids + def update_async_spec_token_ids( + self, draft_token_ids_cpu: list[list[int]] | None + ) -> None: + """ + In async scheduling case, update spec_token_ids in sampling metadata + with real draft token ids from prior step. + This is called right before they are needed by the rejection sampler + for penalty/bad_words computation. + """ + if draft_token_ids_cpu is None or self.prev_req_id_to_index is None: + return + + spec_token_ids = self.sampling_metadata.spec_token_ids + if not spec_token_ids: + return + + for index, req_id in enumerate(self.req_ids): + prev_index = self.prev_req_id_to_index.get(req_id, default=None) + if prev_index is None: + continue + assert prev_index < len(draft_token_ids_cpu) + draft_ids = draft_token_ids_cpu[prev_index] + assert index < len(spec_token_ids) + spec_token_ids[index] = draft_ids + @property def num_reqs(self) -> int: return len(self.req_id_to_index) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 39456d2e80ed0..47f9ac90b7183 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -591,15 +591,32 @@ class GPUModelRunner( # with dedicated stream for overlapping and event for coordination. self.valid_sampled_token_count_event: torch.Event | None = None self.valid_sampled_token_count_copy_stream: torch.cuda.Stream | None = None + self.valid_sampled_token_count_cpu: torch.Tensor | None = None + # Pre-allocated tensor for copying draft token ids to CPU, + # with dedicated stream for overlapping and event for coordination. + self.draft_token_ids_copy_event: torch.Event | None = None + self.draft_token_ids_copy_stream: torch.cuda.Stream | None = None + self.draft_token_ids_cpu: torch.Tensor | None = None if self.use_async_scheduling and self.num_spec_tokens: self.valid_sampled_token_count_event = torch.Event() self.valid_sampled_token_count_copy_stream = torch.cuda.Stream() - self.valid_sampled_token_count_cpu = torch.empty( - self.max_num_reqs, - dtype=torch.int64, - device="cpu", - pin_memory=self.pin_memory, - ) + self.valid_sampled_token_count_cpu = torch.empty( + self.max_num_reqs, + dtype=torch.int64, + device="cpu", + pin_memory=self.pin_memory, + ) + self.draft_token_ids_copy_event = torch.Event() + self.draft_token_ids_copy_stream = torch.cuda.Stream() + self.draft_token_ids_cpu = torch.empty( + (self.max_num_reqs, self.num_spec_tokens), + dtype=torch.int64, + device="cpu", + pin_memory=self.pin_memory, + ) + # Flag to track if valid draft tokens were copied this step. + # Reset to False at step start, set True in _copy_draft_token_ids. + self._has_draft_tokens: bool = False # Ephemeral state transferred between execute_model() and sample_tokens(). self.execute_model_state: ExecuteModelState | None = None @@ -2555,15 +2572,19 @@ class GPUModelRunner( ) -> SamplerOutput: # Sample the next token and get logprobs if needed. sampling_metadata = self.input_batch.sampling_metadata + # Update output token ids with tokens sampled in last step + # if async scheduling and required by current sampling params. + self.input_batch.update_async_output_token_ids() if spec_decode_metadata is None: - # Update output token ids with tokens sampled in last step - # if async scheduling and required by current sampling params. - self.input_batch.update_async_output_token_ids() return self.sampler( logits=logits, sampling_metadata=sampling_metadata, ) + # Update spec_token_ids with real draft tokens from previous step + draft_token_ids_cpu = self._get_draft_token_ids_cpu() + self.input_batch.update_async_spec_token_ids(draft_token_ids_cpu) + sampler_output = self.rejection_sampler( spec_decode_metadata, None, # draft_probs @@ -3370,6 +3391,55 @@ class GPUModelRunner( self.valid_sampled_token_count_event.synchronize() return counts_cpu[: prev_sampled_token_ids.shape[0]].tolist() + def _copy_draft_token_ids( + self, draft_token_ids: torch.Tensor, num_reqs: int + ) -> None: + """Copy draft token ids to CPU asynchronously. + + This is used for async scheduling with spec decode + penalty/bad_words. + The draft_token_ids will be used in the next step to update + input_batch.spec_token_ids for correct penalty/bad_words computation. + """ + if self.draft_token_ids_copy_event is None or not isinstance( + draft_token_ids, torch.Tensor + ): + return + + default_stream = torch.cuda.current_stream() + with torch.cuda.stream(self.draft_token_ids_copy_stream): + self.draft_token_ids_copy_stream.wait_stream(default_stream) # type: ignore + # Copy draft_token_ids [num_reqs, num_spec_tokens] to pinned CPU + assert self.draft_token_ids_cpu is not None + self.draft_token_ids_cpu[:num_reqs].copy_( + draft_token_ids[:num_reqs], non_blocking=True + ) + self.draft_token_ids_copy_event.record() + self._has_draft_tokens = True + + def _get_draft_token_ids_cpu(self) -> list[list[int]] | None: + """Get previously copied draft token ids from CPU. + + Called at the beginning of the next step to update spec_token_ids + for async scheduling with spec decode + penalty/bad_words. + Returns None if no draft tokens were copied in previous step. + """ + if isinstance(self._draft_token_ids, list): + return self._draft_token_ids + + if ( + self.draft_token_ids_copy_event is None + or self.draft_token_ids_cpu is None + or not self._has_draft_tokens + ): + return None + + self._has_draft_tokens = False + self.draft_token_ids_copy_event.synchronize() + num_reqs = self.input_batch.num_reqs + if num_reqs == 0: + return None + return self.draft_token_ids_cpu[:num_reqs].tolist() + def propose_draft_token_ids( self, scheduler_output: "SchedulerOutput", @@ -3530,6 +3600,9 @@ class GPUModelRunner( mm_embed_inputs=mm_embed_inputs, ) + self._copy_draft_token_ids( + self._draft_token_ids, self.input_batch.num_reqs + ) return draft_token_ids def update_config(self, overrides: dict[str, Any]) -> None: From 78555c06a20fff084b9cc4f80a444928b36ed18f Mon Sep 17 00:00:00 2001 From: zhuhaoran Date: Fri, 12 Dec 2025 00:02:47 +0800 Subject: [PATCH 04/14] fix bug Signed-off-by: zhuhaoran --- vllm/v1/worker/gpu_input_batch.py | 20 +++++++++++++++++--- vllm/v1/worker/gpu_model_runner.py | 13 +++++-------- 2 files changed, 22 insertions(+), 11 deletions(-) diff --git a/vllm/v1/worker/gpu_input_batch.py b/vllm/v1/worker/gpu_input_batch.py index 10f5df96f9535..786dfe95106e6 100644 --- a/vllm/v1/worker/gpu_input_batch.py +++ b/vllm/v1/worker/gpu_input_batch.py @@ -952,7 +952,18 @@ class InputBatch: sampled_token_ids = self.sampled_token_ids_cpu.tolist() # Replace placeholder token id(s) with actual sampled id(s). if sampled_ids := sampled_token_ids[prev_index]: - req_output_token_ids[-len(sampled_ids) :] = sampled_ids + num_placeholders = 0 + for t in reversed(req_output_token_ids): + if t == -1: + num_placeholders += 1 + else: + break + if num_placeholders == 0: + continue + assert num_placeholders <= len(sampled_ids) + req_output_token_ids[-num_placeholders:] = sampled_ids[ + :num_placeholders + ] def update_async_spec_token_ids( self, draft_token_ids_cpu: list[list[int]] | None @@ -971,13 +982,16 @@ class InputBatch: return for index, req_id in enumerate(self.req_ids): - prev_index = self.prev_req_id_to_index.get(req_id, default=None) + prev_index = self.prev_req_id_to_index.get(req_id) if prev_index is None: continue assert prev_index < len(draft_token_ids_cpu) draft_ids = draft_token_ids_cpu[prev_index] + if not draft_ids: + continue assert index < len(spec_token_ids) - spec_token_ids[index] = draft_ids + spec_token_ids[index].clear() + spec_token_ids[index].extend(draft_ids) @property def num_reqs(self) -> int: diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 47f9ac90b7183..5c85a9bf399d9 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -3423,14 +3423,13 @@ class GPUModelRunner( for async scheduling with spec decode + penalty/bad_words. Returns None if no draft tokens were copied in previous step. """ + if not self._has_draft_tokens: + return None + if isinstance(self._draft_token_ids, list): return self._draft_token_ids - if ( - self.draft_token_ids_copy_event is None - or self.draft_token_ids_cpu is None - or not self._has_draft_tokens - ): + if self.draft_token_ids_copy_event is None or self.draft_token_ids_cpu is None: return None self._has_draft_tokens = False @@ -3600,9 +3599,7 @@ class GPUModelRunner( mm_embed_inputs=mm_embed_inputs, ) - self._copy_draft_token_ids( - self._draft_token_ids, self.input_batch.num_reqs - ) + self._copy_draft_token_ids(draft_token_ids, self.input_batch.num_reqs) return draft_token_ids def update_config(self, overrides: dict[str, Any]) -> None: From 049041874227b619fe3c78ce90bec0468bd5aef8 Mon Sep 17 00:00:00 2001 From: zhuhaoran Date: Fri, 12 Dec 2025 01:08:23 +0800 Subject: [PATCH 05/14] fix bug about pre num_reqs maybe != current num_reqs Signed-off-by: zhuhaoran --- vllm/v1/worker/gpu_model_runner.py | 23 ++++++++++------------- 1 file changed, 10 insertions(+), 13 deletions(-) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 5c85a9bf399d9..0c7d8d2b2c435 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -614,9 +614,7 @@ class GPUModelRunner( device="cpu", pin_memory=self.pin_memory, ) - # Flag to track if valid draft tokens were copied this step. - # Reset to False at step start, set True in _copy_draft_token_ids. - self._has_draft_tokens: bool = False + self._prev_copy_draft_num_reqs: int = 0 # Ephemeral state transferred between execute_model() and sample_tokens(). self.execute_model_state: ExecuteModelState | None = None @@ -3405,6 +3403,7 @@ class GPUModelRunner( ): return + self._prev_copy_draft_num_reqs = num_reqs default_stream = torch.cuda.current_stream() with torch.cuda.stream(self.draft_token_ids_copy_stream): self.draft_token_ids_copy_stream.wait_stream(default_stream) # type: ignore @@ -3414,7 +3413,6 @@ class GPUModelRunner( draft_token_ids[:num_reqs], non_blocking=True ) self.draft_token_ids_copy_event.record() - self._has_draft_tokens = True def _get_draft_token_ids_cpu(self) -> list[list[int]] | None: """Get previously copied draft token ids from CPU. @@ -3423,21 +3421,20 @@ class GPUModelRunner( for async scheduling with spec decode + penalty/bad_words. Returns None if no draft tokens were copied in previous step. """ - if not self._has_draft_tokens: - return None - if isinstance(self._draft_token_ids, list): return self._draft_token_ids - if self.draft_token_ids_copy_event is None or self.draft_token_ids_cpu is None: + if ( + self.draft_token_ids_copy_event is None + or self.draft_token_ids_cpu is None + or not self._prev_copy_draft_num_reqs + ): return None - self._has_draft_tokens = False + _prev_copy_draft_num_reqs = self._prev_copy_draft_num_reqs + self._prev_copy_draft_num_reqs = 0 self.draft_token_ids_copy_event.synchronize() - num_reqs = self.input_batch.num_reqs - if num_reqs == 0: - return None - return self.draft_token_ids_cpu[:num_reqs].tolist() + return self.draft_token_ids_cpu[:_prev_copy_draft_num_reqs].tolist() def propose_draft_token_ids( self, From aec572f39d8b703e4772965b99fc8dd261f71081 Mon Sep 17 00:00:00 2001 From: zhuhaoran Date: Fri, 12 Dec 2025 10:56:06 +0800 Subject: [PATCH 06/14] lint: fix mypy error Signed-off-by: zhuhaoran --- vllm/v1/worker/gpu_model_runner.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 0c7d8d2b2c435..51812c87227de 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -591,7 +591,6 @@ class GPUModelRunner( # with dedicated stream for overlapping and event for coordination. self.valid_sampled_token_count_event: torch.Event | None = None self.valid_sampled_token_count_copy_stream: torch.cuda.Stream | None = None - self.valid_sampled_token_count_cpu: torch.Tensor | None = None # Pre-allocated tensor for copying draft token ids to CPU, # with dedicated stream for overlapping and event for coordination. self.draft_token_ids_copy_event: torch.Event | None = None @@ -600,12 +599,6 @@ class GPUModelRunner( if self.use_async_scheduling and self.num_spec_tokens: self.valid_sampled_token_count_event = torch.Event() self.valid_sampled_token_count_copy_stream = torch.cuda.Stream() - self.valid_sampled_token_count_cpu = torch.empty( - self.max_num_reqs, - dtype=torch.int64, - device="cpu", - pin_memory=self.pin_memory, - ) self.draft_token_ids_copy_event = torch.Event() self.draft_token_ids_copy_stream = torch.cuda.Stream() self.draft_token_ids_cpu = torch.empty( @@ -615,6 +608,12 @@ class GPUModelRunner( pin_memory=self.pin_memory, ) self._prev_copy_draft_num_reqs: int = 0 + self.valid_sampled_token_count_cpu = torch.empty( + self.max_num_reqs, + dtype=torch.int64, + device="cpu", + pin_memory=self.pin_memory, + ) # Ephemeral state transferred between execute_model() and sample_tokens(). self.execute_model_state: ExecuteModelState | None = None From 33c63f263d588f6db142dc4d44cdfc16671b2f39 Mon Sep 17 00:00:00 2001 From: zhuhaoran Date: Fri, 12 Dec 2025 22:59:57 +0800 Subject: [PATCH 07/14] fix: use num_draft_tokens to trim draft_token_ids_cpu Signed-off-by: zhuhaoran --- vllm/v1/worker/gpu_input_batch.py | 11 ++++++++--- vllm/v1/worker/gpu_model_runner.py | 5 ++++- 2 files changed, 12 insertions(+), 4 deletions(-) diff --git a/vllm/v1/worker/gpu_input_batch.py b/vllm/v1/worker/gpu_input_batch.py index 786dfe95106e6..ee2f6b9964e55 100644 --- a/vllm/v1/worker/gpu_input_batch.py +++ b/vllm/v1/worker/gpu_input_batch.py @@ -966,7 +966,9 @@ class InputBatch: ] def update_async_spec_token_ids( - self, draft_token_ids_cpu: list[list[int]] | None + self, + draft_token_ids_cpu: list[list[int]] | None, + num_draft_tokens: list[int] | None = None, ) -> None: """ In async scheduling case, update spec_token_ids in sampling metadata @@ -985,11 +987,14 @@ class InputBatch: prev_index = self.prev_req_id_to_index.get(req_id) if prev_index is None: continue - assert prev_index < len(draft_token_ids_cpu) draft_ids = draft_token_ids_cpu[prev_index] if not draft_ids: continue - assert index < len(spec_token_ids) + + if num_draft_tokens is not None: + scheduled_count = num_draft_tokens[index] + assert scheduled_count <= len(draft_ids) + draft_ids = draft_ids[:scheduled_count] spec_token_ids[index].clear() spec_token_ids[index].extend(draft_ids) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index f6aa773def973..0da2f082052d6 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -2582,7 +2582,10 @@ class GPUModelRunner( # Update spec_token_ids with real draft tokens from previous step draft_token_ids_cpu = self._get_draft_token_ids_cpu() - self.input_batch.update_async_spec_token_ids(draft_token_ids_cpu) + self.input_batch.update_async_spec_token_ids( + draft_token_ids_cpu, + num_draft_tokens=spec_decode_metadata.num_draft_tokens, + ) sampler_output = self.rejection_sampler( spec_decode_metadata, From 4672c4d035af58d304ab24dcb74a4a054570cce4 Mon Sep 17 00:00:00 2001 From: zhuhaoran Date: Sat, 13 Dec 2025 00:01:56 +0800 Subject: [PATCH 08/14] clean useless func comment Signed-off-by: zhuhaoran --- vllm/v1/worker/gpu_input_batch.py | 7 +++---- vllm/v1/worker/gpu_model_runner.py | 14 ++------------ 2 files changed, 5 insertions(+), 16 deletions(-) diff --git a/vllm/v1/worker/gpu_input_batch.py b/vllm/v1/worker/gpu_input_batch.py index ee2f6b9964e55..eff7e99933305 100644 --- a/vllm/v1/worker/gpu_input_batch.py +++ b/vllm/v1/worker/gpu_input_batch.py @@ -971,10 +971,9 @@ class InputBatch: num_draft_tokens: list[int] | None = None, ) -> None: """ - In async scheduling case, update spec_token_ids in sampling metadata - with real draft token ids from prior step. - This is called right before they are needed by the rejection sampler - for penalty/bad_words computation. + In async scheduling case, update spec_token_ids in sampling metadata with + real draft token ids from prior step. This is called right before they are + needed by the rejection sampler for penalty/bad_words computation. """ if draft_token_ids_cpu is None or self.prev_req_id_to_index is None: return diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 0da2f082052d6..92b6ad31f4443 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -3403,12 +3403,7 @@ class GPUModelRunner( def _copy_draft_token_ids( self, draft_token_ids: torch.Tensor, num_reqs: int ) -> None: - """Copy draft token ids to CPU asynchronously. - - This is used for async scheduling with spec decode + penalty/bad_words. - The draft_token_ids will be used in the next step to update - input_batch.spec_token_ids for correct penalty/bad_words computation. - """ + """Copy draft token ids to CPU asynchronously.""" if self.draft_token_ids_copy_event is None or not isinstance( draft_token_ids, torch.Tensor ): @@ -3426,12 +3421,7 @@ class GPUModelRunner( self.draft_token_ids_copy_event.record() def _get_draft_token_ids_cpu(self) -> list[list[int]] | None: - """Get previously copied draft token ids from CPU. - - Called at the beginning of the next step to update spec_token_ids - for async scheduling with spec decode + penalty/bad_words. - Returns None if no draft tokens were copied in previous step. - """ + """Get previously copied draft token ids from CPU.""" if isinstance(self._draft_token_ids, list): return self._draft_token_ids From 3de2e4bff0d86abb5909f7e934dc1a964a5ef2a1 Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Fri, 12 Dec 2025 10:17:01 -0800 Subject: [PATCH 09/14] update e2e test Signed-off-by: Nick Hill --- tests/v1/e2e/test_async_scheduling.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tests/v1/e2e/test_async_scheduling.py b/tests/v1/e2e/test_async_scheduling.py index 5cef9b33c9984..c320b7c91e76d 100644 --- a/tests/v1/e2e/test_async_scheduling.py +++ b/tests/v1/e2e/test_async_scheduling.py @@ -101,8 +101,11 @@ def test_with_spec_decoding(monkeypatch: pytest.MonkeyPatch): spec_config_short = spec_config | {"max_model_len": 50} test_sampling_params = [ - dict(), + dict(presence_penalty=-1.0), + dict(bad_words=["the", " the"]), dict(logprobs=2), + # TODO there is a logprobs diff for this combo, independent of async scheduling + # dict(logprobs=2, presence_penalty=-1.0), ] # test_preemption, executor, async_scheduling, From 7feb2f2a6d5b37d78d871739b3db54d7a28de9f7 Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Fri, 12 Dec 2025 10:32:37 -0800 Subject: [PATCH 10/14] update e2e test Signed-off-by: Nick Hill --- tests/v1/e2e/test_async_scheduling.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/v1/e2e/test_async_scheduling.py b/tests/v1/e2e/test_async_scheduling.py index c320b7c91e76d..137a5e6edef2e 100644 --- a/tests/v1/e2e/test_async_scheduling.py +++ b/tests/v1/e2e/test_async_scheduling.py @@ -101,6 +101,8 @@ def test_with_spec_decoding(monkeypatch: pytest.MonkeyPatch): spec_config_short = spec_config | {"max_model_len": 50} test_sampling_params = [ + dict(), + dict(logprobs=2), dict(presence_penalty=-1.0), dict(bad_words=["the", " the"]), dict(logprobs=2), From c5df2565ab73841bef982ed87799b943c4f117be Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Fri, 12 Dec 2025 10:34:23 -0800 Subject: [PATCH 11/14] update e2e test - oops Signed-off-by: Nick Hill --- tests/v1/e2e/test_async_scheduling.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/v1/e2e/test_async_scheduling.py b/tests/v1/e2e/test_async_scheduling.py index 137a5e6edef2e..307b6e66682f6 100644 --- a/tests/v1/e2e/test_async_scheduling.py +++ b/tests/v1/e2e/test_async_scheduling.py @@ -102,7 +102,6 @@ def test_with_spec_decoding(monkeypatch: pytest.MonkeyPatch): test_sampling_params = [ dict(), - dict(logprobs=2), dict(presence_penalty=-1.0), dict(bad_words=["the", " the"]), dict(logprobs=2), From b4d755ac050094cb903521ee035b034790136cb4 Mon Sep 17 00:00:00 2001 From: zhuhaoran Date: Sat, 13 Dec 2025 18:09:36 +0800 Subject: [PATCH 12/14] fix inductor tf32 setting error Signed-off-by: zhuhaoran --- tests/v1/e2e/test_async_scheduling.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/v1/e2e/test_async_scheduling.py b/tests/v1/e2e/test_async_scheduling.py index 307b6e66682f6..5d64a41fbc9a5 100644 --- a/tests/v1/e2e/test_async_scheduling.py +++ b/tests/v1/e2e/test_async_scheduling.py @@ -4,6 +4,7 @@ from itertools import repeat from typing import Any import pytest +import torch import torch._dynamo.config as dynamo_config from vllm import SamplingParams @@ -158,6 +159,7 @@ def run_tests( m.setenv("VLLM_ATTENTION_BACKEND", "FLEX_ATTENTION") # lock matmul precision to full FP32 (IEEE) m.setenv("VLLM_FLOAT32_MATMUL_PRECISION", "ieee") + torch.backends.cuda.matmul.allow_tf32 = False # m.setenv("VLLM_BATCH_INVARIANT", "1") outputs: list[tuple[str, list, list]] = [] for n, ( From 8d339e86e5ed4328e8be9497a933ff85a2706d20 Mon Sep 17 00:00:00 2001 From: zhuhaoran Date: Tue, 23 Dec 2025 17:52:38 +0800 Subject: [PATCH 13/14] fix corner case for update_async_output_token_ids Signed-off-by: zhuhaoran --- vllm/v1/worker/gpu_input_batch.py | 15 ++++++--------- 1 file changed, 6 insertions(+), 9 deletions(-) diff --git a/vllm/v1/worker/gpu_input_batch.py b/vllm/v1/worker/gpu_input_batch.py index b09e982e70137..50a421c4a6fec 100644 --- a/vllm/v1/worker/gpu_input_batch.py +++ b/vllm/v1/worker/gpu_input_batch.py @@ -942,18 +942,15 @@ class InputBatch: sampled_token_ids = self.sampled_token_ids_cpu.tolist() # Replace placeholder token id(s) with actual sampled id(s). if sampled_ids := sampled_token_ids[prev_index]: - num_placeholders = 0 - for t in reversed(req_output_token_ids): + num_replace = 0 + for t in sampled_ids: if t == -1: - num_placeholders += 1 - else: break - if num_placeholders == 0: + num_replace += 1 + + if num_replace == 0: continue - assert num_placeholders <= len(sampled_ids) - req_output_token_ids[-num_placeholders:] = sampled_ids[ - :num_placeholders - ] + req_output_token_ids[-num_replace:] = sampled_ids[:num_replace] def update_async_spec_token_ids( self, From 9b1a8cc76a25648a199467aeb95e23b6377a0b4e Mon Sep 17 00:00:00 2001 From: zhuhaoran Date: Tue, 23 Dec 2025 22:06:51 +0800 Subject: [PATCH 14/14] remove todo for logprobs and penalty combo Signed-off-by: zhuhaoran --- tests/v1/e2e/test_async_scheduling.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/v1/e2e/test_async_scheduling.py b/tests/v1/e2e/test_async_scheduling.py index 28d75ee156a68..a558fbabb12e4 100644 --- a/tests/v1/e2e/test_async_scheduling.py +++ b/tests/v1/e2e/test_async_scheduling.py @@ -106,8 +106,7 @@ def test_with_spec_decoding(monkeypatch: pytest.MonkeyPatch): dict(presence_penalty=-1.0), dict(bad_words=["the", " the"]), dict(logprobs=2), - # TODO there is a logprobs diff for this combo, independent of async scheduling - # dict(logprobs=2, presence_penalty=-1.0), + dict(logprobs=2, presence_penalty=-1.0), ] # test_preemption, executor, async_scheduling,