mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-04 03:53:05 +08:00
[Perf][V1] Fully overlap model execution (#23569)
Signed-off-by: Benjamin Chislett <benjamin.chislett@centml.ai>
This commit is contained in:
parent
c954c6629c
commit
cee182b297
@ -3,6 +3,7 @@
|
||||
import multiprocessing
|
||||
import os
|
||||
import pickle
|
||||
import queue
|
||||
import signal
|
||||
import threading
|
||||
import time
|
||||
@ -33,7 +34,8 @@ from vllm.utils import (decorate_logs, get_distributed_init_method,
|
||||
get_loopback_ip, get_mp_context, get_open_port,
|
||||
set_process_title)
|
||||
from vllm.v1.executor.abstract import Executor, FailureCallback
|
||||
from vllm.v1.outputs import DraftTokenIds, ModelRunnerOutput
|
||||
from vllm.v1.outputs import (AsyncModelRunnerOutput, DraftTokenIds,
|
||||
ModelRunnerOutput)
|
||||
from vllm.worker.worker_base import WorkerWrapperBase
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@ -414,6 +416,16 @@ class WorkerProc:
|
||||
# Initializes a message queue for sending the model output
|
||||
self.worker_response_mq = MessageQueue(1, 1)
|
||||
|
||||
scheduler_config = vllm_config.scheduler_config
|
||||
self.use_async_scheduling = scheduler_config.async_scheduling
|
||||
if self.use_async_scheduling:
|
||||
self.async_output_queue: queue.Queue = queue.Queue()
|
||||
self.async_output_copy_thread = Thread(
|
||||
target=self.async_output_busy_loop,
|
||||
daemon=True,
|
||||
name="WorkerAsyncOutputCopy")
|
||||
self.async_output_copy_thread.start()
|
||||
|
||||
# Initialize device and loads weights
|
||||
self.worker.init_device()
|
||||
self.worker.load_model()
|
||||
@ -595,6 +607,36 @@ class WorkerProc:
|
||||
SUCCESS = auto()
|
||||
FAILURE = auto()
|
||||
|
||||
def enqueue_output(self, output: Any):
|
||||
"""Prepares output from the worker and enqueues it to the
|
||||
worker_response_mq. If the output is an Exception, it is
|
||||
converted to a FAILURE response.
|
||||
"""
|
||||
if isinstance(output, AsyncModelRunnerOutput):
|
||||
output = output.get_output()
|
||||
|
||||
if isinstance(output, Exception):
|
||||
result = (WorkerProc.ResponseStatus.FAILURE, str(output))
|
||||
else:
|
||||
result = (WorkerProc.ResponseStatus.SUCCESS, output)
|
||||
self.worker_response_mq.enqueue(result)
|
||||
|
||||
def handle_output(self, output: Any):
|
||||
"""Handles output from the worker. If async scheduling is enabled,
|
||||
it is passed to the async_output_busy_loop thread. Otherwise, it is
|
||||
enqueued directly to the worker_response_mq.
|
||||
"""
|
||||
if self.use_async_scheduling:
|
||||
self.async_output_queue.put(output)
|
||||
else:
|
||||
self.enqueue_output(output)
|
||||
|
||||
def async_output_busy_loop(self):
|
||||
"""Entrypoint for the thread which handles outputs asynchronously."""
|
||||
while True:
|
||||
output = self.async_output_queue.get()
|
||||
self.enqueue_output(output)
|
||||
|
||||
def worker_busy_loop(self):
|
||||
"""Main busy loop for Multiprocessing Workers"""
|
||||
while True:
|
||||
@ -614,10 +656,8 @@ class WorkerProc:
|
||||
# exception might not be serializable, so we convert it to
|
||||
# string, only for logging purpose.
|
||||
if output_rank is None or self.rank == output_rank:
|
||||
self.worker_response_mq.enqueue(
|
||||
(WorkerProc.ResponseStatus.FAILURE, str(e)))
|
||||
self.handle_output(e)
|
||||
continue
|
||||
|
||||
if output_rank is None or self.rank == output_rank:
|
||||
self.worker_response_mq.enqueue(
|
||||
(WorkerProc.ResponseStatus.SUCCESS, output))
|
||||
self.handle_output(output)
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from typing import NamedTuple, Optional
|
||||
|
||||
@ -114,6 +115,20 @@ class ModelRunnerOutput:
|
||||
num_nans_in_logits: Optional[dict[str, int]] = None
|
||||
|
||||
|
||||
# ModelRunnerOutput wrapper for async scheduling.
|
||||
class AsyncModelRunnerOutput(ABC):
|
||||
|
||||
@abstractmethod
|
||||
def get_output(self) -> ModelRunnerOutput:
|
||||
"""Get the ModelRunnerOutput for this async output.
|
||||
|
||||
This is a blocking call that waits until the results are ready, which
|
||||
might involve copying device tensors to the host.
|
||||
This method should only be called once per AsyncModelRunnerOutput.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class DraftTokenIds:
|
||||
|
||||
|
||||
@ -250,6 +250,11 @@ class InputBatch:
|
||||
|
||||
self.pooling_params: dict[str, PoolingParams] = {}
|
||||
|
||||
# Cached reference to the GPU tensor of previously sampled tokens
|
||||
self.prev_sampled_token_ids: Optional[torch.Tensor] = None
|
||||
self.prev_sampled_token_ids_invalid_indices: Optional[set[int]] = None
|
||||
self.prev_req_id_to_index: Optional[dict[str, int]] = None
|
||||
|
||||
@property
|
||||
def req_ids(self) -> list[str]:
|
||||
# None elements should only be present transiently
|
||||
|
||||
@ -67,8 +67,8 @@ from vllm.v1.kv_cache_interface import (AttentionSpec,
|
||||
FullAttentionSpec, KVCacheConfig,
|
||||
KVCacheGroupSpec, KVCacheSpec,
|
||||
MambaSpec, SlidingWindowSpec)
|
||||
from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, DraftTokenIds,
|
||||
LogprobsTensors, ModelRunnerOutput)
|
||||
from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, AsyncModelRunnerOutput,
|
||||
DraftTokenIds, LogprobsTensors, ModelRunnerOutput)
|
||||
from vllm.v1.pool.metadata import PoolingMetadata
|
||||
from vllm.v1.sample.logits_processor import LogitsProcessors, build_logitsprocs
|
||||
from vllm.v1.sample.metadata import SamplingMetadata
|
||||
@ -100,6 +100,53 @@ else:
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
# Wrapper for ModelRunnerOutput to support overlapped execution.
|
||||
class AsyncGPUModelRunnerOutput(AsyncModelRunnerOutput):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_runner_output: ModelRunnerOutput,
|
||||
sampled_token_ids: torch.Tensor,
|
||||
invalid_req_indices: list[int],
|
||||
async_output_copy_stream: torch.cuda.Stream,
|
||||
):
|
||||
self._model_runner_output = model_runner_output
|
||||
self._invalid_req_indices = invalid_req_indices
|
||||
|
||||
# Event on the copy stream so we can synchronize the non-blocking copy.
|
||||
self._async_copy_ready_event = torch.cuda.Event()
|
||||
|
||||
# Keep a reference to the device tensor to avoid it being
|
||||
# deallocated until we finish copying it to the host.
|
||||
self._sampled_token_ids = sampled_token_ids
|
||||
|
||||
# Initiate the copy on a separate stream, but do not synchronize it.
|
||||
default_stream = torch.cuda.current_stream()
|
||||
with torch.cuda.stream(async_output_copy_stream):
|
||||
async_output_copy_stream.wait_stream(default_stream)
|
||||
self._sampled_token_ids_cpu = self._sampled_token_ids.to(
|
||||
'cpu', non_blocking=True)
|
||||
self._async_copy_ready_event.record()
|
||||
|
||||
def get_output(self) -> ModelRunnerOutput:
|
||||
"""Copy the device tensors to the host and return a ModelRunnerOutput.
|
||||
|
||||
This function blocks until the copy is finished.
|
||||
"""
|
||||
self._async_copy_ready_event.synchronize()
|
||||
|
||||
# Release the device tensor once the copy has completed
|
||||
del self._sampled_token_ids
|
||||
|
||||
valid_sampled_token_ids = self._sampled_token_ids_cpu.tolist()
|
||||
for i in self._invalid_req_indices:
|
||||
valid_sampled_token_ids[i].clear()
|
||||
|
||||
output = self._model_runner_output
|
||||
output.sampled_token_ids = valid_sampled_token_ids
|
||||
return output
|
||||
|
||||
|
||||
class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
|
||||
def __init__(
|
||||
@ -230,6 +277,10 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
is_pooling_model=self.is_pooling_model,
|
||||
)
|
||||
|
||||
self.use_async_scheduling = self.scheduler_config.async_scheduling
|
||||
self.async_output_copy_stream = torch.cuda.Stream() if \
|
||||
self.use_async_scheduling else None
|
||||
|
||||
# TODO(woosuk): Provide an option to tune the max cudagraph batch size.
|
||||
# The convention is different.
|
||||
# self.cudagraph_batch_sizes sorts in ascending order.
|
||||
@ -654,6 +705,73 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
|
||||
return cu_num_tokens, arange
|
||||
|
||||
def _prepare_input_ids(self, total_num_scheduled_tokens: int,
|
||||
cu_num_tokens: np.ndarray) -> None:
|
||||
"""Prepare the input IDs for the current batch.
|
||||
|
||||
Carefully handles the `prev_sampled_token_ids` which can be cached
|
||||
from the previous engine iteration, in which case those tokens on the
|
||||
GPU need to be copied into the corresponding slots into input_ids."""
|
||||
|
||||
if self.input_batch.prev_sampled_token_ids is None:
|
||||
# Normal scheduling case
|
||||
self.input_ids.copy_to_gpu(total_num_scheduled_tokens)
|
||||
return
|
||||
|
||||
# Async scheduling case, where some decode requests from the previous
|
||||
# iteration won't have entries in input_ids_cpu and need to be copied
|
||||
# on the GPU from prev_sampled_token_ids.
|
||||
prev_req_id_to_index = self.input_batch.prev_req_id_to_index
|
||||
assert prev_req_id_to_index is not None
|
||||
flattened_indices = []
|
||||
prev_common_req_indices = []
|
||||
indices_match = True
|
||||
max_flattened_index = -1
|
||||
for req_id, cur_index in self.input_batch.req_id_to_index.items():
|
||||
if (prev_index := prev_req_id_to_index.get(req_id)) is not None:
|
||||
prev_common_req_indices.append(prev_index)
|
||||
# We need to compute the flattened input_ids index of the
|
||||
# last token in each common request.
|
||||
flattened_index = cu_num_tokens[cur_index].item() - 1
|
||||
flattened_indices.append(flattened_index)
|
||||
indices_match &= (prev_index == flattened_index)
|
||||
max_flattened_index = max(max_flattened_index, flattened_index)
|
||||
num_commmon_tokens = len(flattened_indices)
|
||||
if num_commmon_tokens < total_num_scheduled_tokens:
|
||||
# If not all requests are decodes from the last iteration,
|
||||
# We need to copy the input_ids_cpu to the GPU first.
|
||||
self.input_ids.copy_to_gpu(total_num_scheduled_tokens)
|
||||
if num_commmon_tokens == 0:
|
||||
# No requests in common with the previous iteration
|
||||
# So input_ids_cpu will have all the input ids.
|
||||
return
|
||||
if indices_match and max_flattened_index == (num_commmon_tokens - 1):
|
||||
# Common-case optimization: the batch is unchanged
|
||||
# and no reordering happened.
|
||||
# The indices are both the same permutation of 0..N-1 so
|
||||
# we can copy directly using a single slice.
|
||||
self.input_ids.gpu[:num_commmon_tokens].copy_(
|
||||
self.input_batch.prev_sampled_token_ids[:num_commmon_tokens,
|
||||
0],
|
||||
non_blocking=True)
|
||||
return
|
||||
# Upload the index tensors asynchronously
|
||||
# so the scatter can be non-blocking.
|
||||
input_ids_index_tensor = torch.tensor(flattened_indices,
|
||||
dtype=torch.int64,
|
||||
pin_memory=self.pin_memory).to(
|
||||
self.device,
|
||||
non_blocking=True)
|
||||
prev_common_req_indices_tensor = torch.tensor(
|
||||
prev_common_req_indices,
|
||||
dtype=torch.int64,
|
||||
pin_memory=self.pin_memory).to(self.device, non_blocking=True)
|
||||
self.input_ids.gpu.scatter_(
|
||||
dim=0,
|
||||
index=input_ids_index_tensor,
|
||||
src=self.input_batch.prev_sampled_token_ids[
|
||||
prev_common_req_indices_tensor, 0])
|
||||
|
||||
def _prepare_inputs(
|
||||
self,
|
||||
scheduler_output: "SchedulerOutput",
|
||||
@ -740,7 +858,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
max_seq_len = self.seq_lens.np[:num_reqs].max().item()
|
||||
|
||||
# Copy the tensors to the GPU.
|
||||
self.input_ids.copy_to_gpu(total_num_scheduled_tokens)
|
||||
self._prepare_input_ids(total_num_scheduled_tokens, cu_num_tokens)
|
||||
|
||||
if self.uses_mrope:
|
||||
# Only relevant for models using M-RoPE (e.g, Qwen2-VL)
|
||||
self.mrope_positions.gpu[:, :total_num_scheduled_tokens].copy_(
|
||||
@ -1458,7 +1577,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
self,
|
||||
scheduler_output: "SchedulerOutput",
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
) -> Union[ModelRunnerOutput, IntermediateTensors]:
|
||||
) -> Union[ModelRunnerOutput, AsyncModelRunnerOutput, IntermediateTensors]:
|
||||
self._update_states(scheduler_output)
|
||||
if not scheduler_output.total_num_scheduled_tokens:
|
||||
if not has_kv_transfer_group():
|
||||
@ -1673,6 +1792,12 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
# so that we could clear the sampled tokens before returning.
|
||||
discard_sampled_tokens_req_indices.append(i)
|
||||
|
||||
# Copy some objects so they don't get modified after returning.
|
||||
# This is important when using async scheduling.
|
||||
req_ids_output_copy = self.input_batch.req_ids.copy()
|
||||
req_id_to_index_output_copy = \
|
||||
self.input_batch.req_id_to_index.copy()
|
||||
|
||||
# NOTE: GPU -> CPU Sync happens here.
|
||||
# Move as many CPU operations as possible before this sync point.
|
||||
logprobs_tensors = sampler_output.logprobs_tensors
|
||||
@ -1685,21 +1810,41 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
scheduler_output.num_scheduled_tokens,
|
||||
)
|
||||
|
||||
# Get the valid generated tokens.
|
||||
num_sampled_tokens = sampler_output.sampled_token_ids.shape[0]
|
||||
sampled_token_ids = sampler_output.sampled_token_ids
|
||||
max_gen_len = sampled_token_ids.shape[-1]
|
||||
if max_gen_len == 1:
|
||||
# No spec decode tokens.
|
||||
valid_sampled_token_ids = self._to_list(sampled_token_ids)
|
||||
if not self.use_async_scheduling:
|
||||
# Get the valid generated tokens.
|
||||
max_gen_len = sampled_token_ids.shape[-1]
|
||||
if max_gen_len == 1:
|
||||
# No spec decode tokens.
|
||||
valid_sampled_token_ids = self._to_list(sampled_token_ids)
|
||||
else:
|
||||
# Includes spec decode tokens.
|
||||
valid_sampled_token_ids = self.rejection_sampler.parse_output(
|
||||
sampled_token_ids,
|
||||
self.input_batch.vocab_size,
|
||||
)
|
||||
# Mask out the sampled tokens that should not be sampled.
|
||||
for i in discard_sampled_tokens_req_indices:
|
||||
valid_sampled_token_ids[i].clear()
|
||||
else:
|
||||
# Includes spec decode tokens.
|
||||
valid_sampled_token_ids = self.rejection_sampler.parse_output(
|
||||
sampled_token_ids,
|
||||
self.input_batch.vocab_size,
|
||||
)
|
||||
# Mask out the sampled tokens that should not be sampled.
|
||||
for i in discard_sampled_tokens_req_indices:
|
||||
valid_sampled_token_ids[i].clear()
|
||||
valid_sampled_token_ids = []
|
||||
invalid_req_indices = list(discard_sampled_tokens_req_indices)
|
||||
invalid_req_indices_set = set(invalid_req_indices)
|
||||
assert sampled_token_ids.shape[-1] == 1
|
||||
|
||||
# Cache the sampled tokens on the GPU and avoid CPU sync.
|
||||
# These will be copied into input_ids in the next step
|
||||
# when preparing inputs.
|
||||
self.input_batch.prev_sampled_token_ids = \
|
||||
sampled_token_ids
|
||||
self.input_batch.prev_sampled_token_ids_invalid_indices = \
|
||||
invalid_req_indices_set
|
||||
self.input_batch.prev_req_id_to_index = {
|
||||
req_id: i
|
||||
for i, req_id in enumerate(self.input_batch.req_ids)
|
||||
if i not in invalid_req_indices_set
|
||||
}
|
||||
|
||||
# Cache the sampled tokens in the model runner, so that the scheduler
|
||||
# doesn't need to send them back.
|
||||
@ -1707,7 +1852,12 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
# the sampled tokens back, because there's no direct communication
|
||||
# between the first-stage worker and the last-stage worker.
|
||||
req_ids = self.input_batch.req_ids
|
||||
for req_idx, sampled_ids in enumerate(valid_sampled_token_ids):
|
||||
for req_idx in range(num_sampled_tokens):
|
||||
if self.use_async_scheduling:
|
||||
sampled_ids = [-1] if \
|
||||
req_idx not in invalid_req_indices_set else None
|
||||
else:
|
||||
sampled_ids = valid_sampled_token_ids[req_idx]
|
||||
if not sampled_ids:
|
||||
continue
|
||||
|
||||
@ -1722,6 +1872,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
start_idx:end_idx] = sampled_ids
|
||||
self.input_batch.num_tokens_no_spec[req_idx] = end_idx
|
||||
self.input_batch.num_tokens[req_idx] = end_idx
|
||||
|
||||
req_id = req_ids[req_idx]
|
||||
req_state = self.requests[req_id]
|
||||
req_state.output_token_ids.extend(sampled_ids)
|
||||
@ -1741,9 +1892,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
|
||||
self.eplb_step()
|
||||
|
||||
return ModelRunnerOutput(
|
||||
req_ids=self.input_batch.req_ids,
|
||||
req_id_to_index=self.input_batch.req_id_to_index,
|
||||
output = ModelRunnerOutput(
|
||||
req_ids=req_ids_output_copy,
|
||||
req_id_to_index=req_id_to_index_output_copy,
|
||||
sampled_token_ids=valid_sampled_token_ids,
|
||||
logprobs=logprobs_lists,
|
||||
prompt_logprobs_dict=prompt_logprobs_dict,
|
||||
@ -1752,6 +1903,16 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
num_nans_in_logits=num_nans_in_logits,
|
||||
)
|
||||
|
||||
if not self.use_async_scheduling:
|
||||
return output
|
||||
|
||||
return AsyncGPUModelRunnerOutput(
|
||||
model_runner_output=output,
|
||||
sampled_token_ids=sampled_token_ids,
|
||||
invalid_req_indices=invalid_req_indices,
|
||||
async_output_copy_stream=self.async_output_copy_stream,
|
||||
)
|
||||
|
||||
def take_draft_token_ids(self) -> Optional[DraftTokenIds]:
|
||||
if self._draft_token_ids is None:
|
||||
return None
|
||||
|
||||
@ -5,7 +5,7 @@ import copy
|
||||
import gc
|
||||
import os
|
||||
from contextlib import AbstractContextManager, nullcontext
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
from typing import TYPE_CHECKING, Any, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed
|
||||
@ -28,8 +28,8 @@ from vllm.tasks import SupportedTask
|
||||
from vllm.utils import GiB_bytes, MemorySnapshot, memory_profiling
|
||||
from vllm.v1.engine import ReconfigureDistributedRequest, ReconfigureRankType
|
||||
from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
|
||||
from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, DraftTokenIds,
|
||||
ModelRunnerOutput)
|
||||
from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, AsyncModelRunnerOutput,
|
||||
DraftTokenIds, ModelRunnerOutput)
|
||||
from vllm.v1.utils import report_usage_stats
|
||||
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
|
||||
from vllm.v1.worker.worker_base import WorkerBase
|
||||
@ -355,7 +355,7 @@ class Worker(WorkerBase):
|
||||
def execute_model(
|
||||
self,
|
||||
scheduler_output: "SchedulerOutput",
|
||||
) -> Optional[ModelRunnerOutput]:
|
||||
) -> Optional[Union[ModelRunnerOutput, AsyncModelRunnerOutput]]:
|
||||
intermediate_tensors = None
|
||||
forward_pass = scheduler_output.total_num_scheduled_tokens > 0
|
||||
if forward_pass and not get_pp_group().is_first_rank:
|
||||
@ -365,7 +365,7 @@ class Worker(WorkerBase):
|
||||
|
||||
output = self.model_runner.execute_model(scheduler_output,
|
||||
intermediate_tensors)
|
||||
if isinstance(output, ModelRunnerOutput):
|
||||
if isinstance(output, (ModelRunnerOutput, AsyncModelRunnerOutput)):
|
||||
return output
|
||||
|
||||
assert isinstance(output, IntermediateTensors)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user