[Perf][V1] Fully overlap model execution (#23569)

Signed-off-by: Benjamin Chislett <benjamin.chislett@centml.ai>
This commit is contained in:
Benjamin Chislett 2025-09-05 21:20:17 -04:00 committed by GitHub
parent c954c6629c
commit cee182b297
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 252 additions and 31 deletions

View File

@ -3,6 +3,7 @@
import multiprocessing import multiprocessing
import os import os
import pickle import pickle
import queue
import signal import signal
import threading import threading
import time import time
@ -33,7 +34,8 @@ from vllm.utils import (decorate_logs, get_distributed_init_method,
get_loopback_ip, get_mp_context, get_open_port, get_loopback_ip, get_mp_context, get_open_port,
set_process_title) set_process_title)
from vllm.v1.executor.abstract import Executor, FailureCallback from vllm.v1.executor.abstract import Executor, FailureCallback
from vllm.v1.outputs import DraftTokenIds, ModelRunnerOutput from vllm.v1.outputs import (AsyncModelRunnerOutput, DraftTokenIds,
ModelRunnerOutput)
from vllm.worker.worker_base import WorkerWrapperBase from vllm.worker.worker_base import WorkerWrapperBase
logger = init_logger(__name__) logger = init_logger(__name__)
@ -414,6 +416,16 @@ class WorkerProc:
# Initializes a message queue for sending the model output # Initializes a message queue for sending the model output
self.worker_response_mq = MessageQueue(1, 1) self.worker_response_mq = MessageQueue(1, 1)
scheduler_config = vllm_config.scheduler_config
self.use_async_scheduling = scheduler_config.async_scheduling
if self.use_async_scheduling:
self.async_output_queue: queue.Queue = queue.Queue()
self.async_output_copy_thread = Thread(
target=self.async_output_busy_loop,
daemon=True,
name="WorkerAsyncOutputCopy")
self.async_output_copy_thread.start()
# Initialize device and loads weights # Initialize device and loads weights
self.worker.init_device() self.worker.init_device()
self.worker.load_model() self.worker.load_model()
@ -595,6 +607,36 @@ class WorkerProc:
SUCCESS = auto() SUCCESS = auto()
FAILURE = auto() FAILURE = auto()
def enqueue_output(self, output: Any):
"""Prepares output from the worker and enqueues it to the
worker_response_mq. If the output is an Exception, it is
converted to a FAILURE response.
"""
if isinstance(output, AsyncModelRunnerOutput):
output = output.get_output()
if isinstance(output, Exception):
result = (WorkerProc.ResponseStatus.FAILURE, str(output))
else:
result = (WorkerProc.ResponseStatus.SUCCESS, output)
self.worker_response_mq.enqueue(result)
def handle_output(self, output: Any):
"""Handles output from the worker. If async scheduling is enabled,
it is passed to the async_output_busy_loop thread. Otherwise, it is
enqueued directly to the worker_response_mq.
"""
if self.use_async_scheduling:
self.async_output_queue.put(output)
else:
self.enqueue_output(output)
def async_output_busy_loop(self):
"""Entrypoint for the thread which handles outputs asynchronously."""
while True:
output = self.async_output_queue.get()
self.enqueue_output(output)
def worker_busy_loop(self): def worker_busy_loop(self):
"""Main busy loop for Multiprocessing Workers""" """Main busy loop for Multiprocessing Workers"""
while True: while True:
@ -614,10 +656,8 @@ class WorkerProc:
# exception might not be serializable, so we convert it to # exception might not be serializable, so we convert it to
# string, only for logging purpose. # string, only for logging purpose.
if output_rank is None or self.rank == output_rank: if output_rank is None or self.rank == output_rank:
self.worker_response_mq.enqueue( self.handle_output(e)
(WorkerProc.ResponseStatus.FAILURE, str(e)))
continue continue
if output_rank is None or self.rank == output_rank: if output_rank is None or self.rank == output_rank:
self.worker_response_mq.enqueue( self.handle_output(output)
(WorkerProc.ResponseStatus.SUCCESS, output))

View File

