[Core] Async Scheduling X Spec Decoding Compatibility (#24799)

Signed-off-by: Ronald1995 <ronaldautomobile@163.com>
Signed-off-by: Nick Hill <nhill@redhat.com>
Signed-off-by: Benjamin Chislett <chislett.ben@gmail.com>
Co-authored-by: Nick Hill <nhill@redhat.com>
Co-authored-by: Benjamin Chislett <chislett.ben@gmail.com>
This commit is contained in:
Ronald 2025-11-18 04:16:20 +08:00 committed by GitHub
parent f8b19c0ffd
commit d8874c61a5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 314 additions and 98 deletions

View File

@ -15,7 +15,7 @@ from ...conftest import VllmRunner
from ...models.utils import check_outputs_equal from ...models.utils import check_outputs_equal
MODEL = "Qwen/Qwen3-0.6B" MODEL = "Qwen/Qwen3-0.6B"
MTP_MODEL = "XiaomiMiMo/MiMo-7B-Base" MTP_MODEL = "meta-llama/Llama-3.2-1B-Instruct"
first_prompt = ( first_prompt = (
@ -29,7 +29,8 @@ example_prompts = [first_prompt, "In one word, the capital of France is "] + [
default_params = dict( default_params = dict(
temperature=0.0, # greedy temperature=0.0, # greedy
max_tokens=20, max_tokens=23,
min_tokens=18,
) )
@ -69,15 +70,9 @@ def test_without_spec_decoding(
(True, "uni", True, None, True), (True, "uni", True, None, True),
] ]
run_tests( run_tests(monkeypatch, MODEL, test_configs, test_sampling_params)
monkeypatch,
MODEL,
test_configs,
test_sampling_params,
)
@pytest.mark.skip("MTP model too big to run in fp32 in CI")
def test_with_spec_decoding(monkeypatch: pytest.MonkeyPatch): def test_with_spec_decoding(monkeypatch: pytest.MonkeyPatch):
"""Test consistency and acceptance rates with some different combos of """Test consistency and acceptance rates with some different combos of
preemption, executor, async scheduling, prefill chunking, preemption, executor, async scheduling, prefill chunking,
@ -85,8 +80,9 @@ def test_with_spec_decoding(monkeypatch: pytest.MonkeyPatch):
""" """
spec_config = { spec_config = {
"method": "mtp", "method": "eagle3",
"num_speculative_tokens": 2, "num_speculative_tokens": 2,
"model": "nm-testing/Llama3_2_1B_speculator.eagle3",
} }
spec_config_short = spec_config | {"max_model_len": 50} spec_config_short = spec_config | {"max_model_len": 50}
@ -106,12 +102,7 @@ def test_with_spec_decoding(monkeypatch: pytest.MonkeyPatch):
(True, "uni", True, spec_config_short, True), (True, "uni", True, spec_config_short, True),
] ]
run_tests( run_tests(monkeypatch, MTP_MODEL, test_configs, [{}])
monkeypatch,
MTP_MODEL,
test_configs,
[{}],
)
@dynamo_config.patch(cache_size_limit=16) @dynamo_config.patch(cache_size_limit=16)
@ -182,15 +173,13 @@ def run_tests(
and test_acceptance_rate is not None and test_acceptance_rate is not None
): ):
if "spec_mml=None" in test_config: if "spec_mml=None" in test_config:
# because the acceptance rate can vary, we use a looser
# tolerance here.
assert ( assert (
pytest.approx(test_acceptance_rate, rel=5e-2) pytest.approx(test_acceptance_rate, rel=5e-2)
== base_acceptance_rate == base_acceptance_rate
) )
else: else:
# Currently the reported acceptance rate is expected to be # Currently the reported acceptance rate is expected to be
# lower when we skip drafting altogether. # lower when we sometimes skip drafting altogether.
assert test_acceptance_rate > 0.05 assert test_acceptance_rate > 0.05
print( print(
f"PASSED: config=[{test_config}], params={params}" f"PASSED: config=[{test_config}], params={params}"
@ -220,6 +209,7 @@ def run_test(
): ):
spec_decoding = spec_config is not None spec_decoding = spec_config is not None
cache_arg: dict[str, Any] = ( cache_arg: dict[str, Any] = (
# Force preemptions
dict(num_gpu_blocks_override=32) dict(num_gpu_blocks_override=32)
if test_preemption if test_preemption
else dict(gpu_memory_utilization=0.9) else dict(gpu_memory_utilization=0.9)
@ -238,6 +228,7 @@ def run_test(
model, model,
max_model_len=512, max_model_len=512,
enable_chunked_prefill=test_prefill_chunking, enable_chunked_prefill=test_prefill_chunking,
# Force prefill chunking
max_num_batched_tokens=48 if test_prefill_chunking else None, max_num_batched_tokens=48 if test_prefill_chunking else None,
# enforce_eager=True, # enforce_eager=True,
async_scheduling=async_scheduling, async_scheduling=async_scheduling,
@ -255,10 +246,7 @@ def run_test(
results.append( results.append(
vllm_model.generate( vllm_model.generate(
example_prompts, example_prompts,
sampling_params=SamplingParams( sampling_params=SamplingParams(**default_params, **override_params),
**default_params,
**override_params,
),
return_logprobs=True, return_logprobs=True,
) )
) )
@ -270,9 +258,7 @@ def run_test(
if test_preemption: if test_preemption:
preemptions = _get_count( preemptions = _get_count(
metrics_before, metrics_before, metrics_after, "vllm:num_preemptions"
metrics_after,
"vllm:num_preemptions",
) )
assert preemptions > 0, "preemption test had no preemptions" assert preemptions > 0, "preemption test had no preemptions"

View File

@ -3,7 +3,7 @@
import ast import ast
import hashlib import hashlib
from typing import TYPE_CHECKING, Any, Literal from typing import TYPE_CHECKING, Any, Literal, get_args
from pydantic import Field, SkipValidation, model_validator from pydantic import Field, SkipValidation, model_validator
from pydantic.dataclasses import dataclass from pydantic.dataclasses import dataclass
@ -29,31 +29,25 @@ else:
logger = init_logger(__name__) logger = init_logger(__name__)
SpeculativeMethod = Literal[ MTPModelTypes = Literal[
"ngram",
"eagle",
"eagle3",
"medusa",
"mlp_speculator",
"draft_model",
"deepseek_mtp",
"ernie_mtp",
"qwen3_next_mtp",
"mimo_mtp",
"longcat_flash_mtp",
"pangu_ultra_moe_mtp",
"mtp",
"suffix",
]
MTP_MODEL_TYPES = (
"deepseek_mtp", "deepseek_mtp",
"mimo_mtp", "mimo_mtp",
"glm4_moe_mtp", "glm4_moe_mtp",
"ernie_mtp", "ernie_mtp",
"qwen3_next_mtp", "qwen3_next_mtp",
"longcat_flash_mtp", "longcat_flash_mtp",
"mtp",
"pangu_ultra_moe_mtp", "pangu_ultra_moe_mtp",
) ]
EagleModelTypes = Literal["eagle", "eagle3", MTPModelTypes]
SpeculativeMethod = Literal[
"ngram",
"medusa",
"mlp_speculator",
"draft_model",
"suffix",
EagleModelTypes,
]
@config @config
@ -244,7 +238,7 @@ class SpeculativeConfig:
# can not be detected, it will be considered as the "draft_model" by # can not be detected, it will be considered as the "draft_model" by
# default. # default.
if self.method in MTP_MODEL_TYPES: if self.method in get_args(MTPModelTypes) and self.method != "mtp":
logger.warning( logger.warning(
"method `%s` is deprecated and replaced with mtp.", self.method "method `%s` is deprecated and replaced with mtp.", self.method
) )
@ -361,7 +355,9 @@ class SpeculativeConfig:
self.method = "medusa" self.method = "medusa"
elif self.draft_model_config.hf_config.model_type == "mlp_speculator": elif self.draft_model_config.hf_config.model_type == "mlp_speculator":
self.method = "mlp_speculator" self.method = "mlp_speculator"
elif self.draft_model_config.hf_config.model_type in MTP_MODEL_TYPES: elif self.draft_model_config.hf_config.model_type in get_args(
MTPModelTypes
):
self.method = "mtp" self.method = "mtp"
if self.num_speculative_tokens > 1: if self.num_speculative_tokens > 1:
logger.warning( logger.warning(

View File

@ -14,13 +14,14 @@ from dataclasses import replace
from datetime import datetime from datetime import datetime
from functools import lru_cache from functools import lru_cache
from pathlib import Path from pathlib import Path
from typing import TYPE_CHECKING, Any, TypeVar from typing import TYPE_CHECKING, Any, TypeVar, get_args
import torch import torch
from pydantic import ConfigDict, Field, model_validator from pydantic import ConfigDict, Field, model_validator
from pydantic.dataclasses import dataclass from pydantic.dataclasses import dataclass
import vllm.envs as envs import vllm.envs as envs
from vllm.config.speculative import EagleModelTypes
from vllm.logger import enable_trace_function_call, init_logger from vllm.logger import enable_trace_function_call, init_logger
from vllm.transformers_utils.runai_utils import is_runai_obj_uri from vllm.transformers_utils.runai_utils import is_runai_obj_uri
from vllm.utils import random_uuid from vllm.utils import random_uuid
@ -374,10 +375,22 @@ class VllmConfig:
"Async scheduling is not yet compatible with " "Async scheduling is not yet compatible with "
"pipeline_parallel_size > 1." "pipeline_parallel_size > 1."
) )
# Currently, async scheduling only support eagle speculative
# decoding.
if self.speculative_config is not None: if self.speculative_config is not None:
raise ValueError( if self.speculative_config.method not in get_args(EagleModelTypes):
"Async scheduling is not yet compatible with speculative decoding." raise ValueError(
) "Currently, async scheduling is only supported "
"with EAGLE/MTP kind of speculative decoding"
)
if self.speculative_config.disable_padded_drafter_batch:
raise ValueError(
"async scheduling for EAGLE/MTP kind of speculative "
"decoding is enabled, but disable_padded_drafter_batch=True "
"disable_padded_drafter_batch=True is not supported for "
"this situation now. please set "
"disable_padded_drafter_batch=Fasle"
)
if not executor_supports_async_sched: if not executor_supports_async_sched:
raise ValueError( raise ValueError(
"Currently, async scheduling only supports `mp`, `uni`, or " "Currently, async scheduling only supports `mp`, `uni`, or "

View File

@ -16,18 +16,25 @@ class AsyncScheduler(Scheduler):
) -> None: ) -> None:
super()._update_after_schedule(scheduler_output) super()._update_after_schedule(scheduler_output)
pending_structured_output_tokens = False pending_structured_output_tokens = False
spec_decode_tokens = scheduler_output.scheduled_spec_decode_tokens
for req_id in scheduler_output.num_scheduled_tokens: for req_id in scheduler_output.num_scheduled_tokens:
request = self.requests[req_id] request = self.requests[req_id]
pending_structured_output_tokens |= ( pending_structured_output_tokens |= (
request.use_structured_output and request.num_output_placeholders > 0 request.use_structured_output and request.num_output_placeholders > 0
) )
cur_num_spec_tokens = len(spec_decode_tokens.get(req_id, ()))
if ( if (
request.num_computed_tokens request.num_computed_tokens
== request.num_tokens + request.num_output_placeholders == request.num_tokens
+ request.num_output_placeholders
+ cur_num_spec_tokens
): ):
# The request will generate a new token in this scheduling step. # The request will generate a new token plus num_spec_tokens
# TODO(woosuk): Support speculative decoding. # in this scheduling step.
request.num_output_placeholders += 1 request.num_output_placeholders += 1 + cur_num_spec_tokens
# Add placeholders for the new tokens in spec_token_ids.
# Wwe will update the actual spec token ids in the worker process.
request.spec_token_ids = [-1] * self.num_spec_tokens
scheduler_output.pending_structured_output_tokens = ( scheduler_output.pending_structured_output_tokens = (
pending_structured_output_tokens pending_structured_output_tokens

View File

@ -348,7 +348,10 @@ class Scheduler(SchedulerInterface):
# Speculative decode related. # Speculative decode related.
if request.spec_token_ids: if request.spec_token_ids:
num_scheduled_spec_tokens = ( num_scheduled_spec_tokens = (
num_new_tokens + request.num_computed_tokens - request.num_tokens num_new_tokens
+ request.num_computed_tokens
- request.num_tokens
- request.num_output_placeholders
) )
if num_scheduled_spec_tokens > 0: if num_scheduled_spec_tokens > 0:
# Trim spec_token_ids list to num_scheduled_spec_tokens. # Trim spec_token_ids list to num_scheduled_spec_tokens.
@ -1024,7 +1027,12 @@ class Scheduler(SchedulerInterface):
# tokens and rejections. If some tokens are rejected, # tokens and rejections. If some tokens are rejected,
# num_computed_tokens is decreased by the number of rejected # num_computed_tokens is decreased by the number of rejected
# tokens. # tokens.
request.num_computed_tokens -= num_rejected if request.num_computed_tokens > 0:
request.num_computed_tokens -= num_rejected
# If async scheduling, num_output_placeholders also includes
# the scheduled spec tokens count and so is similarly adjusted.
if request.num_output_placeholders > 0:
request.num_output_placeholders -= num_rejected
spec_decoding_stats = self.make_spec_decoding_stats( spec_decoding_stats = self.make_spec_decoding_stats(
spec_decoding_stats, spec_decoding_stats,
num_draft_tokens=num_draft_tokens, num_draft_tokens=num_draft_tokens,

View File

@ -198,6 +198,7 @@ class EngineCore:
self.step_fn = ( self.step_fn = (
self.step if self.batch_queue is None else self.step_with_batch_queue self.step if self.batch_queue is None else self.step_with_batch_queue
) )
self.async_scheduling = vllm_config.scheduler_config.async_scheduling
# Mark the startup heap as static so that it's ignored by GC. # Mark the startup heap as static so that it's ignored by GC.
# Reduces pause times of oldest generation collections. # Reduces pause times of oldest generation collections.
@ -341,7 +342,10 @@ class EngineCore:
return engine_core_outputs, scheduler_output.total_num_scheduled_tokens > 0 return engine_core_outputs, scheduler_output.total_num_scheduled_tokens > 0
def post_step(self, model_executed: bool) -> None: def post_step(self, model_executed: bool) -> None:
if self.use_spec_decode and model_executed: # When using async scheduling we can't get draft token ids in advance,
# so we update draft token ids in the worker process and don't
# need to update draft token ids here.
if not self.async_scheduling and self.use_spec_decode and model_executed:
# Take the draft token ids. # Take the draft token ids.
draft_token_ids = self.model_executor.take_draft_token_ids() draft_token_ids = self.model_executor.take_draft_token_ids()
if draft_token_ids is not None: if draft_token_ids is not None:

View File

@ -150,6 +150,23 @@ class Processor:
raise ValueError( raise ValueError(
"vLLM V1 does not support per request user provided logits processors." "vLLM V1 does not support per request user provided logits processors."
) )
# Async scheduling + spec decode currently incompatible with some
# sampling parameters.
if (
self.vllm_config.speculative_config is not None
and self.vllm_config.scheduler_config.async_scheduling
and (
params.frequency_penalty != 0.0
or params.presence_penalty != 0.0
or params.repetition_penalty != 1.0
or params.bad_words_token_ids
or params.structured_outputs
)
):
raise ValueError(
"async scheduling with spec decoding doesn't yet support "
"penalties, bad words or structured outputs in sampling parameters."
)
def _validate_params( def _validate_params(
self, self,

View File

@ -41,7 +41,7 @@ STR_POOLING_REJECTS_LOGITSPROCS = (
# Error message when the user tries to initialize vLLM with a speculative # Error message when the user tries to initialize vLLM with a speculative
# decoding enabled and custom logitsproces # decoding enabled and custom logitsproces
STR_SPEC_DEC_REJECTS_LOGITSPROCS = ( STR_SPEC_DEC_REJECTS_LOGITSPROCS = (
"Custom logits processors are not supportedwhen speculative decoding is enabled." "Custom logits processors are not supported when speculative decoding is enabled."
) )
LOGITSPROCS_GROUP = "vllm.logits_processors" LOGITSPROCS_GROUP = "vllm.logits_processors"

View File

@ -397,10 +397,13 @@ class EagleProposer:
positions += 1 positions += 1
exceeds_max_model_len = positions >= self.max_model_len exceeds_max_model_len = positions >= self.max_model_len
clamped_positions = torch.where(exceeds_max_model_len, 0, positions) clamped_positions = torch.where(exceeds_max_model_len, 0, positions)
# For data integrity when async scheduling, we shouldn't use in place
# operations in case they are modified in next step's `prepare_input`
# of main model.
# Increment the sequence lengths. # Increment the sequence lengths.
common_attn_metadata.seq_lens += 1 common_attn_metadata.seq_lens += 1
common_attn_metadata.seq_lens_cpu += 1 # This is an out-of-place operation to avoid modifying the original tensor.
common_attn_metadata.seq_lens_cpu = common_attn_metadata.seq_lens_cpu + 1
# For the requests that exceed the max model length, we set the # For the requests that exceed the max model length, we set the
# sequence length to 1 to minimize their overheads in attention. # sequence length to 1 to minimize their overheads in attention.

View File

@ -46,6 +46,9 @@ class CachedRequestState:
lora_request: LoRARequest | None = None lora_request: LoRARequest | None = None
prompt_embeds: torch.Tensor | None = None prompt_embeds: torch.Tensor | None = None
# Used when both async_scheduling and spec_decode are enabled.
prev_num_draft_len: int = 0
def __post_init__(self): def __post_init__(self):
self.num_prompt_tokens = length_from_prompt_token_ids_or_embeds( self.num_prompt_tokens = length_from_prompt_token_ids_or_embeds(
self.prompt_token_ids, self.prompt_embeds self.prompt_token_ids, self.prompt_embeds

View File

@ -179,6 +179,7 @@ class AsyncGPUModelRunnerOutput(AsyncModelRunnerOutput):
logprobs_tensors: torch.Tensor | None, logprobs_tensors: torch.Tensor | None,
invalid_req_indices: list[int], invalid_req_indices: list[int],
async_output_copy_stream: torch.cuda.Stream, async_output_copy_stream: torch.cuda.Stream,
vocab_size: int,
): ):
self._model_runner_output = model_runner_output self._model_runner_output = model_runner_output
self._invalid_req_indices = invalid_req_indices self._invalid_req_indices = invalid_req_indices
@ -189,6 +190,7 @@ class AsyncGPUModelRunnerOutput(AsyncModelRunnerOutput):
# Keep a reference to the device tensor to avoid it being # Keep a reference to the device tensor to avoid it being
# deallocated until we finish copying it to the host. # deallocated until we finish copying it to the host.
self._sampled_token_ids = sampled_token_ids self._sampled_token_ids = sampled_token_ids
self.vocab_size = vocab_size
self._logprobs_tensors = logprobs_tensors self._logprobs_tensors = logprobs_tensors
# Initiate the copy on a separate stream, but do not synchronize it. # Initiate the copy on a separate stream, but do not synchronize it.
@ -215,10 +217,16 @@ class AsyncGPUModelRunnerOutput(AsyncModelRunnerOutput):
# Release the device tensors once the copy has completed. # Release the device tensors once the copy has completed.
del self._logprobs_tensors del self._logprobs_tensors
del self._sampled_token_ids del self._sampled_token_ids
max_gen_len = self.sampled_token_ids_cpu.shape[-1]
valid_sampled_token_ids: list[np.ndarray] = [ if max_gen_len == 1:
row for row in self.sampled_token_ids_cpu.numpy() valid_sampled_token_ids: list[np.ndarray] = [
] row for row in self.sampled_token_ids_cpu.numpy()
]
else:
valid_sampled_token_ids = RejectionSampler.parse_output(
self.sampled_token_ids_cpu,
self.vocab_size,
)
for i in self._invalid_req_indices: for i in self._invalid_req_indices:
valid_sampled_token_ids[i] = np.array([]) valid_sampled_token_ids[i] = np.array([])
@ -377,6 +385,10 @@ class GPUModelRunner(
) )
self.rejection_sampler = RejectionSampler(self.sampler) self.rejection_sampler = RejectionSampler(self.sampler)
self.num_spec_tokens = 0
if self.speculative_config:
self.num_spec_tokens = self.speculative_config.num_speculative_tokens
# Request states. # Request states.
self.requests: dict[str, CachedRequestState] = {} self.requests: dict[str, CachedRequestState] = {}
self.comm_stream = torch.cuda.Stream() self.comm_stream = torch.cuda.Stream()
@ -513,11 +525,7 @@ class GPUModelRunner(
self.max_num_tokens, dtype=torch.int32, device=self.device self.max_num_tokens, dtype=torch.int32, device=self.device
) )
self.uniform_decode_query_len = ( self.uniform_decode_query_len = 1 + self.num_spec_tokens
1
if not self.speculative_config
else 1 + self.speculative_config.num_speculative_tokens
)
# Cudagraph dispatcher for runtime cudagraph dispatching. # Cudagraph dispatcher for runtime cudagraph dispatching.
self.cudagraph_dispatcher = CudagraphDispatcher(self.vllm_config) self.cudagraph_dispatcher = CudagraphDispatcher(self.vllm_config)
@ -549,6 +557,20 @@ class GPUModelRunner(
pin_memory=self.pin_memory, pin_memory=self.pin_memory,
) )
# Pre-allocated tensor for copying valid sampled token counts to CPU,
# with dedicated stream for overlapping and event for coordination.
self.valid_sampled_token_count_event: torch.cuda.Event | None = None
self.valid_sampled_token_count_copy_stream: torch.cuda.Stream | None = None
if self.use_async_scheduling and self.num_spec_tokens:
self.valid_sampled_token_count_event = torch.cuda.Event()
self.valid_sampled_token_count_copy_stream = torch.cuda.Stream()
self.valid_sampled_token_count_cpu = torch.empty(
self.max_num_reqs,
dtype=torch.int64,
device="cpu",
pin_memory=self.pin_memory,
)
# Ephemeral state transferred between execute_model() and sample_tokens(). # Ephemeral state transferred between execute_model() and sample_tokens().
self.execute_model_state: ExecuteModelState | None = None self.execute_model_state: ExecuteModelState | None = None
@ -736,17 +758,45 @@ class GPUModelRunner(
# Update the states of the running/resumed requests. # Update the states of the running/resumed requests.
is_last_rank = get_pp_group().is_last_rank is_last_rank = get_pp_group().is_last_rank
req_data = scheduler_output.scheduled_cached_reqs req_data = scheduler_output.scheduled_cached_reqs
# Wait until valid_sampled_tokens_count is copied to cpu,
# then use it to update actual num_computed_tokens of each request.
valid_sampled_token_count = self._get_valid_sampled_token_count()
for i, req_id in enumerate(req_data.req_ids): for i, req_id in enumerate(req_data.req_ids):
req_state = self.requests[req_id] req_state = self.requests[req_id]
num_computed_tokens = req_data.num_computed_tokens[i] num_computed_tokens = req_data.num_computed_tokens[i]
new_block_ids = req_data.new_block_ids[i] new_block_ids = req_data.new_block_ids[i]
resumed_from_preemption = req_id in req_data.resumed_req_ids resumed_from_preemption = req_id in req_data.resumed_req_ids
num_output_tokens = req_data.num_output_tokens[i] num_output_tokens = req_data.num_output_tokens[i]
req_index = self.input_batch.req_id_to_index.get(req_id)
# prev_num_draft_len is used in async scheduling mode with
# spec decode. it indicates if need to update num_computed_tokens
# of the request. for example:
# fist step: num_computed_tokens = 0, spec_tokens = [],
# prev_num_draft_len = 0.
# second step: num_computed_tokens = 100(prompt lenth),
# spec_tokens = [a,b], prev_num_draft_len = 0.
# third step: num_computed_tokens = 100 + 2, spec_tokens = [c,d],
# prev_num_draft_len = 2.
# num_computed_tokens in first step and second step does't contain
# the spec tokens length, but in third step it contains the
# spec tokens length. we only need to update num_computed_tokens
# when prev_num_draft_len > 0.
if req_state.prev_num_draft_len:
if req_index is None:
req_state.prev_num_draft_len = 0
else:
assert self.input_batch.prev_req_id_to_index is not None
prev_req_index = self.input_batch.prev_req_id_to_index[req_id]
num_accepted = valid_sampled_token_count[prev_req_index] - 1
num_rejected = req_state.prev_num_draft_len - num_accepted
num_computed_tokens -= num_rejected
req_state.output_token_ids.extend([-1] * num_accepted)
# Update the cached states. # Update the cached states.
req_state.num_computed_tokens = num_computed_tokens req_state.num_computed_tokens = num_computed_tokens
req_index = self.input_batch.req_id_to_index.get(req_id)
if not is_last_rank: if not is_last_rank:
# When using PP, the scheduler sends the sampled tokens back, # When using PP, the scheduler sends the sampled tokens back,
@ -823,8 +873,11 @@ class GPUModelRunner(
spec_token_ids = scheduler_output.scheduled_spec_decode_tokens.get( spec_token_ids = scheduler_output.scheduled_spec_decode_tokens.get(
req_id, [] req_id, []
) )
if spec_token_ids: num_spec_tokens = len(spec_token_ids)
num_spec_tokens = len(spec_token_ids) # For async scheduling, token_ids_cpu assigned from
# spec_token_ids are placeholders and will be overwritten in
# _prepare_input_ids.
if num_spec_tokens:
start_index = self.input_batch.num_tokens_no_spec[req_index] start_index = self.input_batch.num_tokens_no_spec[req_index]
end_token_index = start_index + num_spec_tokens end_token_index = start_index + num_spec_tokens
self.input_batch.token_ids_cpu[ self.input_batch.token_ids_cpu[
@ -840,6 +893,15 @@ class GPUModelRunner(
# even when speculative decoding is enabled. # even when speculative decoding is enabled.
self.input_batch.spec_token_ids[req_index] = spec_token_ids self.input_batch.spec_token_ids[req_index] = spec_token_ids
# there are no draft tokens with async scheduling,
# we clear the spec_decoding info in scheduler_output and
# use normal sampling but rejection_sampling.
if self.use_async_scheduling:
req_state.prev_num_draft_len = num_spec_tokens
if num_spec_tokens and self._draft_token_ids is None:
scheduler_output.total_num_scheduled_tokens -= num_spec_tokens
scheduler_output.num_scheduled_tokens[req_id] -= num_spec_tokens
scheduler_output.scheduled_spec_decode_tokens.pop(req_id, None)
# Add the new or resumed requests to the persistent batch. # Add the new or resumed requests to the persistent batch.
# The smaller empty indices are filled first. # The smaller empty indices are filled first.
for request in reqs_to_add: for request in reqs_to_add:
@ -959,7 +1021,10 @@ class GPUModelRunner(
return cu_num_tokens, arange return cu_num_tokens, arange
def _prepare_input_ids( def _prepare_input_ids(
self, total_num_scheduled_tokens: int, cu_num_tokens: np.ndarray self,
scheduler_output: "SchedulerOutput",
total_num_scheduled_tokens: int,
cu_num_tokens: np.ndarray,
) -> None: ) -> None:
"""Prepare the input IDs for the current batch. """Prepare the input IDs for the current batch.
@ -980,21 +1045,43 @@ class GPUModelRunner(
# on the GPU from prev_sampled_token_ids. # on the GPU from prev_sampled_token_ids.
prev_req_id_to_index = self.input_batch.prev_req_id_to_index prev_req_id_to_index = self.input_batch.prev_req_id_to_index
assert prev_req_id_to_index is not None assert prev_req_id_to_index is not None
flattened_indices = [] sample_flattened_indices: list[int] = []
prev_common_req_indices = [] spec_flattened_indices: list[int] = []
prev_common_req_indices: list[int] = []
prev_draft_token_indices: list[int] = []
indices_match = True indices_match = True
max_flattened_index = -1 max_flattened_index = -1
total_num_spec_tokens = 0
scheduled_spec_tokens = scheduler_output.scheduled_spec_decode_tokens
for req_id, cur_index in self.input_batch.req_id_to_index.items(): for req_id, cur_index in self.input_batch.req_id_to_index.items():
if (prev_index := prev_req_id_to_index.get(req_id)) is not None: if (prev_index := prev_req_id_to_index.get(req_id)) is not None:
prev_common_req_indices.append(prev_index) prev_common_req_indices.append(prev_index)
# We need to compute the flattened input_ids index of the # We need to compute the flattened input_ids index of the
# last token in each common request. # last token in each common request.
draft_len = len(scheduled_spec_tokens.get(req_id, ()))
total_num_spec_tokens += draft_len
flattened_index = cu_num_tokens[cur_index].item() - 1 flattened_index = cu_num_tokens[cur_index].item() - 1
flattened_indices.append(flattened_index) # example: cu_num_tokens = [2, 5, 8], draft_tokens = [1, 2, 2]
# sample_flattened_indices = [0, 2, 5]
# spec_flattened_indices = [1, 3, 4, 6, 7]
sample_flattened_indices.append(flattened_index - draft_len)
spec_flattened_indices.extend(
range(flattened_index - draft_len + 1, flattened_index + 1)
)
start = prev_index * self.num_spec_tokens
# prev_draft_token_indices is used to find which draft_tokens_id
# should be copied to input_ids
# example: prev draft_tokens_id [[1,2], [3,4], [5, 6]]
# flatten draft_tokens_id [1,2,3,4,5,6]
# draft_len of each request [1, 2, 1]
# then prev_draft_token_indices is [0, 2, 3, 4]
prev_draft_token_indices.extend(range(start, start + draft_len))
indices_match &= prev_index == flattened_index indices_match &= prev_index == flattened_index
max_flattened_index = max(max_flattened_index, flattened_index) max_flattened_index = max(max_flattened_index, flattened_index)
num_commmon_tokens = len(flattened_indices) num_commmon_tokens = len(sample_flattened_indices)
if num_commmon_tokens < total_num_scheduled_tokens: total_without_spec = total_num_scheduled_tokens - total_num_spec_tokens
if num_commmon_tokens < total_without_spec:
# If not all requests are decodes from the last iteration, # If not all requests are decodes from the last iteration,
# We need to copy the input_ids_cpu to the GPU first. # We need to copy the input_ids_cpu to the GPU first.
self.input_ids.copy_to_gpu(total_num_scheduled_tokens) self.input_ids.copy_to_gpu(total_num_scheduled_tokens)
@ -1018,20 +1105,43 @@ class GPUModelRunner(
self.is_token_ids.gpu[:num_commmon_tokens] = True self.is_token_ids.gpu[:num_commmon_tokens] = True
return return
# Upload the index tensors asynchronously so the scatter can be non-blocking. # Upload the index tensors asynchronously so the scatter can be non-blocking.
input_ids_index_tensor = torch.tensor( sampled_tokens_index_tensor = torch.tensor(
flattened_indices, dtype=torch.int64, pin_memory=self.pin_memory sample_flattened_indices, dtype=torch.int64, pin_memory=self.pin_memory
).to(self.device, non_blocking=True) ).to(self.device, non_blocking=True)
prev_common_req_indices_tensor = torch.tensor( prev_common_req_indices_tensor = torch.tensor(
prev_common_req_indices, dtype=torch.int64, pin_memory=self.pin_memory prev_common_req_indices, dtype=torch.int64, pin_memory=self.pin_memory
).to(self.device, non_blocking=True) ).to(self.device, non_blocking=True)
self.input_ids.gpu.scatter_( self.input_ids.gpu.scatter_(
dim=0, dim=0,
index=input_ids_index_tensor, index=sampled_tokens_index_tensor,
src=self.input_batch.prev_sampled_token_ids[ src=self.input_batch.prev_sampled_token_ids[
prev_common_req_indices_tensor, 0 prev_common_req_indices_tensor, 0
], ],
) )
# Scatter the draft tokens after the sampled tokens are scattered.
if self._draft_token_ids is None or not spec_flattened_indices:
return
assert isinstance(self._draft_token_ids, torch.Tensor)
draft_tokens_index_tensor = torch.tensor(
spec_flattened_indices, dtype=torch.int64, pin_memory=self.pin_memory
).to(self.device, non_blocking=True)
prev_draft_token_indices_tensor = torch.tensor(
prev_draft_token_indices, dtype=torch.int64, pin_memory=self.pin_memory
).to(self.device, non_blocking=True)
# because input_ids dtype is torch.int32,
# so convert draft_token_ids to torch.int32 here.
draft_token_ids = self._draft_token_ids.to(dtype=torch.int32)
self._draft_token_ids = None
self.input_ids.gpu.scatter_(
dim=0,
index=draft_tokens_index_tensor,
src=draft_token_ids.flatten()[prev_draft_token_indices_tensor],
)
def _get_encoder_seq_lens( def _get_encoder_seq_lens(
self, self,
scheduled_encoder_inputs: dict[str, list[int]], scheduled_encoder_inputs: dict[str, list[int]],
@ -1218,7 +1328,11 @@ class GPUModelRunner(
self.discard_request_indices.copy_to_gpu(self.num_discarded_requests) self.discard_request_indices.copy_to_gpu(self.num_discarded_requests)
# Copy the tensors to the GPU. # Copy the tensors to the GPU.
self._prepare_input_ids(total_num_scheduled_tokens, cu_num_tokens) self._prepare_input_ids(
scheduler_output,
total_num_scheduled_tokens,
cu_num_tokens,
)
if self.uses_mrope: if self.uses_mrope:
# Only relevant for models using M-RoPE (e.g, Qwen2-VL) # Only relevant for models using M-RoPE (e.g, Qwen2-VL)
@ -2377,12 +2491,14 @@ class GPUModelRunner(
valid_sampled_token_ids = [] valid_sampled_token_ids = []
invalid_req_indices = discard_sampled_tokens_req_indices.tolist() invalid_req_indices = discard_sampled_tokens_req_indices.tolist()
invalid_req_indices_set = set(invalid_req_indices) invalid_req_indices_set = set(invalid_req_indices)
assert sampled_token_ids.shape[-1] == 1
# Cache the sampled tokens on the GPU and avoid CPU sync. # Cache the sampled tokens on the GPU and avoid CPU sync.
# These will be copied into input_ids in the next step # These will be copied into input_ids in the next step
# when preparing inputs. # when preparing inputs.
self.input_batch.prev_sampled_token_ids = sampled_token_ids # With spec decoding, this is done in propose_draft_token_ids().
if self.input_batch.prev_sampled_token_ids is None:
assert sampled_token_ids.shape[-1] == 1
self.input_batch.prev_sampled_token_ids = sampled_token_ids
self.input_batch.prev_req_id_to_index = { self.input_batch.prev_req_id_to_index = {
req_id: i req_id: i
for i, req_id in enumerate(self.input_batch.req_ids) for i, req_id in enumerate(self.input_batch.req_ids)
@ -2517,6 +2633,21 @@ class GPUModelRunner(
"State error: sample_tokens() must be called " "State error: sample_tokens() must be called "
"after execute_model() returns None." "after execute_model() returns None."
) )
# self._draft_token_ids is None when `input_fits_in_drafter=False`
# and there is no draft tokens scheduled. so it need to update the
# spec_decoding info in scheduler_output with async_scheduling.
# use deepcopy to avoid the modification has influence on the
# scheduler_output in engine core process.
# TODO(Ronald1995): deepcopy is expensive when there is a large
# number of requests, optimize it later.
if (
self.use_async_scheduling
and self.num_spec_tokens
and self._draft_token_ids is None
):
scheduler_output = deepcopy(scheduler_output)
num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
with record_function_or_nullcontext("gpu_model_runner: preprocess"): with record_function_or_nullcontext("gpu_model_runner: preprocess"):
with self.synchronize_input_prep(): with self.synchronize_input_prep():
@ -2759,6 +2890,8 @@ class GPUModelRunner(
with record_function_or_nullcontext("gpu_model_runner: sample"): with record_function_or_nullcontext("gpu_model_runner: sample"):
sampler_output = self._sample(logits, spec_decode_metadata) sampler_output = self._sample(logits, spec_decode_metadata)
self.input_batch.prev_sampled_token_ids = None
def propose_draft_token_ids( def propose_draft_token_ids(
sampled_token_ids: torch.Tensor | list[np.ndarray], sampled_token_ids: torch.Tensor | list[np.ndarray],
) -> None: ) -> None:
@ -2792,14 +2925,29 @@ class GPUModelRunner(
self.speculative_config.draft_model_config.max_model_len self.speculative_config.draft_model_config.max_model_len
) )
input_fits_in_drafter = spec_decode_common_attn_metadata and ( input_fits_in_drafter = spec_decode_common_attn_metadata and (
spec_decode_common_attn_metadata.max_seq_len spec_decode_common_attn_metadata.max_seq_len + self.num_spec_tokens
+ self.speculative_config.num_speculative_tokens
<= effective_drafter_max_model_len <= effective_drafter_max_model_len
) )
if use_padded_batch_for_eagle and input_fits_in_drafter: if use_padded_batch_for_eagle:
# EAGLE speculative decoding can use the GPU sampled tokens sampled_token_ids = sampler_output.sampled_token_ids
# as inputs, and does not need to wait for bookkeeping to finish. if input_fits_in_drafter:
propose_draft_token_ids(sampler_output.sampled_token_ids) # EAGLE speculative decoding can use the GPU sampled tokens
# as inputs, and does not need to wait for bookkeeping to finish.
propose_draft_token_ids(sampled_token_ids)
elif self.valid_sampled_token_count_event is not None:
next_token_ids, valid_sampled_tokens_count = (
self.drafter.prepare_next_token_ids_padded(
spec_decode_common_attn_metadata,
sampled_token_ids,
self.requests,
self.input_batch,
self.discard_request_indices.gpu,
self.num_discarded_requests,
)
)
self._copy_valid_sampled_token_count(
next_token_ids, valid_sampled_tokens_count
)
with record_function_or_nullcontext("gpu_model_runner: bookkeep"): with record_function_or_nullcontext("gpu_model_runner: bookkeep"):
( (
@ -2856,6 +3004,7 @@ class GPUModelRunner(
logprobs_tensors=sampler_output.logprobs_tensors, logprobs_tensors=sampler_output.logprobs_tensors,
invalid_req_indices=invalid_req_indices, invalid_req_indices=invalid_req_indices,
async_output_copy_stream=self.async_output_copy_stream, async_output_copy_stream=self.async_output_copy_stream,
vocab_size=self.input_batch.vocab_size,
) )
with record_function_or_nullcontext( with record_function_or_nullcontext(
"gpu_model_runner: set_async_sampled_token_ids" "gpu_model_runner: set_async_sampled_token_ids"
@ -2880,6 +3029,37 @@ class GPUModelRunner(
self._draft_token_ids = None self._draft_token_ids = None
return DraftTokenIds(req_ids, draft_token_ids) return DraftTokenIds(req_ids, draft_token_ids)
def _copy_valid_sampled_token_count(
self, next_token_ids: torch.Tensor, valid_sampled_tokens_count: torch.Tensor
) -> None:
if self.valid_sampled_token_count_event is None:
return
default_stream = torch.cuda.current_stream()
# Initialize a new stream to overlap the copy operation with
# prepare_input of draft model.
with torch.cuda.stream(self.valid_sampled_token_count_copy_stream):
self.valid_sampled_token_count_copy_stream.wait_stream(default_stream) # type: ignore
counts = valid_sampled_tokens_count
counts_cpu = self.valid_sampled_token_count_cpu
counts_cpu[: counts.shape[0]].copy_(counts, non_blocking=True)
self.valid_sampled_token_count_event.record()
self.input_batch.prev_sampled_token_ids = next_token_ids.unsqueeze(1)
def _get_valid_sampled_token_count(self) -> list[int]:
# Wait until valid_sampled_tokens_count is copied to cpu,
prev_sampled_token_ids = self.input_batch.prev_sampled_token_ids
if (
self.valid_sampled_token_count_event is None
or prev_sampled_token_ids is None
):
return []
counts_cpu = self.valid_sampled_token_count_cpu
self.valid_sampled_token_count_event.synchronize()
return counts_cpu[: prev_sampled_token_ids.shape[0]].tolist()
def propose_draft_token_ids( def propose_draft_token_ids(
self, self,
scheduler_output: "SchedulerOutput", scheduler_output: "SchedulerOutput",
@ -2967,6 +3147,9 @@ class GPUModelRunner(
self.num_discarded_requests, self.num_discarded_requests,
) )
) )
self._copy_valid_sampled_token_count(
next_token_ids, valid_sampled_tokens_count
)
if spec_decode_metadata is None: if spec_decode_metadata is None:
token_indices_to_sample = None token_indices_to_sample = None
@ -3532,7 +3715,7 @@ class GPUModelRunner(
# TODO(luka) better system for describing dummy batches # TODO(luka) better system for describing dummy batches
seq_lens = [1] * num_decode_tokens + [num_prefill_tokens + 1] seq_lens = [1] * num_decode_tokens + [num_prefill_tokens + 1]
else: else:
seq_lens = max_query_len seq_lens = max_query_len # type: ignore[assignment]
self.seq_lens.np[:num_reqs] = seq_lens self.seq_lens.np[:num_reqs] = seq_lens
self.seq_lens.np[num_reqs:] = 0 self.seq_lens.np[num_reqs:] = 0
self.seq_lens.copy_to_gpu() self.seq_lens.copy_to_gpu()
@ -4485,11 +4668,7 @@ class GPUModelRunner(
logitsprocs=self.input_batch.logitsprocs, logitsprocs=self.input_batch.logitsprocs,
logitsprocs_need_output_token_ids=self.input_batch.logitsprocs_need_output_token_ids, logitsprocs_need_output_token_ids=self.input_batch.logitsprocs_need_output_token_ids,
is_pooling_model=self.is_pooling_model, is_pooling_model=self.is_pooling_model,
num_speculative_tokens=( num_speculative_tokens=self.num_spec_tokens,
self.vllm_config.speculative_config.num_speculative_tokens
if self.vllm_config.speculative_config
else 0
),
) )
def _allocate_kv_cache_tensors( def _allocate_kv_cache_tensors(