diff --git a/tests/v1/e2e/test_async_scheduling.py b/tests/v1/e2e/test_async_scheduling.py index 6447a33838d75..a558fbabb12e4 100644 --- a/tests/v1/e2e/test_async_scheduling.py +++ b/tests/v1/e2e/test_async_scheduling.py @@ -4,6 +4,7 @@ from itertools import repeat from typing import Any import pytest +import torch import torch._dynamo.config as dynamo_config from vllm import SamplingParams @@ -102,7 +103,10 @@ def test_with_spec_decoding(monkeypatch: pytest.MonkeyPatch): test_sampling_params = [ dict(), + dict(presence_penalty=-1.0), + dict(bad_words=["the", " the"]), dict(logprobs=2), + dict(logprobs=2, presence_penalty=-1.0), ] # test_preemption, executor, async_scheduling, @@ -155,6 +159,7 @@ def run_tests( with monkeypatch.context() as m: # lock matmul precision to full FP32 (IEEE) m.setenv("VLLM_FLOAT32_MATMUL_PRECISION", "ieee") + torch.backends.cuda.matmul.allow_tf32 = False # m.setenv("VLLM_BATCH_INVARIANT", "1") outputs: list[tuple[str, list, list]] = [] for n, ( diff --git a/vllm/v1/engine/input_processor.py b/vllm/v1/engine/input_processor.py index 1d43a8253843f..65f7248eafc94 100644 --- a/vllm/v1/engine/input_processor.py +++ b/vllm/v1/engine/input_processor.py @@ -147,22 +147,15 @@ class InputProcessor: raise ValueError( "vLLM V1 does not support per request user provided logits processors." ) - # Async scheduling + spec decode currently incompatible with some - # sampling parameters. + # Async scheduling + spec decode currently incompatible with structured outputs if ( self.vllm_config.speculative_config is not None and self.vllm_config.scheduler_config.async_scheduling - and ( - params.frequency_penalty != 0.0 - or params.presence_penalty != 0.0 - or params.repetition_penalty != 1.0 - or params.bad_words_token_ids - or params.structured_outputs - ) + and params.structured_outputs ): raise ValueError( "async scheduling with spec decoding doesn't yet support " - "penalties, bad words or structured outputs in sampling parameters." + "structured outputs in sampling parameters." ) def _validate_params( diff --git a/vllm/v1/worker/gpu_input_batch.py b/vllm/v1/worker/gpu_input_batch.py index 08b595845bb40..50a421c4a6fec 100644 --- a/vllm/v1/worker/gpu_input_batch.py +++ b/vllm/v1/worker/gpu_input_batch.py @@ -939,9 +939,50 @@ class InputBatch: if sampled_token_ids is None: assert self.async_copy_ready_event is not None self.async_copy_ready_event.synchronize() - sampled_token_ids = self.sampled_token_ids_cpu.squeeze(-1).tolist() - # Replace placeholder token id with actual sampled id. - req_output_token_ids[-1] = sampled_token_ids[prev_index] + sampled_token_ids = self.sampled_token_ids_cpu.tolist() + # Replace placeholder token id(s) with actual sampled id(s). + if sampled_ids := sampled_token_ids[prev_index]: + num_replace = 0 + for t in sampled_ids: + if t == -1: + break + num_replace += 1 + + if num_replace == 0: + continue + req_output_token_ids[-num_replace:] = sampled_ids[:num_replace] + + def update_async_spec_token_ids( + self, + draft_token_ids_cpu: list[list[int]] | None, + num_draft_tokens: list[int] | None = None, + ) -> None: + """ + In async scheduling case, update spec_token_ids in sampling metadata with + real draft token ids from prior step. This is called right before they are + needed by the rejection sampler for penalty/bad_words computation. + """ + if draft_token_ids_cpu is None or self.prev_req_id_to_index is None: + return + + spec_token_ids = self.sampling_metadata.spec_token_ids + if not spec_token_ids: + return + + for index, req_id in enumerate(self.req_ids): + prev_index = self.prev_req_id_to_index.get(req_id) + if prev_index is None: + continue + draft_ids = draft_token_ids_cpu[prev_index] + if not draft_ids: + continue + + if num_draft_tokens is not None: + scheduled_count = num_draft_tokens[index] + assert scheduled_count <= len(draft_ids) + draft_ids = draft_ids[:scheduled_count] + spec_token_ids[index].clear() + spec_token_ids[index].extend(draft_ids) @property def num_reqs(self) -> int: diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 414ae33c6251f..97eed52805a48 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -591,9 +591,23 @@ class GPUModelRunner( # with dedicated stream for overlapping and event for coordination. self.valid_sampled_token_count_event: torch.Event | None = None self.valid_sampled_token_count_copy_stream: torch.cuda.Stream | None = None + # Pre-allocated tensor for copying draft token ids to CPU, + # with dedicated stream for overlapping and event for coordination. + self.draft_token_ids_copy_event: torch.Event | None = None + self.draft_token_ids_copy_stream: torch.cuda.Stream | None = None + self.draft_token_ids_cpu: torch.Tensor | None = None if self.use_async_scheduling and self.num_spec_tokens: self.valid_sampled_token_count_event = torch.Event() self.valid_sampled_token_count_copy_stream = torch.cuda.Stream() + self.draft_token_ids_copy_event = torch.Event() + self.draft_token_ids_copy_stream = torch.cuda.Stream() + self.draft_token_ids_cpu = torch.empty( + (self.max_num_reqs, self.num_spec_tokens), + dtype=torch.int64, + device="cpu", + pin_memory=self.pin_memory, + ) + self._prev_copy_draft_num_reqs: int = 0 self.valid_sampled_token_count_cpu = torch.empty( self.max_num_reqs, dtype=torch.int64, @@ -2585,15 +2599,22 @@ class GPUModelRunner( ) -> SamplerOutput: # Sample the next token and get logprobs if needed. sampling_metadata = self.input_batch.sampling_metadata + # Update output token ids with tokens sampled in last step + # if async scheduling and required by current sampling params. + self.input_batch.update_async_output_token_ids() if spec_decode_metadata is None: - # Update output token ids with tokens sampled in last step - # if async scheduling and required by current sampling params. - self.input_batch.update_async_output_token_ids() return self.sampler( logits=logits, sampling_metadata=sampling_metadata, ) + # Update spec_token_ids with real draft tokens from previous step + draft_token_ids_cpu = self._get_draft_token_ids_cpu() + self.input_batch.update_async_spec_token_ids( + draft_token_ids_cpu, + num_draft_tokens=spec_decode_metadata.num_draft_tokens, + ) + sampler_output = self.rejection_sampler( spec_decode_metadata, None, # draft_probs @@ -3446,6 +3467,43 @@ class GPUModelRunner( self.valid_sampled_token_count_event.synchronize() return counts_cpu[: prev_sampled_token_ids.shape[0]].tolist() + def _copy_draft_token_ids( + self, draft_token_ids: torch.Tensor, num_reqs: int + ) -> None: + """Copy draft token ids to CPU asynchronously.""" + if self.draft_token_ids_copy_event is None or not isinstance( + draft_token_ids, torch.Tensor + ): + return + + self._prev_copy_draft_num_reqs = num_reqs + default_stream = torch.cuda.current_stream() + with torch.cuda.stream(self.draft_token_ids_copy_stream): + self.draft_token_ids_copy_stream.wait_stream(default_stream) # type: ignore + # Copy draft_token_ids [num_reqs, num_spec_tokens] to pinned CPU + assert self.draft_token_ids_cpu is not None + self.draft_token_ids_cpu[:num_reqs].copy_( + draft_token_ids[:num_reqs], non_blocking=True + ) + self.draft_token_ids_copy_event.record() + + def _get_draft_token_ids_cpu(self) -> list[list[int]] | None: + """Get previously copied draft token ids from CPU.""" + if isinstance(self._draft_token_ids, list): + return self._draft_token_ids + + if ( + self.draft_token_ids_copy_event is None + or self.draft_token_ids_cpu is None + or not self._prev_copy_draft_num_reqs + ): + return None + + _prev_copy_draft_num_reqs = self._prev_copy_draft_num_reqs + self._prev_copy_draft_num_reqs = 0 + self.draft_token_ids_copy_event.synchronize() + return self.draft_token_ids_cpu[:_prev_copy_draft_num_reqs].tolist() + def propose_draft_token_ids( self, scheduler_output: "SchedulerOutput", @@ -3610,6 +3668,7 @@ class GPUModelRunner( num_rejected_tokens_gpu=num_rejected_tokens_gpu, ) + self._copy_draft_token_ids(draft_token_ids, self.input_batch.num_reqs) return draft_token_ids def update_config(self, overrides: dict[str, Any]) -> None: