mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 00:15:51 +08:00
[Core] Async Scheduling X Spec Decoding Compatibility (#24799)
Signed-off-by: Ronald1995 <ronaldautomobile@163.com> Signed-off-by: Nick Hill <nhill@redhat.com> Signed-off-by: Benjamin Chislett <chislett.ben@gmail.com> Co-authored-by: Nick Hill <nhill@redhat.com> Co-authored-by: Benjamin Chislett <chislett.ben@gmail.com>
This commit is contained in:
parent
f8b19c0ffd
commit
d8874c61a5
@ -15,7 +15,7 @@ from ...conftest import VllmRunner
|
||||
from ...models.utils import check_outputs_equal
|
||||
|
||||
MODEL = "Qwen/Qwen3-0.6B"
|
||||
MTP_MODEL = "XiaomiMiMo/MiMo-7B-Base"
|
||||
MTP_MODEL = "meta-llama/Llama-3.2-1B-Instruct"
|
||||
|
||||
|
||||
first_prompt = (
|
||||
@ -29,7 +29,8 @@ example_prompts = [first_prompt, "In one word, the capital of France is "] + [
|
||||
|
||||
default_params = dict(
|
||||
temperature=0.0, # greedy
|
||||
max_tokens=20,
|
||||
max_tokens=23,
|
||||
min_tokens=18,
|
||||
)
|
||||
|
||||
|
||||
@ -69,15 +70,9 @@ def test_without_spec_decoding(
|
||||
(True, "uni", True, None, True),
|
||||
]
|
||||
|
||||
run_tests(
|
||||
monkeypatch,
|
||||
MODEL,
|
||||
test_configs,
|
||||
test_sampling_params,
|
||||
)
|
||||
run_tests(monkeypatch, MODEL, test_configs, test_sampling_params)
|
||||
|
||||
|
||||
@pytest.mark.skip("MTP model too big to run in fp32 in CI")
|
||||
def test_with_spec_decoding(monkeypatch: pytest.MonkeyPatch):
|
||||
"""Test consistency and acceptance rates with some different combos of
|
||||
preemption, executor, async scheduling, prefill chunking,
|
||||
@ -85,8 +80,9 @@ def test_with_spec_decoding(monkeypatch: pytest.MonkeyPatch):
|
||||
"""
|
||||
|
||||
spec_config = {
|
||||
"method": "mtp",
|
||||
"method": "eagle3",
|
||||
"num_speculative_tokens": 2,
|
||||
"model": "nm-testing/Llama3_2_1B_speculator.eagle3",
|
||||
}
|
||||
spec_config_short = spec_config | {"max_model_len": 50}
|
||||
|
||||
@ -106,12 +102,7 @@ def test_with_spec_decoding(monkeypatch: pytest.MonkeyPatch):
|
||||
(True, "uni", True, spec_config_short, True),
|
||||
]
|
||||
|
||||
run_tests(
|
||||
monkeypatch,
|
||||
MTP_MODEL,
|
||||
test_configs,
|
||||
[{}],
|
||||
)
|
||||
run_tests(monkeypatch, MTP_MODEL, test_configs, [{}])
|
||||
|
||||
|
||||
@dynamo_config.patch(cache_size_limit=16)
|
||||
@ -182,15 +173,13 @@ def run_tests(
|
||||
and test_acceptance_rate is not None
|
||||
):
|
||||
if "spec_mml=None" in test_config:
|
||||
# because the acceptance rate can vary, we use a looser
|
||||
# tolerance here.
|
||||
assert (
|
||||
pytest.approx(test_acceptance_rate, rel=5e-2)
|
||||
== base_acceptance_rate
|
||||
)
|
||||
else:
|
||||
# Currently the reported acceptance rate is expected to be
|
||||
# lower when we skip drafting altogether.
|
||||
# lower when we sometimes skip drafting altogether.
|
||||
assert test_acceptance_rate > 0.05
|
||||
print(
|
||||
f"PASSED: config=[{test_config}], params={params}"
|
||||
@ -220,6 +209,7 @@ def run_test(
|
||||
):
|
||||
spec_decoding = spec_config is not None
|
||||
cache_arg: dict[str, Any] = (
|
||||
# Force preemptions
|
||||
dict(num_gpu_blocks_override=32)
|
||||
if test_preemption
|
||||
else dict(gpu_memory_utilization=0.9)
|
||||
@ -238,6 +228,7 @@ def run_test(
|
||||
model,
|
||||
max_model_len=512,
|
||||
enable_chunked_prefill=test_prefill_chunking,
|
||||
# Force prefill chunking
|
||||
max_num_batched_tokens=48 if test_prefill_chunking else None,
|
||||
# enforce_eager=True,
|
||||
async_scheduling=async_scheduling,
|
||||
@ -255,10 +246,7 @@ def run_test(
|
||||
results.append(
|
||||
vllm_model.generate(
|
||||
example_prompts,
|
||||
sampling_params=SamplingParams(
|
||||
**default_params,
|
||||
**override_params,
|
||||
),
|
||||
sampling_params=SamplingParams(**default_params, **override_params),
|
||||
return_logprobs=True,
|
||||
)
|
||||
)
|
||||
@ -270,9 +258,7 @@ def run_test(
|
||||
|
||||
if test_preemption:
|
||||
preemptions = _get_count(
|
||||
metrics_before,
|
||||
metrics_after,
|
||||
"vllm:num_preemptions",
|
||||
metrics_before, metrics_after, "vllm:num_preemptions"
|
||||
)
|
||||
assert preemptions > 0, "preemption test had no preemptions"
|
||||
|
||||
|
||||
@ -3,7 +3,7 @@
|
||||
|
||||
import ast
|
||||
import hashlib
|
||||
from typing import TYPE_CHECKING, Any, Literal
|
||||
from typing import TYPE_CHECKING, Any, Literal, get_args
|
||||
|
||||
from pydantic import Field, SkipValidation, model_validator
|
||||
from pydantic.dataclasses import dataclass
|
||||
@ -29,31 +29,25 @@ else:
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
SpeculativeMethod = Literal[
|
||||
"ngram",
|
||||
"eagle",
|
||||
"eagle3",
|
||||
"medusa",
|
||||
"mlp_speculator",
|
||||
"draft_model",
|
||||
"deepseek_mtp",
|
||||
"ernie_mtp",
|
||||
"qwen3_next_mtp",
|
||||
"mimo_mtp",
|
||||
"longcat_flash_mtp",
|
||||
"pangu_ultra_moe_mtp",
|
||||
"mtp",
|
||||
"suffix",
|
||||
]
|
||||
MTP_MODEL_TYPES = (
|
||||
MTPModelTypes = Literal[
|
||||
"deepseek_mtp",
|
||||
"mimo_mtp",
|
||||
"glm4_moe_mtp",
|
||||
"ernie_mtp",
|
||||
"qwen3_next_mtp",
|
||||
"longcat_flash_mtp",
|
||||
"mtp",
|
||||
"pangu_ultra_moe_mtp",
|
||||
)
|
||||
]
|
||||
EagleModelTypes = Literal["eagle", "eagle3", MTPModelTypes]
|
||||
SpeculativeMethod = Literal[
|
||||
"ngram",
|
||||
"medusa",
|
||||
"mlp_speculator",
|
||||
"draft_model",
|
||||
"suffix",
|
||||
EagleModelTypes,
|
||||
]
|
||||
|
||||
|
||||
@config
|
||||
@ -244,7 +238,7 @@ class SpeculativeConfig:
|
||||
# can not be detected, it will be considered as the "draft_model" by
|
||||
# default.
|
||||
|
||||
if self.method in MTP_MODEL_TYPES:
|
||||
if self.method in get_args(MTPModelTypes) and self.method != "mtp":
|
||||
logger.warning(
|
||||
"method `%s` is deprecated and replaced with mtp.", self.method
|
||||
)
|
||||
@ -361,7 +355,9 @@ class SpeculativeConfig:
|
||||
self.method = "medusa"
|
||||
elif self.draft_model_config.hf_config.model_type == "mlp_speculator":
|
||||
self.method = "mlp_speculator"
|
||||
elif self.draft_model_config.hf_config.model_type in MTP_MODEL_TYPES:
|
||||
elif self.draft_model_config.hf_config.model_type in get_args(
|
||||
MTPModelTypes
|
||||
):
|
||||
self.method = "mtp"
|
||||
if self.num_speculative_tokens > 1:
|
||||
logger.warning(
|
||||
|
||||
@ -14,13 +14,14 @@ from dataclasses import replace
|
||||
from datetime import datetime
|
||||
from functools import lru_cache
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any, TypeVar
|
||||
from typing import TYPE_CHECKING, Any, TypeVar, get_args
|
||||
|
||||
import torch
|
||||
from pydantic import ConfigDict, Field, model_validator
|
||||
from pydantic.dataclasses import dataclass
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.config.speculative import EagleModelTypes
|
||||
from vllm.logger import enable_trace_function_call, init_logger
|
||||
from vllm.transformers_utils.runai_utils import is_runai_obj_uri
|
||||
from vllm.utils import random_uuid
|
||||
@ -374,10 +375,22 @@ class VllmConfig:
|
||||
"Async scheduling is not yet compatible with "
|
||||
"pipeline_parallel_size > 1."
|
||||
)
|
||||
# Currently, async scheduling only support eagle speculative
|
||||
# decoding.
|
||||
if self.speculative_config is not None:
|
||||
raise ValueError(
|
||||
"Async scheduling is not yet compatible with speculative decoding."
|
||||
)
|
||||
if self.speculative_config.method not in get_args(EagleModelTypes):
|
||||
raise ValueError(
|
||||
"Currently, async scheduling is only supported "
|
||||
"with EAGLE/MTP kind of speculative decoding"
|
||||
)
|
||||
if self.speculative_config.disable_padded_drafter_batch:
|
||||
raise ValueError(
|
||||
"async scheduling for EAGLE/MTP kind of speculative "
|
||||
"decoding is enabled, but disable_padded_drafter_batch=True "
|
||||
"disable_padded_drafter_batch=True is not supported for "
|
||||
"this situation now. please set "
|
||||
"disable_padded_drafter_batch=Fasle"
|
||||
)
|
||||
if not executor_supports_async_sched:
|
||||
raise ValueError(
|
||||
"Currently, async scheduling only supports `mp`, `uni`, or "
|
||||
|
||||
@ -16,18 +16,25 @@ class AsyncScheduler(Scheduler):
|
||||
) -> None:
|
||||
super()._update_after_schedule(scheduler_output)
|
||||
pending_structured_output_tokens = False
|
||||
spec_decode_tokens = scheduler_output.scheduled_spec_decode_tokens
|
||||
for req_id in scheduler_output.num_scheduled_tokens:
|
||||
request = self.requests[req_id]
|
||||
pending_structured_output_tokens |= (
|
||||
request.use_structured_output and request.num_output_placeholders > 0
|
||||
)
|
||||
cur_num_spec_tokens = len(spec_decode_tokens.get(req_id, ()))
|
||||
if (
|
||||
request.num_computed_tokens
|
||||
== request.num_tokens + request.num_output_placeholders
|
||||
== request.num_tokens
|
||||
+ request.num_output_placeholders
|
||||
+ cur_num_spec_tokens
|
||||
):
|
||||
# The request will generate a new token in this scheduling step.
|
||||
# TODO(woosuk): Support speculative decoding.
|
||||
request.num_output_placeholders += 1
|
||||
# The request will generate a new token plus num_spec_tokens
|
||||
# in this scheduling step.
|
||||
request.num_output_placeholders += 1 + cur_num_spec_tokens
|
||||
# Add placeholders for the new tokens in spec_token_ids.
|
||||
# Wwe will update the actual spec token ids in the worker process.
|
||||
request.spec_token_ids = [-1] * self.num_spec_tokens
|
||||
|
||||
scheduler_output.pending_structured_output_tokens = (
|
||||
pending_structured_output_tokens
|
||||
|
||||
@ -348,7 +348,10 @@ class Scheduler(SchedulerInterface):
|
||||
# Speculative decode related.
|
||||
if request.spec_token_ids:
|
||||
num_scheduled_spec_tokens = (
|
||||
num_new_tokens + request.num_computed_tokens - request.num_tokens
|
||||
num_new_tokens
|
||||
+ request.num_computed_tokens
|
||||
- request.num_tokens
|
||||
- request.num_output_placeholders
|
||||
)
|
||||
if num_scheduled_spec_tokens > 0:
|
||||
# Trim spec_token_ids list to num_scheduled_spec_tokens.
|
||||
@ -1024,7 +1027,12 @@ class Scheduler(SchedulerInterface):
|
||||
# tokens and rejections. If some tokens are rejected,
|
||||
# num_computed_tokens is decreased by the number of rejected
|
||||
# tokens.
|
||||
request.num_computed_tokens -= num_rejected
|
||||
if request.num_computed_tokens > 0:
|
||||
request.num_computed_tokens -= num_rejected
|
||||
# If async scheduling, num_output_placeholders also includes
|
||||
# the scheduled spec tokens count and so is similarly adjusted.
|
||||
if request.num_output_placeholders > 0:
|
||||
request.num_output_placeholders -= num_rejected
|
||||
spec_decoding_stats = self.make_spec_decoding_stats(
|
||||
spec_decoding_stats,
|
||||
num_draft_tokens=num_draft_tokens,
|
||||
|
||||
@ -198,6 +198,7 @@ class EngineCore:
|
||||
self.step_fn = (
|
||||
self.step if self.batch_queue is None else self.step_with_batch_queue
|
||||
)
|
||||
self.async_scheduling = vllm_config.scheduler_config.async_scheduling
|
||||
|
||||
# Mark the startup heap as static so that it's ignored by GC.
|
||||
# Reduces pause times of oldest generation collections.
|
||||
@ -341,7 +342,10 @@ class EngineCore:
|
||||
return engine_core_outputs, scheduler_output.total_num_scheduled_tokens > 0
|
||||
|
||||
def post_step(self, model_executed: bool) -> None:
|
||||
if self.use_spec_decode and model_executed:
|
||||
# When using async scheduling we can't get draft token ids in advance,
|
||||
# so we update draft token ids in the worker process and don't
|
||||
# need to update draft token ids here.
|
||||
if not self.async_scheduling and self.use_spec_decode and model_executed:
|
||||
# Take the draft token ids.
|
||||
draft_token_ids = self.model_executor.take_draft_token_ids()
|
||||
if draft_token_ids is not None:
|
||||
|
||||
@ -150,6 +150,23 @@ class Processor:
|
||||
raise ValueError(
|
||||
"vLLM V1 does not support per request user provided logits processors."
|
||||
)
|
||||
# Async scheduling + spec decode currently incompatible with some
|
||||
# sampling parameters.
|
||||
if (
|
||||
self.vllm_config.speculative_config is not None
|
||||
and self.vllm_config.scheduler_config.async_scheduling
|
||||
and (
|
||||
params.frequency_penalty != 0.0
|
||||
or params.presence_penalty != 0.0
|
||||
or params.repetition_penalty != 1.0
|
||||
or params.bad_words_token_ids
|
||||
or params.structured_outputs
|
||||
)
|
||||
):
|
||||
raise ValueError(
|
||||
"async scheduling with spec decoding doesn't yet support "
|
||||
"penalties, bad words or structured outputs in sampling parameters."
|
||||
)
|
||||
|
||||
def _validate_params(
|
||||
self,
|
||||
|
||||
@ -41,7 +41,7 @@ STR_POOLING_REJECTS_LOGITSPROCS = (
|
||||
# Error message when the user tries to initialize vLLM with a speculative
|
||||
# decoding enabled and custom logitsproces
|
||||
STR_SPEC_DEC_REJECTS_LOGITSPROCS = (
|
||||
"Custom logits processors are not supportedwhen speculative decoding is enabled."
|
||||
"Custom logits processors are not supported when speculative decoding is enabled."
|
||||
)
|
||||
|
||||
LOGITSPROCS_GROUP = "vllm.logits_processors"
|
||||
|
||||
@ -397,10 +397,13 @@ class EagleProposer:
|
||||
positions += 1
|
||||
exceeds_max_model_len = positions >= self.max_model_len
|
||||
clamped_positions = torch.where(exceeds_max_model_len, 0, positions)
|
||||
|
||||
# For data integrity when async scheduling, we shouldn't use in place
|
||||
# operations in case they are modified in next step's `prepare_input`
|
||||
# of main model.
|
||||
# Increment the sequence lengths.
|
||||
common_attn_metadata.seq_lens += 1
|
||||
common_attn_metadata.seq_lens_cpu += 1
|
||||
# This is an out-of-place operation to avoid modifying the original tensor.
|
||||
common_attn_metadata.seq_lens_cpu = common_attn_metadata.seq_lens_cpu + 1
|
||||
# For the requests that exceed the max model length, we set the
|
||||
# sequence length to 1 to minimize their overheads in attention.
|
||||
|
||||
|
||||
@ -46,6 +46,9 @@ class CachedRequestState:
|
||||
lora_request: LoRARequest | None = None
|
||||
prompt_embeds: torch.Tensor | None = None
|
||||
|
||||
# Used when both async_scheduling and spec_decode are enabled.
|
||||
prev_num_draft_len: int = 0
|
||||
|
||||
def __post_init__(self):
|
||||
self.num_prompt_tokens = length_from_prompt_token_ids_or_embeds(
|
||||
self.prompt_token_ids, self.prompt_embeds
|
||||
|
||||
@ -179,6 +179,7 @@ class AsyncGPUModelRunnerOutput(AsyncModelRunnerOutput):
|
||||
logprobs_tensors: torch.Tensor | None,
|
||||
invalid_req_indices: list[int],
|
||||
async_output_copy_stream: torch.cuda.Stream,
|
||||
vocab_size: int,
|
||||
):
|
||||
self._model_runner_output = model_runner_output
|
||||
self._invalid_req_indices = invalid_req_indices
|
||||
@ -189,6 +190,7 @@ class AsyncGPUModelRunnerOutput(AsyncModelRunnerOutput):
|
||||
# Keep a reference to the device tensor to avoid it being
|
||||
# deallocated until we finish copying it to the host.
|
||||
self._sampled_token_ids = sampled_token_ids
|
||||
self.vocab_size = vocab_size
|
||||
self._logprobs_tensors = logprobs_tensors
|
||||
|
||||
# Initiate the copy on a separate stream, but do not synchronize it.
|
||||
@ -215,10 +217,16 @@ class AsyncGPUModelRunnerOutput(AsyncModelRunnerOutput):
|
||||
# Release the device tensors once the copy has completed.
|
||||
del self._logprobs_tensors
|
||||
del self._sampled_token_ids
|
||||
|
||||
valid_sampled_token_ids: list[np.ndarray] = [
|
||||
row for row in self.sampled_token_ids_cpu.numpy()
|
||||
]
|
||||
max_gen_len = self.sampled_token_ids_cpu.shape[-1]
|
||||
if max_gen_len == 1:
|
||||
valid_sampled_token_ids: list[np.ndarray] = [
|
||||
row for row in self.sampled_token_ids_cpu.numpy()
|
||||
]
|
||||
else:
|
||||
valid_sampled_token_ids = RejectionSampler.parse_output(
|
||||
self.sampled_token_ids_cpu,
|
||||
self.vocab_size,
|
||||
)
|
||||
for i in self._invalid_req_indices:
|
||||
valid_sampled_token_ids[i] = np.array([])
|
||||
|
||||
@ -377,6 +385,10 @@ class GPUModelRunner(
|
||||
)
|
||||
self.rejection_sampler = RejectionSampler(self.sampler)
|
||||
|
||||
self.num_spec_tokens = 0
|
||||
if self.speculative_config:
|
||||
self.num_spec_tokens = self.speculative_config.num_speculative_tokens
|
||||
|
||||
# Request states.
|
||||
self.requests: dict[str, CachedRequestState] = {}
|
||||
self.comm_stream = torch.cuda.Stream()
|
||||
@ -513,11 +525,7 @@ class GPUModelRunner(
|
||||
self.max_num_tokens, dtype=torch.int32, device=self.device
|
||||
)
|
||||
|
||||
self.uniform_decode_query_len = (
|
||||
1
|
||||
if not self.speculative_config
|
||||
else 1 + self.speculative_config.num_speculative_tokens
|
||||
)
|
||||
self.uniform_decode_query_len = 1 + self.num_spec_tokens
|
||||
|
||||
# Cudagraph dispatcher for runtime cudagraph dispatching.
|
||||
self.cudagraph_dispatcher = CudagraphDispatcher(self.vllm_config)
|
||||
@ -549,6 +557,20 @@ class GPUModelRunner(
|
||||
pin_memory=self.pin_memory,
|
||||
)
|
||||
|
||||
# Pre-allocated tensor for copying valid sampled token counts to CPU,
|
||||
# with dedicated stream for overlapping and event for coordination.
|
||||
self.valid_sampled_token_count_event: torch.cuda.Event | None = None
|
||||
self.valid_sampled_token_count_copy_stream: torch.cuda.Stream | None = None
|
||||
if self.use_async_scheduling and self.num_spec_tokens:
|
||||
self.valid_sampled_token_count_event = torch.cuda.Event()
|
||||
self.valid_sampled_token_count_copy_stream = torch.cuda.Stream()
|
||||
self.valid_sampled_token_count_cpu = torch.empty(
|
||||
self.max_num_reqs,
|
||||
dtype=torch.int64,
|
||||
device="cpu",
|
||||
pin_memory=self.pin_memory,
|
||||
)
|
||||
|
||||
# Ephemeral state transferred between execute_model() and sample_tokens().
|
||||
self.execute_model_state: ExecuteModelState | None = None
|
||||
|
||||
@ -736,17 +758,45 @@ class GPUModelRunner(
|
||||
# Update the states of the running/resumed requests.
|
||||
is_last_rank = get_pp_group().is_last_rank
|
||||
req_data = scheduler_output.scheduled_cached_reqs
|
||||
|
||||
# Wait until valid_sampled_tokens_count is copied to cpu,
|
||||
# then use it to update actual num_computed_tokens of each request.
|
||||
valid_sampled_token_count = self._get_valid_sampled_token_count()
|
||||
|
||||
for i, req_id in enumerate(req_data.req_ids):
|
||||
req_state = self.requests[req_id]
|
||||
num_computed_tokens = req_data.num_computed_tokens[i]
|
||||
new_block_ids = req_data.new_block_ids[i]
|
||||
resumed_from_preemption = req_id in req_data.resumed_req_ids
|
||||
num_output_tokens = req_data.num_output_tokens[i]
|
||||
req_index = self.input_batch.req_id_to_index.get(req_id)
|
||||
|
||||
# prev_num_draft_len is used in async scheduling mode with
|
||||
# spec decode. it indicates if need to update num_computed_tokens
|
||||
# of the request. for example:
|
||||
# fist step: num_computed_tokens = 0, spec_tokens = [],
|
||||
# prev_num_draft_len = 0.
|
||||
# second step: num_computed_tokens = 100(prompt lenth),
|
||||
# spec_tokens = [a,b], prev_num_draft_len = 0.
|
||||
# third step: num_computed_tokens = 100 + 2, spec_tokens = [c,d],
|
||||
# prev_num_draft_len = 2.
|
||||
# num_computed_tokens in first step and second step does't contain
|
||||
# the spec tokens length, but in third step it contains the
|
||||
# spec tokens length. we only need to update num_computed_tokens
|
||||
# when prev_num_draft_len > 0.
|
||||
if req_state.prev_num_draft_len:
|
||||
if req_index is None:
|
||||
req_state.prev_num_draft_len = 0
|
||||
else:
|
||||
assert self.input_batch.prev_req_id_to_index is not None
|
||||
prev_req_index = self.input_batch.prev_req_id_to_index[req_id]
|
||||
num_accepted = valid_sampled_token_count[prev_req_index] - 1
|
||||
num_rejected = req_state.prev_num_draft_len - num_accepted
|
||||
num_computed_tokens -= num_rejected
|
||||
req_state.output_token_ids.extend([-1] * num_accepted)
|
||||
|
||||
# Update the cached states.
|
||||
|
||||
req_state.num_computed_tokens = num_computed_tokens
|
||||
req_index = self.input_batch.req_id_to_index.get(req_id)
|
||||
|
||||
if not is_last_rank:
|
||||
# When using PP, the scheduler sends the sampled tokens back,
|
||||
@ -823,8 +873,11 @@ class GPUModelRunner(
|
||||
spec_token_ids = scheduler_output.scheduled_spec_decode_tokens.get(
|
||||
req_id, []
|
||||
)
|
||||
if spec_token_ids:
|
||||
num_spec_tokens = len(spec_token_ids)
|
||||
num_spec_tokens = len(spec_token_ids)
|
||||
# For async scheduling, token_ids_cpu assigned from
|
||||
# spec_token_ids are placeholders and will be overwritten in
|
||||
# _prepare_input_ids.
|
||||
if num_spec_tokens:
|
||||
start_index = self.input_batch.num_tokens_no_spec[req_index]
|
||||
end_token_index = start_index + num_spec_tokens
|
||||
self.input_batch.token_ids_cpu[
|
||||
@ -840,6 +893,15 @@ class GPUModelRunner(
|
||||
# even when speculative decoding is enabled.
|
||||
self.input_batch.spec_token_ids[req_index] = spec_token_ids
|
||||
|
||||
# there are no draft tokens with async scheduling,
|
||||
# we clear the spec_decoding info in scheduler_output and
|
||||
# use normal sampling but rejection_sampling.
|
||||
if self.use_async_scheduling:
|
||||
req_state.prev_num_draft_len = num_spec_tokens
|
||||
if num_spec_tokens and self._draft_token_ids is None:
|
||||
scheduler_output.total_num_scheduled_tokens -= num_spec_tokens
|
||||
scheduler_output.num_scheduled_tokens[req_id] -= num_spec_tokens
|
||||
scheduler_output.scheduled_spec_decode_tokens.pop(req_id, None)
|
||||
# Add the new or resumed requests to the persistent batch.
|
||||
# The smaller empty indices are filled first.
|
||||
for request in reqs_to_add:
|
||||
@ -959,7 +1021,10 @@ class GPUModelRunner(
|
||||
return cu_num_tokens, arange
|
||||
|
||||
def _prepare_input_ids(
|
||||
self, total_num_scheduled_tokens: int, cu_num_tokens: np.ndarray
|
||||
self,
|
||||
scheduler_output: "SchedulerOutput",
|
||||
total_num_scheduled_tokens: int,
|
||||
cu_num_tokens: np.ndarray,
|
||||
) -> None:
|
||||
"""Prepare the input IDs for the current batch.
|
||||
|
||||
@ -980,21 +1045,43 @@ class GPUModelRunner(
|
||||
# on the GPU from prev_sampled_token_ids.
|
||||
prev_req_id_to_index = self.input_batch.prev_req_id_to_index
|
||||
assert prev_req_id_to_index is not None
|
||||
flattened_indices = []
|
||||
prev_common_req_indices = []
|
||||
sample_flattened_indices: list[int] = []
|
||||
spec_flattened_indices: list[int] = []
|
||||
prev_common_req_indices: list[int] = []
|
||||
prev_draft_token_indices: list[int] = []
|
||||
indices_match = True
|
||||
max_flattened_index = -1
|
||||
total_num_spec_tokens = 0
|
||||
scheduled_spec_tokens = scheduler_output.scheduled_spec_decode_tokens
|
||||
|
||||
for req_id, cur_index in self.input_batch.req_id_to_index.items():
|
||||
if (prev_index := prev_req_id_to_index.get(req_id)) is not None:
|
||||
prev_common_req_indices.append(prev_index)
|
||||
# We need to compute the flattened input_ids index of the
|
||||
# last token in each common request.
|
||||
draft_len = len(scheduled_spec_tokens.get(req_id, ()))
|
||||
total_num_spec_tokens += draft_len
|
||||
flattened_index = cu_num_tokens[cur_index].item() - 1
|
||||
flattened_indices.append(flattened_index)
|
||||
# example: cu_num_tokens = [2, 5, 8], draft_tokens = [1, 2, 2]
|
||||
# sample_flattened_indices = [0, 2, 5]
|
||||
# spec_flattened_indices = [1, 3, 4, 6, 7]
|
||||
sample_flattened_indices.append(flattened_index - draft_len)
|
||||
spec_flattened_indices.extend(
|
||||
range(flattened_index - draft_len + 1, flattened_index + 1)
|
||||
)
|
||||
start = prev_index * self.num_spec_tokens
|
||||
# prev_draft_token_indices is used to find which draft_tokens_id
|
||||
# should be copied to input_ids
|
||||
# example: prev draft_tokens_id [[1,2], [3,4], [5, 6]]
|
||||
# flatten draft_tokens_id [1,2,3,4,5,6]
|
||||
# draft_len of each request [1, 2, 1]
|
||||
# then prev_draft_token_indices is [0, 2, 3, 4]
|
||||
prev_draft_token_indices.extend(range(start, start + draft_len))
|
||||
indices_match &= prev_index == flattened_index
|
||||
max_flattened_index = max(max_flattened_index, flattened_index)
|
||||
num_commmon_tokens = len(flattened_indices)
|
||||
if num_commmon_tokens < total_num_scheduled_tokens:
|
||||
num_commmon_tokens = len(sample_flattened_indices)
|
||||
total_without_spec = total_num_scheduled_tokens - total_num_spec_tokens
|
||||
if num_commmon_tokens < total_without_spec:
|
||||
# If not all requests are decodes from the last iteration,
|
||||
# We need to copy the input_ids_cpu to the GPU first.
|
||||
self.input_ids.copy_to_gpu(total_num_scheduled_tokens)
|
||||
@ -1018,20 +1105,43 @@ class GPUModelRunner(
|
||||
self.is_token_ids.gpu[:num_commmon_tokens] = True
|
||||
return
|
||||
# Upload the index tensors asynchronously so the scatter can be non-blocking.
|
||||
input_ids_index_tensor = torch.tensor(
|
||||
flattened_indices, dtype=torch.int64, pin_memory=self.pin_memory
|
||||
sampled_tokens_index_tensor = torch.tensor(
|
||||
sample_flattened_indices, dtype=torch.int64, pin_memory=self.pin_memory
|
||||
).to(self.device, non_blocking=True)
|
||||
prev_common_req_indices_tensor = torch.tensor(
|
||||
prev_common_req_indices, dtype=torch.int64, pin_memory=self.pin_memory
|
||||
).to(self.device, non_blocking=True)
|
||||
self.input_ids.gpu.scatter_(
|
||||
dim=0,
|
||||
index=input_ids_index_tensor,
|
||||
index=sampled_tokens_index_tensor,
|
||||
src=self.input_batch.prev_sampled_token_ids[
|
||||
prev_common_req_indices_tensor, 0
|
||||
],
|
||||
)
|
||||
|
||||
# Scatter the draft tokens after the sampled tokens are scattered.
|
||||
if self._draft_token_ids is None or not spec_flattened_indices:
|
||||
return
|
||||
|
||||
assert isinstance(self._draft_token_ids, torch.Tensor)
|
||||
draft_tokens_index_tensor = torch.tensor(
|
||||
spec_flattened_indices, dtype=torch.int64, pin_memory=self.pin_memory
|
||||
).to(self.device, non_blocking=True)
|
||||
prev_draft_token_indices_tensor = torch.tensor(
|
||||
prev_draft_token_indices, dtype=torch.int64, pin_memory=self.pin_memory
|
||||
).to(self.device, non_blocking=True)
|
||||
|
||||
# because input_ids dtype is torch.int32,
|
||||
# so convert draft_token_ids to torch.int32 here.
|
||||
draft_token_ids = self._draft_token_ids.to(dtype=torch.int32)
|
||||
self._draft_token_ids = None
|
||||
|
||||
self.input_ids.gpu.scatter_(
|
||||
dim=0,
|
||||
index=draft_tokens_index_tensor,
|
||||
src=draft_token_ids.flatten()[prev_draft_token_indices_tensor],
|
||||
)
|
||||
|
||||
def _get_encoder_seq_lens(
|
||||
self,
|
||||
scheduled_encoder_inputs: dict[str, list[int]],
|
||||
@ -1218,7 +1328,11 @@ class GPUModelRunner(
|
||||
self.discard_request_indices.copy_to_gpu(self.num_discarded_requests)
|
||||
|
||||
# Copy the tensors to the GPU.
|
||||
self._prepare_input_ids(total_num_scheduled_tokens, cu_num_tokens)
|
||||
self._prepare_input_ids(
|
||||
scheduler_output,
|
||||
total_num_scheduled_tokens,
|
||||
cu_num_tokens,
|
||||
)
|
||||
|
||||
if self.uses_mrope:
|
||||
# Only relevant for models using M-RoPE (e.g, Qwen2-VL)
|
||||
@ -2377,12 +2491,14 @@ class GPUModelRunner(
|
||||
valid_sampled_token_ids = []
|
||||
invalid_req_indices = discard_sampled_tokens_req_indices.tolist()
|
||||
invalid_req_indices_set = set(invalid_req_indices)
|
||||
assert sampled_token_ids.shape[-1] == 1
|
||||
|
||||
# Cache the sampled tokens on the GPU and avoid CPU sync.
|
||||
# These will be copied into input_ids in the next step
|
||||
# when preparing inputs.
|
||||
self.input_batch.prev_sampled_token_ids = sampled_token_ids
|
||||
# With spec decoding, this is done in propose_draft_token_ids().
|
||||
if self.input_batch.prev_sampled_token_ids is None:
|
||||
assert sampled_token_ids.shape[-1] == 1
|
||||
self.input_batch.prev_sampled_token_ids = sampled_token_ids
|
||||
self.input_batch.prev_req_id_to_index = {
|
||||
req_id: i
|
||||
for i, req_id in enumerate(self.input_batch.req_ids)
|
||||
@ -2517,6 +2633,21 @@ class GPUModelRunner(
|
||||
"State error: sample_tokens() must be called "
|
||||
"after execute_model() returns None."
|
||||
)
|
||||
|
||||
# self._draft_token_ids is None when `input_fits_in_drafter=False`
|
||||
# and there is no draft tokens scheduled. so it need to update the
|
||||
# spec_decoding info in scheduler_output with async_scheduling.
|
||||
# use deepcopy to avoid the modification has influence on the
|
||||
# scheduler_output in engine core process.
|
||||
# TODO(Ronald1995): deepcopy is expensive when there is a large
|
||||
# number of requests, optimize it later.
|
||||
if (
|
||||
self.use_async_scheduling
|
||||
and self.num_spec_tokens
|
||||
and self._draft_token_ids is None
|
||||
):
|
||||
scheduler_output = deepcopy(scheduler_output)
|
||||
|
||||
num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
|
||||
with record_function_or_nullcontext("gpu_model_runner: preprocess"):
|
||||
with self.synchronize_input_prep():
|
||||
@ -2759,6 +2890,8 @@ class GPUModelRunner(
|
||||
with record_function_or_nullcontext("gpu_model_runner: sample"):
|
||||
sampler_output = self._sample(logits, spec_decode_metadata)
|
||||
|
||||
self.input_batch.prev_sampled_token_ids = None
|
||||
|
||||
def propose_draft_token_ids(
|
||||
sampled_token_ids: torch.Tensor | list[np.ndarray],
|
||||
) -> None:
|
||||
@ -2792,14 +2925,29 @@ class GPUModelRunner(
|
||||
self.speculative_config.draft_model_config.max_model_len
|
||||
)
|
||||
input_fits_in_drafter = spec_decode_common_attn_metadata and (
|
||||
spec_decode_common_attn_metadata.max_seq_len
|
||||
+ self.speculative_config.num_speculative_tokens
|
||||
spec_decode_common_attn_metadata.max_seq_len + self.num_spec_tokens
|
||||
<= effective_drafter_max_model_len
|
||||
)
|
||||
if use_padded_batch_for_eagle and input_fits_in_drafter:
|
||||
# EAGLE speculative decoding can use the GPU sampled tokens
|
||||
# as inputs, and does not need to wait for bookkeeping to finish.
|
||||
propose_draft_token_ids(sampler_output.sampled_token_ids)
|
||||
if use_padded_batch_for_eagle:
|
||||
sampled_token_ids = sampler_output.sampled_token_ids
|
||||
if input_fits_in_drafter:
|
||||
# EAGLE speculative decoding can use the GPU sampled tokens
|
||||
# as inputs, and does not need to wait for bookkeeping to finish.
|
||||
propose_draft_token_ids(sampled_token_ids)
|
||||
elif self.valid_sampled_token_count_event is not None:
|
||||
next_token_ids, valid_sampled_tokens_count = (
|
||||
self.drafter.prepare_next_token_ids_padded(
|
||||
spec_decode_common_attn_metadata,
|
||||
sampled_token_ids,
|
||||
self.requests,
|
||||
self.input_batch,
|
||||
self.discard_request_indices.gpu,
|
||||
self.num_discarded_requests,
|
||||
)
|
||||
)
|
||||
self._copy_valid_sampled_token_count(
|
||||
next_token_ids, valid_sampled_tokens_count
|
||||
)
|
||||
|
||||
with record_function_or_nullcontext("gpu_model_runner: bookkeep"):
|
||||
(
|
||||
@ -2856,6 +3004,7 @@ class GPUModelRunner(
|
||||
logprobs_tensors=sampler_output.logprobs_tensors,
|
||||
invalid_req_indices=invalid_req_indices,
|
||||
async_output_copy_stream=self.async_output_copy_stream,
|
||||
vocab_size=self.input_batch.vocab_size,
|
||||
)
|
||||
with record_function_or_nullcontext(
|
||||
"gpu_model_runner: set_async_sampled_token_ids"
|
||||
@ -2880,6 +3029,37 @@ class GPUModelRunner(
|
||||
self._draft_token_ids = None
|
||||
return DraftTokenIds(req_ids, draft_token_ids)
|
||||
|
||||
def _copy_valid_sampled_token_count(
|
||||
self, next_token_ids: torch.Tensor, valid_sampled_tokens_count: torch.Tensor
|
||||
) -> None:
|
||||
if self.valid_sampled_token_count_event is None:
|
||||
return
|
||||
|
||||
default_stream = torch.cuda.current_stream()
|
||||
# Initialize a new stream to overlap the copy operation with
|
||||
# prepare_input of draft model.
|
||||
with torch.cuda.stream(self.valid_sampled_token_count_copy_stream):
|
||||
self.valid_sampled_token_count_copy_stream.wait_stream(default_stream) # type: ignore
|
||||
counts = valid_sampled_tokens_count
|
||||
counts_cpu = self.valid_sampled_token_count_cpu
|
||||
counts_cpu[: counts.shape[0]].copy_(counts, non_blocking=True)
|
||||
self.valid_sampled_token_count_event.record()
|
||||
|
||||
self.input_batch.prev_sampled_token_ids = next_token_ids.unsqueeze(1)
|
||||
|
||||
def _get_valid_sampled_token_count(self) -> list[int]:
|
||||
# Wait until valid_sampled_tokens_count is copied to cpu,
|
||||
prev_sampled_token_ids = self.input_batch.prev_sampled_token_ids
|
||||
if (
|
||||
self.valid_sampled_token_count_event is None
|
||||
or prev_sampled_token_ids is None
|
||||
):
|
||||
return []
|
||||
|
||||
counts_cpu = self.valid_sampled_token_count_cpu
|
||||
self.valid_sampled_token_count_event.synchronize()
|
||||
return counts_cpu[: prev_sampled_token_ids.shape[0]].tolist()
|
||||
|
||||
def propose_draft_token_ids(
|
||||
self,
|
||||
scheduler_output: "SchedulerOutput",
|
||||
@ -2967,6 +3147,9 @@ class GPUModelRunner(
|
||||
self.num_discarded_requests,
|
||||
)
|
||||
)
|
||||
self._copy_valid_sampled_token_count(
|
||||
next_token_ids, valid_sampled_tokens_count
|
||||
)
|
||||
|
||||
if spec_decode_metadata is None:
|
||||
token_indices_to_sample = None
|
||||
@ -3532,7 +3715,7 @@ class GPUModelRunner(
|
||||
# TODO(luka) better system for describing dummy batches
|
||||
seq_lens = [1] * num_decode_tokens + [num_prefill_tokens + 1]
|
||||
else:
|
||||
seq_lens = max_query_len
|
||||
seq_lens = max_query_len # type: ignore[assignment]
|
||||
self.seq_lens.np[:num_reqs] = seq_lens
|
||||
self.seq_lens.np[num_reqs:] = 0
|
||||
self.seq_lens.copy_to_gpu()
|
||||
@ -4485,11 +4668,7 @@ class GPUModelRunner(
|
||||
logitsprocs=self.input_batch.logitsprocs,
|
||||
logitsprocs_need_output_token_ids=self.input_batch.logitsprocs_need_output_token_ids,
|
||||
is_pooling_model=self.is_pooling_model,
|
||||
num_speculative_tokens=(
|
||||
self.vllm_config.speculative_config.num_speculative_tokens
|
||||
if self.vllm_config.speculative_config
|
||||
else 0
|
||||
),
|
||||
num_speculative_tokens=self.num_spec_tokens,
|
||||
)
|
||||
|
||||
def _allocate_kv_cache_tensors(
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user