@ -1,6 +1,7 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from abc import ABC, abstractmethod
from dataclasses import dataclass from dataclasses import dataclass
from typing import NamedTuple, Optional from typing import NamedTuple, Optional
@ -114,6 +115,20 @@ class ModelRunnerOutput:
num_nans_in_logits: Optional[dict[str, int]] = None num_nans_in_logits: Optional[dict[str, int]] = None
# ModelRunnerOutput wrapper for async scheduling.
class AsyncModelRunnerOutput(ABC):
@abstractmethod
def get_output(self) -> ModelRunnerOutput:
"""Get the ModelRunnerOutput for this async output.
This is a blocking call that waits until the results are ready, which
might involve copying device tensors to the host.
This method should only be called once per AsyncModelRunnerOutput.
"""
pass
@dataclass @dataclass
class DraftTokenIds: class DraftTokenIds:

View File

@ -250,6 +250,11 @@ class InputBatch:
self.pooling_params: dict[str, PoolingParams] = {} self.pooling_params: dict[str, PoolingParams] = {}
# Cached reference to the GPU tensor of previously sampled tokens
self.prev_sampled_token_ids: Optional[torch.Tensor] = None
self.prev_sampled_token_ids_invalid_indices: Optional[set[int]] = None
self.prev_req_id_to_index: Optional[dict[str, int]] = None
@property @property
def req_ids(self) -> list[str]: def req_ids(self) -> list[str]:
# None elements should only be present transiently # None elements should only be present transiently

View File

