From d8874c61a55e40db4ada047f1736c38c86439fff Mon Sep 17 00:00:00 2001 From: Ronald Date: Tue, 18 Nov 2025 04:16:20 +0800 Subject: [PATCH] [Core] Async Scheduling X Spec Decoding Compatibility (#24799) Signed-off-by: Ronald1995 Signed-off-by: Nick Hill Signed-off-by: Benjamin Chislett Co-authored-by: Nick Hill Co-authored-by: Benjamin Chislett --- tests/v1/e2e/test_async_scheduling.py | 38 +-- vllm/config/speculative.py | 38 ++- vllm/config/vllm.py | 21 +- vllm/v1/core/sched/async_scheduler.py | 15 +- vllm/v1/core/sched/scheduler.py | 12 +- vllm/v1/engine/core.py | 6 +- vllm/v1/engine/processor.py | 17 ++ vllm/v1/sample/logits_processor/__init__.py | 2 +- vllm/v1/spec_decode/eagle.py | 7 +- vllm/v1/worker/gpu_input_batch.py | 3 + vllm/v1/worker/gpu_model_runner.py | 253 +++++++++++++++++--- 11 files changed, 314 insertions(+), 98 deletions(-) diff --git a/tests/v1/e2e/test_async_scheduling.py b/tests/v1/e2e/test_async_scheduling.py index c4aca82416cd..f732b05f09f9 100644 --- a/tests/v1/e2e/test_async_scheduling.py +++ b/tests/v1/e2e/test_async_scheduling.py @@ -15,7 +15,7 @@ from ...conftest import VllmRunner from ...models.utils import check_outputs_equal MODEL = "Qwen/Qwen3-0.6B" -MTP_MODEL = "XiaomiMiMo/MiMo-7B-Base" +MTP_MODEL = "meta-llama/Llama-3.2-1B-Instruct" first_prompt = ( @@ -29,7 +29,8 @@ example_prompts = [first_prompt, "In one word, the capital of France is "] + [ default_params = dict( temperature=0.0, # greedy - max_tokens=20, + max_tokens=23, + min_tokens=18, ) @@ -69,15 +70,9 @@ def test_without_spec_decoding( (True, "uni", True, None, True), ] - run_tests( - monkeypatch, - MODEL, - test_configs, - test_sampling_params, - ) + run_tests(monkeypatch, MODEL, test_configs, test_sampling_params) -@pytest.mark.skip("MTP model too big to run in fp32 in CI") def test_with_spec_decoding(monkeypatch: pytest.MonkeyPatch): """Test consistency and acceptance rates with some different combos of preemption, executor, async scheduling, prefill chunking, @@ -85,8 +80,9 @@ def test_with_spec_decoding(monkeypatch: pytest.MonkeyPatch): """ spec_config = { - "method": "mtp", + "method": "eagle3", "num_speculative_tokens": 2, + "model": "nm-testing/Llama3_2_1B_speculator.eagle3", } spec_config_short = spec_config | {"max_model_len": 50} @@ -106,12 +102,7 @@ def test_with_spec_decoding(monkeypatch: pytest.MonkeyPatch): (True, "uni", True, spec_config_short, True), ] - run_tests( - monkeypatch, - MTP_MODEL, - test_configs, - [{}], - ) + run_tests(monkeypatch, MTP_MODEL, test_configs, [{}]) @dynamo_config.patch(cache_size_limit=16) @@ -182,15 +173,13 @@ def run_tests( and test_acceptance_rate is not None ): if "spec_mml=None" in test_config: - # because the acceptance rate can vary, we use a looser - # tolerance here. assert ( pytest.approx(test_acceptance_rate, rel=5e-2) == base_acceptance_rate ) else: # Currently the reported acceptance rate is expected to be - # lower when we skip drafting altogether. + # lower when we sometimes skip drafting altogether. assert test_acceptance_rate > 0.05 print( f"PASSED: config=[{test_config}], params={params}" @@ -220,6 +209,7 @@ def run_test( ): spec_decoding = spec_config is not None cache_arg: dict[str, Any] = ( + # Force preemptions dict(num_gpu_blocks_override=32) if test_preemption else dict(gpu_memory_utilization=0.9) @@ -238,6 +228,7 @@ def run_test( model, max_model_len=512, enable_chunked_prefill=test_prefill_chunking, + # Force prefill chunking max_num_batched_tokens=48 if test_prefill_chunking else None, # enforce_eager=True, async_scheduling=async_scheduling, @@ -255,10 +246,7 @@ def run_test( results.append( vllm_model.generate( example_prompts, - sampling_params=SamplingParams( - **default_params, - **override_params, - ), + sampling_params=SamplingParams(**default_params, **override_params), return_logprobs=True, ) ) @@ -270,9 +258,7 @@ def run_test( if test_preemption: preemptions = _get_count( - metrics_before, - metrics_after, - "vllm:num_preemptions", + metrics_before, metrics_after, "vllm:num_preemptions" ) assert preemptions > 0, "preemption test had no preemptions" diff --git a/vllm/config/speculative.py b/vllm/config/speculative.py index 31cdeabe501d..13a8632413d9 100644 --- a/vllm/config/speculative.py +++ b/vllm/config/speculative.py @@ -3,7 +3,7 @@ import ast import hashlib -from typing import TYPE_CHECKING, Any, Literal +from typing import TYPE_CHECKING, Any, Literal, get_args from pydantic import Field, SkipValidation, model_validator from pydantic.dataclasses import dataclass @@ -29,31 +29,25 @@ else: logger = init_logger(__name__) -SpeculativeMethod = Literal[ - "ngram", - "eagle", - "eagle3", - "medusa", - "mlp_speculator", - "draft_model", - "deepseek_mtp", - "ernie_mtp", - "qwen3_next_mtp", - "mimo_mtp", - "longcat_flash_mtp", - "pangu_ultra_moe_mtp", - "mtp", - "suffix", -] -MTP_MODEL_TYPES = ( +MTPModelTypes = Literal[ "deepseek_mtp", "mimo_mtp", "glm4_moe_mtp", "ernie_mtp", "qwen3_next_mtp", "longcat_flash_mtp", + "mtp", "pangu_ultra_moe_mtp", -) +] +EagleModelTypes = Literal["eagle", "eagle3", MTPModelTypes] +SpeculativeMethod = Literal[ + "ngram", + "medusa", + "mlp_speculator", + "draft_model", + "suffix", + EagleModelTypes, +] @config @@ -244,7 +238,7 @@ class SpeculativeConfig: # can not be detected, it will be considered as the "draft_model" by # default. - if self.method in MTP_MODEL_TYPES: + if self.method in get_args(MTPModelTypes) and self.method != "mtp": logger.warning( "method `%s` is deprecated and replaced with mtp.", self.method ) @@ -361,7 +355,9 @@ class SpeculativeConfig: self.method = "medusa" elif self.draft_model_config.hf_config.model_type == "mlp_speculator": self.method = "mlp_speculator" - elif self.draft_model_config.hf_config.model_type in MTP_MODEL_TYPES: + elif self.draft_model_config.hf_config.model_type in get_args( + MTPModelTypes + ): self.method = "mtp" if self.num_speculative_tokens > 1: logger.warning( diff --git a/vllm/config/vllm.py b/vllm/config/vllm.py index bd98be48588f..672b004c4aa5 100644 --- a/vllm/config/vllm.py +++ b/vllm/config/vllm.py @@ -14,13 +14,14 @@ from dataclasses import replace from datetime import datetime from functools import lru_cache from pathlib import Path -from typing import TYPE_CHECKING, Any, TypeVar +from typing import TYPE_CHECKING, Any, TypeVar, get_args import torch from pydantic import ConfigDict, Field, model_validator from pydantic.dataclasses import dataclass import vllm.envs as envs +from vllm.config.speculative import EagleModelTypes from vllm.logger import enable_trace_function_call, init_logger from vllm.transformers_utils.runai_utils import is_runai_obj_uri from vllm.utils import random_uuid @@ -374,10 +375,22 @@ class VllmConfig: "Async scheduling is not yet compatible with " "pipeline_parallel_size > 1." ) + # Currently, async scheduling only support eagle speculative + # decoding. if self.speculative_config is not None: - raise ValueError( - "Async scheduling is not yet compatible with speculative decoding." - ) + if self.speculative_config.method not in get_args(EagleModelTypes): + raise ValueError( + "Currently, async scheduling is only supported " + "with EAGLE/MTP kind of speculative decoding" + ) + if self.speculative_config.disable_padded_drafter_batch: + raise ValueError( + "async scheduling for EAGLE/MTP kind of speculative " + "decoding is enabled, but disable_padded_drafter_batch=True " + "disable_padded_drafter_batch=True is not supported for " + "this situation now. please set " + "disable_padded_drafter_batch=Fasle" + ) if not executor_supports_async_sched: raise ValueError( "Currently, async scheduling only supports `mp`, `uni`, or " diff --git a/vllm/v1/core/sched/async_scheduler.py b/vllm/v1/core/sched/async_scheduler.py index 0ad994c360b0..3214f65a0972 100644 --- a/vllm/v1/core/sched/async_scheduler.py +++ b/vllm/v1/core/sched/async_scheduler.py @@ -16,18 +16,25 @@ class AsyncScheduler(Scheduler): ) -> None: super()._update_after_schedule(scheduler_output) pending_structured_output_tokens = False + spec_decode_tokens = scheduler_output.scheduled_spec_decode_tokens for req_id in scheduler_output.num_scheduled_tokens: request = self.requests[req_id] pending_structured_output_tokens |= ( request.use_structured_output and request.num_output_placeholders > 0 ) + cur_num_spec_tokens = len(spec_decode_tokens.get(req_id, ())) if ( request.num_computed_tokens - == request.num_tokens + request.num_output_placeholders + == request.num_tokens + + request.num_output_placeholders + + cur_num_spec_tokens ): - # The request will generate a new token in this scheduling step. - # TODO(woosuk): Support speculative decoding. - request.num_output_placeholders += 1 + # The request will generate a new token plus num_spec_tokens + # in this scheduling step. + request.num_output_placeholders += 1 + cur_num_spec_tokens + # Add placeholders for the new tokens in spec_token_ids. + # Wwe will update the actual spec token ids in the worker process. + request.spec_token_ids = [-1] * self.num_spec_tokens scheduler_output.pending_structured_output_tokens = ( pending_structured_output_tokens diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 8e62542337a7..61640e856ac1 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -348,7 +348,10 @@ class Scheduler(SchedulerInterface): # Speculative decode related. if request.spec_token_ids: num_scheduled_spec_tokens = ( - num_new_tokens + request.num_computed_tokens - request.num_tokens + num_new_tokens + + request.num_computed_tokens + - request.num_tokens + - request.num_output_placeholders ) if num_scheduled_spec_tokens > 0: # Trim spec_token_ids list to num_scheduled_spec_tokens. @@ -1024,7 +1027,12 @@ class Scheduler(SchedulerInterface): # tokens and rejections. If some tokens are rejected, # num_computed_tokens is decreased by the number of rejected # tokens. - request.num_computed_tokens -= num_rejected + if request.num_computed_tokens > 0: + request.num_computed_tokens -= num_rejected + # If async scheduling, num_output_placeholders also includes + # the scheduled spec tokens count and so is similarly adjusted. + if request.num_output_placeholders > 0: + request.num_output_placeholders -= num_rejected spec_decoding_stats = self.make_spec_decoding_stats( spec_decoding_stats, num_draft_tokens=num_draft_tokens, diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index a6965182fc2c..508669cf527d 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -198,6 +198,7 @@ class EngineCore: self.step_fn = ( self.step if self.batch_queue is None else self.step_with_batch_queue ) + self.async_scheduling = vllm_config.scheduler_config.async_scheduling # Mark the startup heap as static so that it's ignored by GC. # Reduces pause times of oldest generation collections. @@ -341,7 +342,10 @@ class EngineCore: return engine_core_outputs, scheduler_output.total_num_scheduled_tokens > 0 def post_step(self, model_executed: bool) -> None: - if self.use_spec_decode and model_executed: + # When using async scheduling we can't get draft token ids in advance, + # so we update draft token ids in the worker process and don't + # need to update draft token ids here. + if not self.async_scheduling and self.use_spec_decode and model_executed: # Take the draft token ids. draft_token_ids = self.model_executor.take_draft_token_ids() if draft_token_ids is not None: diff --git a/vllm/v1/engine/processor.py b/vllm/v1/engine/processor.py index fffd075a5165..4cb911d8e22b 100644 --- a/vllm/v1/engine/processor.py +++ b/vllm/v1/engine/processor.py @@ -150,6 +150,23 @@ class Processor: raise ValueError( "vLLM V1 does not support per request user provided logits processors." ) + # Async scheduling + spec decode currently incompatible with some + # sampling parameters. + if ( + self.vllm_config.speculative_config is not None + and self.vllm_config.scheduler_config.async_scheduling + and ( + params.frequency_penalty != 0.0 + or params.presence_penalty != 0.0 + or params.repetition_penalty != 1.0 + or params.bad_words_token_ids + or params.structured_outputs + ) + ): + raise ValueError( + "async scheduling with spec decoding doesn't yet support " + "penalties, bad words or structured outputs in sampling parameters." + ) def _validate_params( self, diff --git a/vllm/v1/sample/logits_processor/__init__.py b/vllm/v1/sample/logits_processor/__init__.py index 5992c4066c9c..8b174af4c779 100644 --- a/vllm/v1/sample/logits_processor/__init__.py +++ b/vllm/v1/sample/logits_processor/__init__.py @@ -41,7 +41,7 @@ STR_POOLING_REJECTS_LOGITSPROCS = ( # Error message when the user tries to initialize vLLM with a speculative # decoding enabled and custom logitsproces STR_SPEC_DEC_REJECTS_LOGITSPROCS = ( - "Custom logits processors are not supportedwhen speculative decoding is enabled." + "Custom logits processors are not supported when speculative decoding is enabled." ) LOGITSPROCS_GROUP = "vllm.logits_processors" diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index ed602f39d0f9..5bf2503c3027 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -397,10 +397,13 @@ class EagleProposer: positions += 1 exceeds_max_model_len = positions >= self.max_model_len clamped_positions = torch.where(exceeds_max_model_len, 0, positions) - + # For data integrity when async scheduling, we shouldn't use in place + # operations in case they are modified in next step's `prepare_input` + # of main model. # Increment the sequence lengths. common_attn_metadata.seq_lens += 1 - common_attn_metadata.seq_lens_cpu += 1 + # This is an out-of-place operation to avoid modifying the original tensor. + common_attn_metadata.seq_lens_cpu = common_attn_metadata.seq_lens_cpu + 1 # For the requests that exceed the max model length, we set the # sequence length to 1 to minimize their overheads in attention. diff --git a/vllm/v1/worker/gpu_input_batch.py b/vllm/v1/worker/gpu_input_batch.py index 393181f543d2..7cf6afa3fc37 100644 --- a/vllm/v1/worker/gpu_input_batch.py +++ b/vllm/v1/worker/gpu_input_batch.py @@ -46,6 +46,9 @@ class CachedRequestState: lora_request: LoRARequest | None = None prompt_embeds: torch.Tensor | None = None + # Used when both async_scheduling and spec_decode are enabled. + prev_num_draft_len: int = 0 + def __post_init__(self): self.num_prompt_tokens = length_from_prompt_token_ids_or_embeds( self.prompt_token_ids, self.prompt_embeds diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 4fe1b6487d58..758e3e1b3a82 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -179,6 +179,7 @@ class AsyncGPUModelRunnerOutput(AsyncModelRunnerOutput): logprobs_tensors: torch.Tensor | None, invalid_req_indices: list[int], async_output_copy_stream: torch.cuda.Stream, + vocab_size: int, ): self._model_runner_output = model_runner_output self._invalid_req_indices = invalid_req_indices @@ -189,6 +190,7 @@ class AsyncGPUModelRunnerOutput(AsyncModelRunnerOutput): # Keep a reference to the device tensor to avoid it being # deallocated until we finish copying it to the host. self._sampled_token_ids = sampled_token_ids + self.vocab_size = vocab_size self._logprobs_tensors = logprobs_tensors # Initiate the copy on a separate stream, but do not synchronize it. @@ -215,10 +217,16 @@ class AsyncGPUModelRunnerOutput(AsyncModelRunnerOutput): # Release the device tensors once the copy has completed. del self._logprobs_tensors del self._sampled_token_ids - - valid_sampled_token_ids: list[np.ndarray] = [ - row for row in self.sampled_token_ids_cpu.numpy() - ] + max_gen_len = self.sampled_token_ids_cpu.shape[-1] + if max_gen_len == 1: + valid_sampled_token_ids: list[np.ndarray] = [ + row for row in self.sampled_token_ids_cpu.numpy() + ] + else: + valid_sampled_token_ids = RejectionSampler.parse_output( + self.sampled_token_ids_cpu, + self.vocab_size, + ) for i in self._invalid_req_indices: valid_sampled_token_ids[i] = np.array([]) @@ -377,6 +385,10 @@ class GPUModelRunner( ) self.rejection_sampler = RejectionSampler(self.sampler) + self.num_spec_tokens = 0 + if self.speculative_config: + self.num_spec_tokens = self.speculative_config.num_speculative_tokens + # Request states. self.requests: dict[str, CachedRequestState] = {} self.comm_stream = torch.cuda.Stream() @@ -513,11 +525,7 @@ class GPUModelRunner( self.max_num_tokens, dtype=torch.int32, device=self.device ) - self.uniform_decode_query_len = ( - 1 - if not self.speculative_config - else 1 + self.speculative_config.num_speculative_tokens - ) + self.uniform_decode_query_len = 1 + self.num_spec_tokens # Cudagraph dispatcher for runtime cudagraph dispatching. self.cudagraph_dispatcher = CudagraphDispatcher(self.vllm_config) @@ -549,6 +557,20 @@ class GPUModelRunner( pin_memory=self.pin_memory, ) + # Pre-allocated tensor for copying valid sampled token counts to CPU, + # with dedicated stream for overlapping and event for coordination. + self.valid_sampled_token_count_event: torch.cuda.Event | None = None + self.valid_sampled_token_count_copy_stream: torch.cuda.Stream | None = None + if self.use_async_scheduling and self.num_spec_tokens: + self.valid_sampled_token_count_event = torch.cuda.Event() + self.valid_sampled_token_count_copy_stream = torch.cuda.Stream() + self.valid_sampled_token_count_cpu = torch.empty( + self.max_num_reqs, + dtype=torch.int64, + device="cpu", + pin_memory=self.pin_memory, + ) + # Ephemeral state transferred between execute_model() and sample_tokens(). self.execute_model_state: ExecuteModelState | None = None @@ -736,17 +758,45 @@ class GPUModelRunner( # Update the states of the running/resumed requests. is_last_rank = get_pp_group().is_last_rank req_data = scheduler_output.scheduled_cached_reqs + + # Wait until valid_sampled_tokens_count is copied to cpu, + # then use it to update actual num_computed_tokens of each request. + valid_sampled_token_count = self._get_valid_sampled_token_count() + for i, req_id in enumerate(req_data.req_ids): req_state = self.requests[req_id] num_computed_tokens = req_data.num_computed_tokens[i] new_block_ids = req_data.new_block_ids[i] resumed_from_preemption = req_id in req_data.resumed_req_ids num_output_tokens = req_data.num_output_tokens[i] + req_index = self.input_batch.req_id_to_index.get(req_id) + + # prev_num_draft_len is used in async scheduling mode with + # spec decode. it indicates if need to update num_computed_tokens + # of the request. for example: + # fist step: num_computed_tokens = 0, spec_tokens = [], + # prev_num_draft_len = 0. + # second step: num_computed_tokens = 100(prompt lenth), + # spec_tokens = [a,b], prev_num_draft_len = 0. + # third step: num_computed_tokens = 100 + 2, spec_tokens = [c,d], + # prev_num_draft_len = 2. + # num_computed_tokens in first step and second step does't contain + # the spec tokens length, but in third step it contains the + # spec tokens length. we only need to update num_computed_tokens + # when prev_num_draft_len > 0. + if req_state.prev_num_draft_len: + if req_index is None: + req_state.prev_num_draft_len = 0 + else: + assert self.input_batch.prev_req_id_to_index is not None + prev_req_index = self.input_batch.prev_req_id_to_index[req_id] + num_accepted = valid_sampled_token_count[prev_req_index] - 1 + num_rejected = req_state.prev_num_draft_len - num_accepted + num_computed_tokens -= num_rejected + req_state.output_token_ids.extend([-1] * num_accepted) # Update the cached states. - req_state.num_computed_tokens = num_computed_tokens - req_index = self.input_batch.req_id_to_index.get(req_id) if not is_last_rank: # When using PP, the scheduler sends the sampled tokens back, @@ -823,8 +873,11 @@ class GPUModelRunner( spec_token_ids = scheduler_output.scheduled_spec_decode_tokens.get( req_id, [] ) - if spec_token_ids: - num_spec_tokens = len(spec_token_ids) + num_spec_tokens = len(spec_token_ids) + # For async scheduling, token_ids_cpu assigned from + # spec_token_ids are placeholders and will be overwritten in + # _prepare_input_ids. + if num_spec_tokens: start_index = self.input_batch.num_tokens_no_spec[req_index] end_token_index = start_index + num_spec_tokens self.input_batch.token_ids_cpu[ @@ -840,6 +893,15 @@ class GPUModelRunner( # even when speculative decoding is enabled. self.input_batch.spec_token_ids[req_index] = spec_token_ids + # there are no draft tokens with async scheduling, + # we clear the spec_decoding info in scheduler_output and + # use normal sampling but rejection_sampling. + if self.use_async_scheduling: + req_state.prev_num_draft_len = num_spec_tokens + if num_spec_tokens and self._draft_token_ids is None: + scheduler_output.total_num_scheduled_tokens -= num_spec_tokens + scheduler_output.num_scheduled_tokens[req_id] -= num_spec_tokens + scheduler_output.scheduled_spec_decode_tokens.pop(req_id, None) # Add the new or resumed requests to the persistent batch. # The smaller empty indices are filled first. for request in reqs_to_add: @@ -959,7 +1021,10 @@ class GPUModelRunner( return cu_num_tokens, arange def _prepare_input_ids( - self, total_num_scheduled_tokens: int, cu_num_tokens: np.ndarray + self, + scheduler_output: "SchedulerOutput", + total_num_scheduled_tokens: int, + cu_num_tokens: np.ndarray, ) -> None: """Prepare the input IDs for the current batch. @@ -980,21 +1045,43 @@ class GPUModelRunner( # on the GPU from prev_sampled_token_ids. prev_req_id_to_index = self.input_batch.prev_req_id_to_index assert prev_req_id_to_index is not None - flattened_indices = [] - prev_common_req_indices = [] + sample_flattened_indices: list[int] = [] + spec_flattened_indices: list[int] = [] + prev_common_req_indices: list[int] = [] + prev_draft_token_indices: list[int] = [] indices_match = True max_flattened_index = -1 + total_num_spec_tokens = 0 + scheduled_spec_tokens = scheduler_output.scheduled_spec_decode_tokens + for req_id, cur_index in self.input_batch.req_id_to_index.items(): if (prev_index := prev_req_id_to_index.get(req_id)) is not None: prev_common_req_indices.append(prev_index) # We need to compute the flattened input_ids index of the # last token in each common request. + draft_len = len(scheduled_spec_tokens.get(req_id, ())) + total_num_spec_tokens += draft_len flattened_index = cu_num_tokens[cur_index].item() - 1 - flattened_indices.append(flattened_index) + # example: cu_num_tokens = [2, 5, 8], draft_tokens = [1, 2, 2] + # sample_flattened_indices = [0, 2, 5] + # spec_flattened_indices = [1, 3, 4, 6, 7] + sample_flattened_indices.append(flattened_index - draft_len) + spec_flattened_indices.extend( + range(flattened_index - draft_len + 1, flattened_index + 1) + ) + start = prev_index * self.num_spec_tokens + # prev_draft_token_indices is used to find which draft_tokens_id + # should be copied to input_ids + # example: prev draft_tokens_id [[1,2], [3,4], [5, 6]] + # flatten draft_tokens_id [1,2,3,4,5,6] + # draft_len of each request [1, 2, 1] + # then prev_draft_token_indices is [0, 2, 3, 4] + prev_draft_token_indices.extend(range(start, start + draft_len)) indices_match &= prev_index == flattened_index max_flattened_index = max(max_flattened_index, flattened_index) - num_commmon_tokens = len(flattened_indices) - if num_commmon_tokens < total_num_scheduled_tokens: + num_commmon_tokens = len(sample_flattened_indices) + total_without_spec = total_num_scheduled_tokens - total_num_spec_tokens + if num_commmon_tokens < total_without_spec: # If not all requests are decodes from the last iteration, # We need to copy the input_ids_cpu to the GPU first. self.input_ids.copy_to_gpu(total_num_scheduled_tokens) @@ -1018,20 +1105,43 @@ class GPUModelRunner( self.is_token_ids.gpu[:num_commmon_tokens] = True return # Upload the index tensors asynchronously so the scatter can be non-blocking. - input_ids_index_tensor = torch.tensor( - flattened_indices, dtype=torch.int64, pin_memory=self.pin_memory + sampled_tokens_index_tensor = torch.tensor( + sample_flattened_indices, dtype=torch.int64, pin_memory=self.pin_memory ).to(self.device, non_blocking=True) prev_common_req_indices_tensor = torch.tensor( prev_common_req_indices, dtype=torch.int64, pin_memory=self.pin_memory ).to(self.device, non_blocking=True) self.input_ids.gpu.scatter_( dim=0, - index=input_ids_index_tensor, + index=sampled_tokens_index_tensor, src=self.input_batch.prev_sampled_token_ids[ prev_common_req_indices_tensor, 0 ], ) + # Scatter the draft tokens after the sampled tokens are scattered. + if self._draft_token_ids is None or not spec_flattened_indices: + return + + assert isinstance(self._draft_token_ids, torch.Tensor) + draft_tokens_index_tensor = torch.tensor( + spec_flattened_indices, dtype=torch.int64, pin_memory=self.pin_memory + ).to(self.device, non_blocking=True) + prev_draft_token_indices_tensor = torch.tensor( + prev_draft_token_indices, dtype=torch.int64, pin_memory=self.pin_memory + ).to(self.device, non_blocking=True) + + # because input_ids dtype is torch.int32, + # so convert draft_token_ids to torch.int32 here. + draft_token_ids = self._draft_token_ids.to(dtype=torch.int32) + self._draft_token_ids = None + + self.input_ids.gpu.scatter_( + dim=0, + index=draft_tokens_index_tensor, + src=draft_token_ids.flatten()[prev_draft_token_indices_tensor], + ) + def _get_encoder_seq_lens( self, scheduled_encoder_inputs: dict[str, list[int]], @@ -1218,7 +1328,11 @@ class GPUModelRunner( self.discard_request_indices.copy_to_gpu(self.num_discarded_requests) # Copy the tensors to the GPU. - self._prepare_input_ids(total_num_scheduled_tokens, cu_num_tokens) + self._prepare_input_ids( + scheduler_output, + total_num_scheduled_tokens, + cu_num_tokens, + ) if self.uses_mrope: # Only relevant for models using M-RoPE (e.g, Qwen2-VL) @@ -2377,12 +2491,14 @@ class GPUModelRunner( valid_sampled_token_ids = [] invalid_req_indices = discard_sampled_tokens_req_indices.tolist() invalid_req_indices_set = set(invalid_req_indices) - assert sampled_token_ids.shape[-1] == 1 # Cache the sampled tokens on the GPU and avoid CPU sync. # These will be copied into input_ids in the next step # when preparing inputs. - self.input_batch.prev_sampled_token_ids = sampled_token_ids + # With spec decoding, this is done in propose_draft_token_ids(). + if self.input_batch.prev_sampled_token_ids is None: + assert sampled_token_ids.shape[-1] == 1 + self.input_batch.prev_sampled_token_ids = sampled_token_ids self.input_batch.prev_req_id_to_index = { req_id: i for i, req_id in enumerate(self.input_batch.req_ids) @@ -2517,6 +2633,21 @@ class GPUModelRunner( "State error: sample_tokens() must be called " "after execute_model() returns None." ) + + # self._draft_token_ids is None when `input_fits_in_drafter=False` + # and there is no draft tokens scheduled. so it need to update the + # spec_decoding info in scheduler_output with async_scheduling. + # use deepcopy to avoid the modification has influence on the + # scheduler_output in engine core process. + # TODO(Ronald1995): deepcopy is expensive when there is a large + # number of requests, optimize it later. + if ( + self.use_async_scheduling + and self.num_spec_tokens + and self._draft_token_ids is None + ): + scheduler_output = deepcopy(scheduler_output) + num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens with record_function_or_nullcontext("gpu_model_runner: preprocess"): with self.synchronize_input_prep(): @@ -2759,6 +2890,8 @@ class GPUModelRunner( with record_function_or_nullcontext("gpu_model_runner: sample"): sampler_output = self._sample(logits, spec_decode_metadata) + self.input_batch.prev_sampled_token_ids = None + def propose_draft_token_ids( sampled_token_ids: torch.Tensor | list[np.ndarray], ) -> None: @@ -2792,14 +2925,29 @@ class GPUModelRunner( self.speculative_config.draft_model_config.max_model_len ) input_fits_in_drafter = spec_decode_common_attn_metadata and ( - spec_decode_common_attn_metadata.max_seq_len - + self.speculative_config.num_speculative_tokens + spec_decode_common_attn_metadata.max_seq_len + self.num_spec_tokens <= effective_drafter_max_model_len ) - if use_padded_batch_for_eagle and input_fits_in_drafter: - # EAGLE speculative decoding can use the GPU sampled tokens - # as inputs, and does not need to wait for bookkeeping to finish. - propose_draft_token_ids(sampler_output.sampled_token_ids) + if use_padded_batch_for_eagle: + sampled_token_ids = sampler_output.sampled_token_ids + if input_fits_in_drafter: + # EAGLE speculative decoding can use the GPU sampled tokens + # as inputs, and does not need to wait for bookkeeping to finish. + propose_draft_token_ids(sampled_token_ids) + elif self.valid_sampled_token_count_event is not None: + next_token_ids, valid_sampled_tokens_count = ( + self.drafter.prepare_next_token_ids_padded( + spec_decode_common_attn_metadata, + sampled_token_ids, + self.requests, + self.input_batch, + self.discard_request_indices.gpu, + self.num_discarded_requests, + ) + ) + self._copy_valid_sampled_token_count( + next_token_ids, valid_sampled_tokens_count + ) with record_function_or_nullcontext("gpu_model_runner: bookkeep"): ( @@ -2856,6 +3004,7 @@ class GPUModelRunner( logprobs_tensors=sampler_output.logprobs_tensors, invalid_req_indices=invalid_req_indices, async_output_copy_stream=self.async_output_copy_stream, + vocab_size=self.input_batch.vocab_size, ) with record_function_or_nullcontext( "gpu_model_runner: set_async_sampled_token_ids" @@ -2880,6 +3029,37 @@ class GPUModelRunner( self._draft_token_ids = None return DraftTokenIds(req_ids, draft_token_ids) + def _copy_valid_sampled_token_count( + self, next_token_ids: torch.Tensor, valid_sampled_tokens_count: torch.Tensor + ) -> None: + if self.valid_sampled_token_count_event is None: + return + + default_stream = torch.cuda.current_stream() + # Initialize a new stream to overlap the copy operation with + # prepare_input of draft model. + with torch.cuda.stream(self.valid_sampled_token_count_copy_stream): + self.valid_sampled_token_count_copy_stream.wait_stream(default_stream) # type: ignore + counts = valid_sampled_tokens_count + counts_cpu = self.valid_sampled_token_count_cpu + counts_cpu[: counts.shape[0]].copy_(counts, non_blocking=True) + self.valid_sampled_token_count_event.record() + + self.input_batch.prev_sampled_token_ids = next_token_ids.unsqueeze(1) + + def _get_valid_sampled_token_count(self) -> list[int]: + # Wait until valid_sampled_tokens_count is copied to cpu, + prev_sampled_token_ids = self.input_batch.prev_sampled_token_ids + if ( + self.valid_sampled_token_count_event is None + or prev_sampled_token_ids is None + ): + return [] + + counts_cpu = self.valid_sampled_token_count_cpu + self.valid_sampled_token_count_event.synchronize() + return counts_cpu[: prev_sampled_token_ids.shape[0]].tolist() + def propose_draft_token_ids( self, scheduler_output: "SchedulerOutput", @@ -2967,6 +3147,9 @@ class GPUModelRunner( self.num_discarded_requests, ) ) + self._copy_valid_sampled_token_count( + next_token_ids, valid_sampled_tokens_count + ) if spec_decode_metadata is None: token_indices_to_sample = None @@ -3532,7 +3715,7 @@ class GPUModelRunner( # TODO(luka) better system for describing dummy batches seq_lens = [1] * num_decode_tokens + [num_prefill_tokens + 1] else: - seq_lens = max_query_len + seq_lens = max_query_len # type: ignore[assignment] self.seq_lens.np[:num_reqs] = seq_lens self.seq_lens.np[num_reqs:] = 0 self.seq_lens.copy_to_gpu() @@ -4485,11 +4668,7 @@ class GPUModelRunner( logitsprocs=self.input_batch.logitsprocs, logitsprocs_need_output_token_ids=self.input_batch.logitsprocs_need_output_token_ids, is_pooling_model=self.is_pooling_model, - num_speculative_tokens=( - self.vllm_config.speculative_config.num_speculative_tokens - if self.vllm_config.speculative_config - else 0 - ), + num_speculative_tokens=self.num_spec_tokens, ) def _allocate_kv_cache_tensors(