@ -67,8 +67,8 @@ from vllm.v1.kv_cache_interface import (AttentionSpec,
FullAttentionSpec, KVCacheConfig, FullAttentionSpec, KVCacheConfig,
KVCacheGroupSpec, KVCacheSpec, KVCacheGroupSpec, KVCacheSpec,
MambaSpec, SlidingWindowSpec) MambaSpec, SlidingWindowSpec)
from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, DraftTokenIds, from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, AsyncModelRunnerOutput,
LogprobsTensors, ModelRunnerOutput) DraftTokenIds, LogprobsTensors, ModelRunnerOutput)
from vllm.v1.pool.metadata import PoolingMetadata from vllm.v1.pool.metadata import PoolingMetadata
from vllm.v1.sample.logits_processor import LogitsProcessors, build_logitsprocs from vllm.v1.sample.logits_processor import LogitsProcessors, build_logitsprocs
from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.sample.metadata import SamplingMetadata
@ -100,6 +100,53 @@ else:
logger = init_logger(__name__) logger = init_logger(__name__)
# Wrapper for ModelRunnerOutput to support overlapped execution.
class AsyncGPUModelRunnerOutput(AsyncModelRunnerOutput):
def __init__(
self,
model_runner_output: ModelRunnerOutput,
sampled_token_ids: torch.Tensor,
invalid_req_indices: list[int],
async_output_copy_stream: torch.cuda.Stream,
):
self._model_runner_output = model_runner_output
self._invalid_req_indices = invalid_req_indices
# Event on the copy stream so we can synchronize the non-blocking copy.
self._async_copy_ready_event = torch.cuda.Event()
# Keep a reference to the device tensor to avoid it being
# deallocated until we finish copying it to the host.
self._sampled_token_ids = sampled_token_ids
# Initiate the copy on a separate stream, but do not synchronize it.
default_stream = torch.cuda.current_stream()
with torch.cuda.stream(async_output_copy_stream):
async_output_copy_stream.wait_stream(default_stream)
self._sampled_token_ids_cpu = self._sampled_token_ids.to(
'cpu', non_blocking=True)
self._async_copy_ready_event.record()
def get_output(self) -> ModelRunnerOutput:
"""Copy the device tensors to the host and return a ModelRunnerOutput.
This function blocks until the copy is finished.
"""
self._async_copy_ready_event.synchronize()
# Release the device tensor once the copy has completed
del self._sampled_token_ids
valid_sampled_token_ids = self._sampled_token_ids_cpu.tolist()
for i in self._invalid_req_indices:
valid_sampled_token_ids[i].clear()
output = self._model_runner_output
output.sampled_token_ids = valid_sampled_token_ids
return output
class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
def __init__( def __init__(
@ -230,6 +277,10 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
is_pooling_model=self.is_pooling_model, is_pooling_model=self.is_pooling_model,
) )
self.use_async_scheduling = self.scheduler_config.async_scheduling
self.async_output_copy_stream = torch.cuda.Stream() if \
self.use_async_scheduling else None
# TODO(woosuk): Provide an option to tune the max cudagraph batch size. # TODO(woosuk): Provide an option to tune the max cudagraph batch size.
# The convention is different. # The convention is different.
# self.cudagraph_batch_sizes sorts in ascending order. # self.cudagraph_batch_sizes sorts in ascending order.
@ -654,6 +705,73 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
return cu_num_tokens, arange return cu_num_tokens, arange
def _prepare_input_ids(self, total_num_scheduled_tokens: int,
cu_num_tokens: np.ndarray) -> None:
"""Prepare the input IDs for the current batch.
Carefully handles the `prev_sampled_token_ids` which can be cached
from the previous engine iteration, in which case those tokens on the
GPU need to be copied into the corresponding slots into input_ids."""
if self.input_batch.prev_sampled_token_ids is None:
# Normal scheduling case
self.input_ids.copy_to_gpu(total_num_scheduled_tokens)
return
# Async scheduling case, where some decode requests from the previous
# iteration won't have entries in input_ids_cpu and need to be copied
# on the GPU from prev_sampled_token_ids.
prev_req_id_to_index = self.input_batch.prev_req_id_to_index
assert prev_req_id_to_index is not None
flattened_indices = []
prev_common_req_indices = []
indices_match = True
max_flattened_index = -1
for req_id, cur_index in self.input_batch.req_id_to_index.items():
if (prev_index := prev_req_id_to_index.get(req_id)) is not None:
prev_common_req_indices.append(prev_index)
# We need to compute the flattened input_ids index of the
# last token in each common request.
flattened_index = cu_num_tokens[cur_index].item() - 1
flattened_indices.append(flattened_index)
indices_match &= (prev_index == flattened_index)
max_flattened_index = max(max_flattened_index, flattened_index)
num_commmon_tokens = len(flattened_indices)
if num_commmon_tokens < total_num_scheduled_tokens:
# If not all requests are decodes from the last iteration,
# We need to copy the input_ids_cpu to the GPU first.
self.input_ids.copy_to_gpu(total_num_scheduled_tokens)
if num_commmon_tokens == 0:
# No requests in common with the previous iteration
# So input_ids_cpu will have all the input ids.
return
if indices_match and max_flattened_index == (num_commmon_tokens - 1):
# Common-case optimization: the batch is unchanged
# and no reordering happened.
# The indices are both the same permutation of 0..N-1 so
# we can copy directly using a single slice.
self.input_ids.gpu[:num_commmon_tokens].copy_(
self.input_batch.prev_sampled_token_ids[:num_commmon_tokens,
0],
non_blocking=True)
return
# Upload the index tensors asynchronously
# so the scatter can be non-blocking.
input_ids_index_tensor = torch.tensor(flattened_indices,
dtype=torch.int64,
pin_memory=self.pin_memory).to(
self.device,
non_blocking=True)
prev_common_req_indices_tensor = torch.tensor(
prev_common_req_indices,
dtype=torch.int64,
pin_memory=self.pin_memory).to(self.device, non_blocking=True)
self.input_ids.gpu.scatter_(
dim=0,
index=input_ids_index_tensor,
src=self.input_batch.prev_sampled_token_ids[
prev_common_req_indices_tensor, 0])
def _prepare_inputs( def _prepare_inputs(
self, self,
scheduler_output: "SchedulerOutput", scheduler_output: "SchedulerOutput",
@ -740,7 +858,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
max_seq_len = self.seq_lens.np[:num_reqs].max().item() max_seq_len = self.seq_lens.np[:num_reqs].max().item()
# Copy the tensors to the GPU. # Copy the tensors to the GPU.
self.input_ids.copy_to_gpu(total_num_scheduled_tokens) self._prepare_input_ids(total_num_scheduled_tokens, cu_num_tokens)
if self.uses_mrope: if self.uses_mrope:
# Only relevant for models using M-RoPE (e.g, Qwen2-VL) # Only relevant for models using M-RoPE (e.g, Qwen2-VL)
self.mrope_positions.gpu[:, :total_num_scheduled_tokens].copy_( self.mrope_positions.gpu[:, :total_num_scheduled_tokens].copy_(
@ -1458,7 +1577,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self, self,
scheduler_output: "SchedulerOutput", scheduler_output: "SchedulerOutput",
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
) -> Union[ModelRunnerOutput, IntermediateTensors]: ) -> Union[ModelRunnerOutput, AsyncModelRunnerOutput, IntermediateTensors]:
self._update_states(scheduler_output) self._update_states(scheduler_output)
if not scheduler_output.total_num_scheduled_tokens: if not scheduler_output.total_num_scheduled_tokens:
if not has_kv_transfer_group(): if not has_kv_transfer_group():
@ -1673,6 +1792,12 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# so that we could clear the sampled tokens before returning. # so that we could clear the sampled tokens before returning.
discard_sampled_tokens_req_indices.append(i) discard_sampled_tokens_req_indices.append(i)
# Copy some objects so they don't get modified after returning.
# This is important when using async scheduling.
req_ids_output_copy = self.input_batch.req_ids.copy()
req_id_to_index_output_copy = \
self.input_batch.req_id_to_index.copy()
# NOTE: GPU -> CPU Sync happens here. # NOTE: GPU -> CPU Sync happens here.
# Move as many CPU operations as possible before this sync point. # Move as many CPU operations as possible before this sync point.
logprobs_tensors = sampler_output.logprobs_tensors logprobs_tensors = sampler_output.logprobs_tensors
@ -1685,21 +1810,41 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
scheduler_output.num_scheduled_tokens, scheduler_output.num_scheduled_tokens,
) )
# Get the valid generated tokens. num_sampled_tokens = sampler_output.sampled_token_ids.shape[0]
sampled_token_ids = sampler_output.sampled_token_ids sampled_token_ids = sampler_output.sampled_token_ids
max_gen_len = sampled_token_ids.shape[-1] if not self.use_async_scheduling:
if max_gen_len == 1: # Get the valid generated tokens.
# No spec decode tokens. max_gen_len = sampled_token_ids.shape[-1]
valid_sampled_token_ids = self._to_list(sampled_token_ids) if max_gen_len == 1:
# No spec decode tokens.
valid_sampled_token_ids = self._to_list(sampled_token_ids)
else:
# Includes spec decode tokens.
valid_sampled_token_ids = self.rejection_sampler.parse_output(
sampled_token_ids,
self.input_batch.vocab_size,
)
# Mask out the sampled tokens that should not be sampled.
for i in discard_sampled_tokens_req_indices:
valid_sampled_token_ids[i].clear()
else: else:
# Includes spec decode tokens. valid_sampled_token_ids = []
valid_sampled_token_ids = self.rejection_sampler.parse_output( invalid_req_indices = list(discard_sampled_tokens_req_indices)
sampled_token_ids, invalid_req_indices_set = set(invalid_req_indices)
self.input_batch.vocab_size, assert sampled_token_ids.shape[-1] == 1
)
# Mask out the sampled tokens that should not be sampled. # Cache the sampled tokens on the GPU and avoid CPU sync.
for i in discard_sampled_tokens_req_indices: # These will be copied into input_ids in the next step
valid_sampled_token_ids[i].clear() # when preparing inputs.
self.input_batch.prev_sampled_token_ids = \
sampled_token_ids
self.input_batch.prev_sampled_token_ids_invalid_indices = \
invalid_req_indices_set
self.input_batch.prev_req_id_to_index = {
req_id: i
for i, req_id in enumerate(self.input_batch.req_ids)
if i not in invalid_req_indices_set
}
# Cache the sampled tokens in the model runner, so that the scheduler # Cache the sampled tokens in the model runner, so that the scheduler
# doesn't need to send them back. # doesn't need to send them back.
@ -1707,7 +1852,12 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# the sampled tokens back, because there's no direct communication # the sampled tokens back, because there's no direct communication
# between the first-stage worker and the last-stage worker. # between the first-stage worker and the last-stage worker.
req_ids = self.input_batch.req_ids req_ids = self.input_batch.req_ids
for req_idx, sampled_ids in enumerate(valid_sampled_token_ids): for req_idx in range(num_sampled_tokens):
if self.use_async_scheduling:
sampled_ids = [-1] if \
req_idx not in invalid_req_indices_set else None
else:
sampled_ids = valid_sampled_token_ids[req_idx]
if not sampled_ids: if not sampled_ids:
continue continue
@ -1722,6 +1872,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
start_idx:end_idx] = sampled_ids start_idx:end_idx] = sampled_ids
self.input_batch.num_tokens_no_spec[req_idx] = end_idx self.input_batch.num_tokens_no_spec[req_idx] = end_idx
self.input_batch.num_tokens[req_idx] = end_idx self.input_batch.num_tokens[req_idx] = end_idx
req_id = req_ids[req_idx] req_id = req_ids[req_idx]
req_state = self.requests[req_id] req_state = self.requests[req_id]
req_state.output_token_ids.extend(sampled_ids) req_state.output_token_ids.extend(sampled_ids)
@ -1741,9 +1892,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self.eplb_step() self.eplb_step()
return ModelRunnerOutput( output = ModelRunnerOutput(
req_ids=self.input_batch.req_ids, req_ids=req_ids_output_copy,
req_id_to_index=self.input_batch.req_id_to_index, req_id_to_index=req_id_to_index_output_copy,
sampled_token_ids=valid_sampled_token_ids, sampled_token_ids=valid_sampled_token_ids,
logprobs=logprobs_lists, logprobs=logprobs_lists,
prompt_logprobs_dict=prompt_logprobs_dict, prompt_logprobs_dict=prompt_logprobs_dict,
@ -1752,6 +1903,16 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
num_nans_in_logits=num_nans_in_logits, num_nans_in_logits=num_nans_in_logits,
) )
if not self.use_async_scheduling:
return output
return AsyncGPUModelRunnerOutput(
model_runner_output=output,
sampled_token_ids=sampled_token_ids,
invalid_req_indices=invalid_req_indices,
async_output_copy_stream=self.async_output_copy_stream,
)
def take_draft_token_ids(self) -> Optional[DraftTokenIds]: def take_draft_token_ids(self) -> Optional[DraftTokenIds]:
if self._draft_token_ids is None: if self._draft_token_ids is None:
return None return None

View File

@ -5,7 +5,7 @@ import copy
import gc import gc
import os import os
from contextlib import AbstractContextManager, nullcontext from contextlib import AbstractContextManager, nullcontext
from typing import TYPE_CHECKING, Any, Optional from typing import TYPE_CHECKING, Any, Optional, Union
import torch import torch
import torch.distributed import torch.distributed
@ -28,8 +28,8 @@ from vllm.tasks import SupportedTask
from vllm.utils import GiB_bytes, MemorySnapshot, memory_profiling from vllm.utils import GiB_bytes, MemorySnapshot, memory_profiling
from vllm.v1.engine import ReconfigureDistributedRequest, ReconfigureRankType from vllm.v1.engine import ReconfigureDistributedRequest, ReconfigureRankType
from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, DraftTokenIds, from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, AsyncModelRunnerOutput,
ModelRunnerOutput) DraftTokenIds, ModelRunnerOutput)
from vllm.v1.utils import report_usage_stats from vllm.v1.utils import report_usage_stats
from vllm.v1.worker.gpu_model_runner import GPUModelRunner from vllm.v1.worker.gpu_model_runner import GPUModelRunner
from vllm.v1.worker.worker_base import WorkerBase from vllm.v1.worker.worker_base import WorkerBase
@ -355,7 +355,7 @@ class Worker(WorkerBase):
def execute_model( def execute_model(
self, self,
scheduler_output: "SchedulerOutput", scheduler_output: "SchedulerOutput",
) -> Optional[ModelRunnerOutput]: ) -> Optional[Union[ModelRunnerOutput, AsyncModelRunnerOutput]]:
intermediate_tensors = None intermediate_tensors = None
forward_pass = scheduler_output.total_num_scheduled_tokens > 0 forward_pass = scheduler_output.total_num_scheduled_tokens > 0
if forward_pass and not get_pp_group().is_first_rank: if forward_pass and not get_pp_group().is_first_rank:
@ -365,7 +365,7 @@ class Worker(WorkerBase):
output = self.model_runner.execute_model(scheduler_output, output = self.model_runner.execute_model(scheduler_output,
intermediate_tensors) intermediate_tensors)
if isinstance(output, ModelRunnerOutput): if isinstance(output, (ModelRunnerOutput, AsyncModelRunnerOutput)):
return output return output
assert isinstance(output, IntermediateTensors) assert isinstance(output, IntermediateTensors)