[Core] Eliminate parallel worker per-step task scheduling overhead (#4894)

This commit is contained in:
Nick Hill 2024-05-22 14:17:27 -07:00 committed by GitHub
parent 97b030005c
commit eb6d3c264d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 350 additions and 211 deletions

View File

@ -234,6 +234,14 @@ class _AsyncLLMEngine(LLMEngine):
# Log stats. # Log stats.
self.do_log_stats(scheduler_outputs, output) self.do_log_stats(scheduler_outputs, output)
if not request_outputs:
# Stop the execute model loop in parallel workers until there are
# more requests to process. This avoids waiting indefinitely in
# torch.distributed ops which may otherwise timeout, and unblocks
# the RPC thread in the workers so that they can process any other
# queued control plane messages, such as add/remove lora adapters.
await self.model_executor.stop_remote_worker_execution_loop_async()
return request_outputs return request_outputs
async def encode_request_async( async def encode_request_async(

View File

@ -692,6 +692,14 @@ class LLMEngine:
# Log stats. # Log stats.
self.do_log_stats(scheduler_outputs, output) self.do_log_stats(scheduler_outputs, output)
if not request_outputs:
# Stop the execute model loop in parallel workers until there are
# more requests to process. This avoids waiting indefinitely in
# torch.distributed ops which may otherwise timeout, and unblocks
# the RPC thread in the workers so that they can process any other
# queued control plane messages, such as add/remove lora adapters.
self.model_executor.stop_remote_worker_execution_loop()
return request_outputs return request_outputs
def do_log_stats( def do_log_stats(

View File

@ -1,11 +1,12 @@
import asyncio
from abc import abstractmethod from abc import abstractmethod
from typing import Any, Dict, List, Optional, Set, Tuple from typing import Any, Awaitable, Dict, List, Optional, Set, Tuple, Union
from vllm.executor.executor_base import ExecutorAsyncBase from vllm.executor.executor_base import ExecutorAsyncBase
from vllm.executor.gpu_executor import GPUExecutor from vllm.executor.gpu_executor import GPUExecutor
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.sequence import SamplerOutput from vllm.sequence import ExecuteModelRequest, SamplerOutput
logger = init_logger(__name__) logger = init_logger(__name__)
@ -13,6 +14,16 @@ logger = init_logger(__name__)
class DistributedGPUExecutor(GPUExecutor): class DistributedGPUExecutor(GPUExecutor):
"""Abstract superclass of multi-GPU executor implementations.""" """Abstract superclass of multi-GPU executor implementations."""
def __init__(self, *args, **kwargs):
# This is non-None when the execute model loop is running
# in the parallel workers. It's a coroutine in the AsyncLLMEngine case.
self.parallel_worker_tasks: Optional[Union[Any, Awaitable[Any]]] = None
# Updated by implementations that require additional args to be passed
# to the _run_workers execute_model call
self.extra_execute_model_run_workers_kwargs: Dict[str, Any] = {}
super().__init__(*args, **kwargs)
def determine_num_available_blocks(self) -> Tuple[int, int]: def determine_num_available_blocks(self) -> Tuple[int, int]:
"""Determine the number of available KV blocks. """Determine the number of available KV blocks.
@ -52,13 +63,28 @@ class DistributedGPUExecutor(GPUExecutor):
num_gpu_blocks=num_gpu_blocks, num_gpu_blocks=num_gpu_blocks,
num_cpu_blocks=num_cpu_blocks) num_cpu_blocks=num_cpu_blocks)
def execute_model(self, *args, **kwargs) -> List[SamplerOutput]: def execute_model(
all_outputs = self._run_workers("execute_model", self,
driver_args=args, execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
driver_kwargs=kwargs) if self.parallel_worker_tasks is None:
self.parallel_worker_tasks = self._run_workers(
"start_worker_execution_loop",
async_run_remote_workers_only=True,
**self.extra_execute_model_run_workers_kwargs)
# Only the driver worker returns the sampling results. # Only the driver worker returns the sampling results.
return all_outputs[0] return self._driver_execute_model(execute_model_req)
def stop_remote_worker_execution_loop(self) -> None:
if self.parallel_worker_tasks is None:
return
self._driver_execute_model()
parallel_worker_tasks = self.parallel_worker_tasks
self.parallel_worker_tasks = None
# Ensure that workers exit model loop cleanly
# (this will raise otherwise)
self._wait_for_tasks_completion(parallel_worker_tasks)
def add_lora(self, lora_request: LoRARequest) -> bool: def add_lora(self, lora_request: LoRARequest) -> bool:
assert lora_request.lora_int_id > 0, "lora_id must be greater than 0." assert lora_request.lora_int_id > 0, "lora_id must be greater than 0."
@ -88,39 +114,84 @@ class DistributedGPUExecutor(GPUExecutor):
pattern=pattern, pattern=pattern,
max_size=max_size) max_size=max_size)
@abstractmethod
def _driver_execute_model(
self,
execute_model_req: Optional[ExecuteModelRequest] = None
) -> List[SamplerOutput]:
"""Run execute_model in the driver worker.
Passing None will cause the driver to stop the model execution
loop running in each of the remote workers.
"""
raise NotImplementedError
@abstractmethod @abstractmethod
def _run_workers( def _run_workers(
self, self,
method: str, method: str,
*args, *args,
driver_args: Optional[Tuple[Any, ...]] = None, async_run_remote_workers_only: bool = False,
driver_kwargs: Optional[Dict[str, Any]] = None,
max_concurrent_workers: Optional[int] = None, max_concurrent_workers: Optional[int] = None,
**kwargs, **kwargs,
) -> Any: ) -> Any:
"""Runs the given method on all workers.""" """Runs the given method on all workers.
Args:
async_run_remote_workers_only: If True the method will be run only
in the remote workers, not the driver worker. It will also be
run asynchronously and return a list of futures rather than
blocking on the results.
"""
raise NotImplementedError
@abstractmethod
def _wait_for_tasks_completion(self, parallel_worker_tasks: Any) -> None:
"""Wait for futures returned from _run_workers() with
async_run_remote_workers_only to complete."""
raise NotImplementedError raise NotImplementedError
class DistributedGPUExecutorAsync(DistributedGPUExecutor, ExecutorAsyncBase): class DistributedGPUExecutorAsync(DistributedGPUExecutor, ExecutorAsyncBase):
@abstractmethod async def execute_model_async(
async def _run_workers_async( self,
self, execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
method: str, if self.parallel_worker_tasks is None:
*args, # Start model execution loop running in the parallel workers
driver_args: Optional[Tuple[Any, ...]] = None, self.parallel_worker_tasks = asyncio.create_task(
driver_kwargs: Optional[Dict[str, Any]] = None, self._start_worker_execution_loop())
**kwargs,
) -> Any:
"""Runs the given method on all workers."""
raise NotImplementedError
async def execute_model_async(self, *args,
**kwargs) -> List[SamplerOutput]:
all_outputs = await self._run_workers_async("execute_model",
driver_args=args,
driver_kwargs=kwargs)
# Only the driver worker returns the sampling results. # Only the driver worker returns the sampling results.
return all_outputs[0] return await self._driver_execute_model_async(execute_model_req)
async def stop_remote_worker_execution_loop_async(self) -> None:
if self.parallel_worker_tasks is None:
return
await self._driver_execute_model_async()
parallel_worker_tasks = self.parallel_worker_tasks
self.parallel_worker_tasks = None
# Ensure that workers exit model loop cleanly
# (this will raise otherwise)
await parallel_worker_tasks
@abstractmethod
async def _driver_execute_model_async(
self,
execute_model_req: Optional[ExecuteModelRequest] = None
) -> List[SamplerOutput]:
"""Execute the model asynchronously in the driver worker.
Passing None will cause the driver to stop the model execution
loop running in each of the remote workers.
"""
raise NotImplementedError
@abstractmethod
async def _start_worker_execution_loop(self):
"""Run execution loop on all workers. It guarantees all workers run
the loop or None of them is running the loop. Loop can be stopped by
`stop_remote_worker_execution_loop`.
The API is idempotent (guarantee only 1 loop run at any moment)."""
raise NotImplementedError

View File

@ -74,6 +74,10 @@ class ExecutorBase(ABC):
"""Executes at least one model step on the given sequences.""" """Executes at least one model step on the given sequences."""
raise NotImplementedError raise NotImplementedError
def stop_remote_worker_execution_loop(self) -> None:
"""Releases parallel workers from model loop."""
return
@abstractmethod @abstractmethod
def add_lora(self, lora_request: LoRARequest) -> bool: def add_lora(self, lora_request: LoRARequest) -> bool:
raise NotImplementedError raise NotImplementedError
@ -109,6 +113,10 @@ class ExecutorAsyncBase(ExecutorBase):
"""Executes one model step on the given sequences.""" """Executes one model step on the given sequences."""
raise NotImplementedError raise NotImplementedError
async def stop_remote_worker_execution_loop_async(self) -> None:
"""Releases parallel workers from model loop."""
return
async def check_health_async(self) -> None: async def check_health_async(self) -> None:
"""Checks if the executor is healthy. If not, it should raise an """Checks if the executor is healthy. If not, it should raise an
exception.""" exception."""

View File

@ -1,13 +1,14 @@
import asyncio import asyncio
import os import os
from functools import partial from functools import partial
from typing import Any, Dict, Optional, Tuple from typing import Any, List, Optional
from vllm.executor.distributed_gpu_executor import ( # yapf: disable from vllm.executor.distributed_gpu_executor import ( # yapf: disable
DistributedGPUExecutor, DistributedGPUExecutorAsync) DistributedGPUExecutor, DistributedGPUExecutorAsync)
from vllm.executor.multiproc_worker_utils import (ProcessWorkerWrapper, from vllm.executor.multiproc_worker_utils import (ProcessWorkerWrapper,
ResultHandler, WorkerMonitor) ResultHandler, WorkerMonitor)
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.sequence import ExecuteModelRequest, SamplerOutput
from vllm.utils import (get_distributed_init_method, get_ip, get_open_port, from vllm.utils import (get_distributed_init_method, get_ip, get_open_port,
get_vllm_instance_id, make_async) get_vllm_instance_id, make_async)
@ -71,16 +72,34 @@ class MultiprocessingGPUExecutor(DistributedGPUExecutor):
None)) is not None: None)) is not None:
worker_monitor.close() worker_monitor.close()
def _driver_execute_model(
self,
execute_model_req: Optional[ExecuteModelRequest] = None
) -> List[SamplerOutput]:
"""Run execute_model in the driver worker.
Passing None will cause the driver to stop the model execution
loop running in each of the remote workers.
"""
return self.driver_worker.execute_model(
execute_model_req=execute_model_req)
def _run_workers( def _run_workers(
self, self,
method: str, method: str,
*args, *args,
driver_args: Optional[Tuple[Any, ...]] = None, async_run_remote_workers_only: bool = False,
driver_kwargs: Optional[Dict[str, Any]] = None,
max_concurrent_workers: Optional[int] = None, max_concurrent_workers: Optional[int] = None,
**kwargs, **kwargs,
) -> Any: ) -> Any:
"""Runs the given method on all workers.""" """Runs the given method on all workers.
Args:
async_run_remote_workers_only: If True the method will be run only
in the remote workers, not the driver worker. It will also be
run asynchronously and return a list of futures rather than
blocking on the results.
"""
if max_concurrent_workers: if max_concurrent_workers:
raise NotImplementedError( raise NotImplementedError(
@ -92,15 +111,12 @@ class MultiprocessingGPUExecutor(DistributedGPUExecutor):
for worker in self.workers for worker in self.workers
] ]
if driver_args is None: if async_run_remote_workers_only:
driver_args = args # Just return futures
if driver_kwargs is None: return worker_outputs
driver_kwargs = kwargs
# Start the driver worker after all the ray workers.
driver_worker_method = getattr(self.driver_worker, method) driver_worker_method = getattr(self.driver_worker, method)
driver_worker_output = driver_worker_method(*driver_args, driver_worker_output = driver_worker_method(*args, **kwargs)
**driver_kwargs)
# Get the results of the workers. # Get the results of the workers.
return [driver_worker_output return [driver_worker_output
@ -111,30 +127,29 @@ class MultiprocessingGPUExecutor(DistributedGPUExecutor):
if not self.worker_monitor.is_alive(): if not self.worker_monitor.is_alive():
raise RuntimeError("Worker processes are not running") raise RuntimeError("Worker processes are not running")
def _wait_for_tasks_completion(self, parallel_worker_tasks: Any) -> None:
"""Wait for futures returned from _run_workers() with
async_run_remote_workers_only to complete."""
for result in parallel_worker_tasks:
result.get()
class MultiprocessingGPUExecutorAsync(MultiprocessingGPUExecutor, class MultiprocessingGPUExecutorAsync(MultiprocessingGPUExecutor,
DistributedGPUExecutorAsync): DistributedGPUExecutorAsync):
async def _run_workers_async( def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.driver_exec_model = make_async(self.driver_worker.execute_model)
async def _driver_execute_model_async(
self, self,
method: str, execute_model_req: Optional[ExecuteModelRequest] = None
*args, ) -> List[SamplerOutput]:
driver_args: Optional[Tuple[Any, ...]] = None, return await self.driver_exec_model(execute_model_req)
driver_kwargs: Optional[Dict[str, Any]] = None,
**kwargs,
) -> Any:
"""Runs the given method on all workers."""
if driver_args is None:
driver_args = args
if driver_kwargs is None:
driver_kwargs = kwargs
driver_executor = make_async(getattr(self.driver_worker, method)) async def _start_worker_execution_loop(self):
coros = [
# Run all the workers asynchronously. worker.execute_method_async("start_worker_execution_loop")
coros = [driver_executor(*driver_args, **driver_kwargs)] + [
worker.execute_method_async(method, *args, **kwargs)
for worker in self.workers for worker in self.workers
] ]
return await asyncio.gather(*coros) return await asyncio.gather(*coros)

View File

@ -42,6 +42,8 @@ class RayGPUExecutor(DistributedGPUExecutor):
self.forward_dag = None self.forward_dag = None
if USE_RAY_COMPILED_DAG: if USE_RAY_COMPILED_DAG:
self.forward_dag = self._compiled_ray_dag() self.forward_dag = self._compiled_ray_dag()
self.extra_execute_model_run_workers_kwargs[
"use_ray_compiled_dag"] = True
def _configure_ray_workers_use_nsight(self, def _configure_ray_workers_use_nsight(self,
ray_remote_kwargs) -> Dict[str, Any]: ray_remote_kwargs) -> Dict[str, Any]:
@ -171,23 +173,23 @@ class RayGPUExecutor(DistributedGPUExecutor):
max_concurrent_workers=self.parallel_config. max_concurrent_workers=self.parallel_config.
max_parallel_loading_workers) max_parallel_loading_workers)
def execute_model( def _driver_execute_model(
self, self,
execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]: execute_model_req: Optional[ExecuteModelRequest] = None
all_outputs = self._run_workers( ) -> List[SamplerOutput]:
"execute_model", """Run execute_model in the driver worker.
driver_kwargs={"execute_model_req": execute_model_req},
use_ray_compiled_dag=USE_RAY_COMPILED_DAG)
# Only the driver worker returns the sampling results. Passing None will cause the driver to stop the model execution
return all_outputs[0] loop running in each of the remote workers.
"""
return self.driver_worker.execute_method("execute_model",
execute_model_req)
def _run_workers( def _run_workers(
self, self,
method: str, method: str,
*args, *args,
driver_args: Optional[Tuple[Any, ...]] = None, async_run_remote_workers_only: bool = False,
driver_kwargs: Optional[Dict[str, Any]] = None,
all_args: Optional[List[Tuple[Any, ...]]] = None, all_args: Optional[List[Tuple[Any, ...]]] = None,
all_kwargs: Optional[List[Dict[str, Any]]] = None, all_kwargs: Optional[List[Dict[str, Any]]] = None,
use_dummy_driver: bool = False, use_dummy_driver: bool = False,
@ -198,9 +200,11 @@ class RayGPUExecutor(DistributedGPUExecutor):
"""Runs the given method on all workers. Can be used in the following """Runs the given method on all workers. Can be used in the following
ways: ways:
- async_run_remote_workers_only: If True the method will be run only
in the remote workers, not the driver worker. It will also be
run asynchronously and return a list of futures rather than blocking
on the results.
- args/kwargs: All workers share the same args/kwargs - args/kwargs: All workers share the same args/kwargs
- args/kwargs and driver_args/driver_kwargs: Driver worker has
different args
- all_args/all_kwargs: args/kwargs for each worker are specified - all_args/all_kwargs: args/kwargs for each worker are specified
individually individually
""" """
@ -209,11 +213,6 @@ class RayGPUExecutor(DistributedGPUExecutor):
raise NotImplementedError( raise NotImplementedError(
"max_concurrent_workers is not supported yet.") "max_concurrent_workers is not supported yet.")
if driver_args is None:
driver_args = args if all_args is None else all_args[0]
if driver_kwargs is None:
driver_kwargs = kwargs if all_kwargs is None else all_kwargs[0]
count = len(self.workers) count = len(self.workers)
all_worker_args = repeat(args, count) if all_args is None \ all_worker_args = repeat(args, count) if all_args is None \
else islice(all_args, 1, None) else islice(all_args, 1, None)
@ -225,6 +224,7 @@ class RayGPUExecutor(DistributedGPUExecutor):
# input. TODO(sang): Fix it. # input. TODO(sang): Fix it.
assert self.forward_dag is not None assert self.forward_dag is not None
output_channels = self.forward_dag.execute(1) output_channels = self.forward_dag.execute(1)
ray_worker_outputs = []
else: else:
# Start the ray workers first. # Start the ray workers first.
ray_worker_outputs = [ ray_worker_outputs = [
@ -234,6 +234,13 @@ class RayGPUExecutor(DistributedGPUExecutor):
) in zip(self.workers, all_worker_args, all_worker_kwargs) ) in zip(self.workers, all_worker_args, all_worker_kwargs)
] ]
if async_run_remote_workers_only:
# Just return futures
return ray_worker_outputs
driver_args = args if all_args is None else all_args[0]
driver_kwargs = kwargs if all_kwargs is None else all_kwargs[0]
# Start the driver worker after all the ray workers. # Start the driver worker after all the ray workers.
if not use_dummy_driver: if not use_dummy_driver:
driver_worker_output = self.driver_worker.execute_method( driver_worker_output = self.driver_worker.execute_method(
@ -260,6 +267,11 @@ class RayGPUExecutor(DistributedGPUExecutor):
return [driver_worker_output] + ray_worker_outputs return [driver_worker_output] + ray_worker_outputs
def _wait_for_tasks_completion(self, parallel_worker_tasks: Any) -> None:
"""Wait for futures returned from _run_workers() with
async_run_remote_workers_only to complete."""
ray.get(parallel_worker_tasks)
def _compiled_ray_dag(self): def _compiled_ray_dag(self):
import pkg_resources import pkg_resources
required_version = "2.9" required_version = "2.9"
@ -303,30 +315,18 @@ class RayGPUExecutorAsync(RayGPUExecutor, DistributedGPUExecutorAsync):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self.driver_executor = make_async(self.driver_worker.execute_method) self.driver_exec_method = make_async(self.driver_worker.execute_method)
async def _run_workers_async( async def _driver_execute_model_async(
self, self,
method: str, execute_model_req: Optional[ExecuteModelRequest] = None
*args, ) -> List[SamplerOutput]:
driver_args: Optional[Tuple[Any, ...]] = None, return await self.driver_exec_method("execute_model",
driver_kwargs: Optional[Dict[str, Any]] = None, execute_model_req)
**kwargs,
) -> Any:
"""Runs the given method on all workers."""
coros = []
if driver_args is None: async def _start_worker_execution_loop(self):
driver_args = args coros = [
if driver_kwargs is None: worker.execute_method.remote("start_worker_execution_loop")
driver_kwargs = kwargs for worker in self.workers
]
coros.append( return await asyncio.gather(*coros)
self.driver_executor(method, *driver_args, **driver_kwargs))
# Run the ray workers asynchronously.
for worker in self.workers:
coros.append(worker.execute_method.remote(method, *args, **kwargs))
all_outputs = await asyncio.gather(*coros)
return all_outputs

View File

@ -47,7 +47,9 @@ class NGramWorker(LoraNotSupportedWorkerBase):
# NGram don't need gpu sampler # NGram don't need gpu sampler
pass pass
def execute_model(self, execute_model_req: ExecuteModelRequest) -> None: def execute_model(
self,
execute_model_req: Optional[ExecuteModelRequest] = None) -> None:
"""NGram doesn't depend on model execution, just pass this function""" """NGram doesn't depend on model execution, just pass this function"""
pass pass

View File

@ -231,35 +231,6 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
self.proposer_worker.initialize_cache(num_gpu_blocks=num_gpu_blocks, self.proposer_worker.initialize_cache(num_gpu_blocks=num_gpu_blocks,
num_cpu_blocks=num_cpu_blocks) num_cpu_blocks=num_cpu_blocks)
def _broadcast_control_flow_decision(
self,
execute_model_req: Optional[ExecuteModelRequest] = None,
disable_all_speculation: bool = False) -> Tuple[int, bool]:
"""Broadcast how many lookahead slots are scheduled for this step, and
whether all speculation is disabled, to all non-driver workers.
This is required as if the number of draft model runs changes
dynamically, the non-driver workers won't know unless we perform a
communication to inform then.
Returns the broadcasted num_lookahead_slots and disable_all_speculation.
"""
if self.rank == self._driver_rank:
assert execute_model_req is not None
broadcast_dict = dict(
num_lookahead_slots=execute_model_req.num_lookahead_slots,
disable_all_speculation=disable_all_speculation,
)
broadcast_tensor_dict(broadcast_dict, src=self._driver_rank)
else:
assert execute_model_req is None
broadcast_dict = broadcast_tensor_dict(src=self._driver_rank)
return (broadcast_dict["num_lookahead_slots"],
broadcast_dict["disable_all_speculation"])
@torch.inference_mode() @torch.inference_mode()
def execute_model( def execute_model(
self, self,
@ -267,39 +238,58 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
) -> List[SamplerOutput]: ) -> List[SamplerOutput]:
"""Perform speculative decoding on the input batch. """Perform speculative decoding on the input batch.
""" """
if self.rank != self._driver_rank:
disable_all_speculation = False self._run_non_driver_rank()
if self.rank == self._driver_rank:
disable_all_speculation = self._should_disable_all_speculation(
execute_model_req)
(num_lookahead_slots,
disable_all_speculation) = self._broadcast_control_flow_decision(
execute_model_req, disable_all_speculation)
if self.rank == self._driver_rank:
assert execute_model_req is not None
assert execute_model_req.seq_group_metadata_list is not None, (
"speculative decoding requires non-None seq_group_metadata_list"
)
self._maybe_disable_speculative_tokens(
disable_all_speculation,
execute_model_req.seq_group_metadata_list)
# If no spec tokens, call the proposer and scorer workers normally.
# Used for prefill.
if num_lookahead_slots == 0 or len(
execute_model_req.seq_group_metadata_list) == 0:
return self._run_no_spec(execute_model_req,
skip_proposer=disable_all_speculation)
return self._run_speculative_decoding_step(execute_model_req,
num_lookahead_slots)
else:
self._run_non_driver_rank(num_lookahead_slots)
return [] return []
if execute_model_req is None:
# This signals that there's no more requests to process for now.
# All workers are running infinite loop with broadcast_tensor_dict,
# and it stops the loop when the driver broadcasts an empty input.
# Send an empty input to notify all other workers to stop their
# execution loop.
broadcast_tensor_dict({}, src=0)
return []
disable_all_speculation = self._should_disable_all_speculation(
execute_model_req)
num_lookahead_slots = execute_model_req.num_lookahead_slots
# Broadcast how many lookahead slots are scheduled for this step, and
# whether all speculation is disabled, to all non-driver workers.
# This is required as if the number of draft model runs changes
# dynamically, the non-driver workers won't know unless we perform a
# communication to inform then.
broadcast_dict = dict(
num_lookahead_slots=num_lookahead_slots,
disable_all_speculation=disable_all_speculation,
)
broadcast_tensor_dict(broadcast_dict, src=self._driver_rank)
assert execute_model_req.seq_group_metadata_list is not None, (
"speculative decoding requires non-None seq_group_metadata_list")
self._maybe_disable_speculative_tokens(
disable_all_speculation, execute_model_req.seq_group_metadata_list)
# If no spec tokens, call the proposer and scorer workers normally.
# Used for prefill.
if num_lookahead_slots == 0 or len(
execute_model_req.seq_group_metadata_list) == 0:
return self._run_no_spec(execute_model_req,
skip_proposer=disable_all_speculation)
return self._run_speculative_decoding_step(execute_model_req,
num_lookahead_slots)
@torch.inference_mode()
def start_worker_execution_loop(self) -> None:
"""Execute model loop to perform speculative decoding
in parallel worker."""
while self._run_non_driver_rank():
pass
def _should_disable_all_speculation( def _should_disable_all_speculation(
self, execute_model_req: ExecuteModelRequest) -> bool: self, execute_model_req: ExecuteModelRequest) -> bool:
# When the batch size is too large, disable speculative decoding # When the batch size is too large, disable speculative decoding
@ -346,13 +336,19 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
sampler_output.logprobs = None sampler_output.logprobs = None
return [sampler_output] return [sampler_output]
def _run_non_driver_rank(self, num_lookahead_slots: int) -> None: def _run_non_driver_rank(self) -> bool:
"""Run proposer and verifier model in non-driver workers. This is used """Run proposer and verifier model in non-driver workers. This is used
for both speculation cases (num_lookahead_slots>0) and non-speculation for both speculation cases (num_lookahead_slots>0) and non-speculation
cases (e.g. prefill). cases (e.g. prefill).
Returns True iff there are remaining sequences to process.
""" """
# In non-driver workers the input is None assert self.rank != self._driver_rank
execute_model_req = None
data = broadcast_tensor_dict(src=self._driver_rank)
if not data:
return False
num_lookahead_slots = data["num_lookahead_slots"]
# Even if num_lookahead_slots is zero, we want to run the proposer model # Even if num_lookahead_slots is zero, we want to run the proposer model
# as it may have KV. # as it may have KV.
@ -360,9 +356,10 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
# We run the proposer once per lookahead slot. In the future we should # We run the proposer once per lookahead slot. In the future we should
# delegate how many times it runs to the proposer. # delegate how many times it runs to the proposer.
for _ in range(max(num_lookahead_slots, 1)): for _ in range(max(num_lookahead_slots, 1)):
self.proposer_worker.execute_model(execute_model_req) self.proposer_worker.execute_model()
self.scorer_worker.execute_model(execute_model_req) self.scorer_worker.execute_model()
return True
@nvtx_range("spec_decode_worker._run_speculative_decoding_step") @nvtx_range("spec_decode_worker._run_speculative_decoding_step")
def _run_speculative_decoding_step( def _run_speculative_decoding_step(

View File

@ -47,7 +47,7 @@ class EmbeddingModelRunner(ModelRunner):
@torch.inference_mode() @torch.inference_mode()
def execute_model( def execute_model(
self, self,
seq_group_metadata_list: List[SequenceGroupMetadata], seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
) -> Optional[PoolerOutput]: ) -> Optional[PoolerOutput]:
(input_tokens, input_positions, attn_metadata, pooling_metadata, (input_tokens, input_positions, attn_metadata, pooling_metadata,
@ -84,10 +84,11 @@ class EmbeddingModelRunner(ModelRunner):
def prepare_input_tensors( def prepare_input_tensors(
self, self,
seq_group_metadata_list: List[SequenceGroupMetadata], seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, PoolingMetadata, ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, PoolingMetadata,
Set[LoRARequest], LoRAMapping, torch.Tensor]: Set[LoRARequest], LoRAMapping, torch.Tensor]:
if self.is_driver_worker: if self.is_driver_worker:
assert seq_group_metadata_list is not None
# Prepare input tensors. # Prepare input tensors.
( (
input_tokens, input_tokens,

View File

@ -609,10 +609,11 @@ class ModelRunner:
def prepare_input_tensors( def prepare_input_tensors(
self, self,
seq_group_metadata_list: List[SequenceGroupMetadata], seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, SamplingMetadata, ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, SamplingMetadata,
Set[LoRARequest], LoRAMapping, torch.Tensor]: Set[LoRARequest], LoRAMapping, torch.Tensor]:
if self.is_driver_worker: if self.is_driver_worker:
assert seq_group_metadata_list is not None
# Prepare input tensors. # Prepare input tensors.
( (
input_tokens, input_tokens,
@ -676,7 +677,7 @@ class ModelRunner:
@torch.inference_mode() @torch.inference_mode()
def execute_model( def execute_model(
self, self,
seq_group_metadata_list: List[SequenceGroupMetadata], seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
) -> Optional[SamplerOutput]: ) -> Optional[SamplerOutput]:
(input_tokens, input_positions, attn_metadata, sampling_metadata, (input_tokens, input_positions, attn_metadata, sampling_metadata,

View File

@ -226,48 +226,42 @@ class Worker(WorkerBase):
self, self,
execute_model_req: Optional[ExecuteModelRequest] = None execute_model_req: Optional[ExecuteModelRequest] = None
) -> List[Union[SamplerOutput, PoolerOutput]]: ) -> List[Union[SamplerOutput, PoolerOutput]]:
if not self.is_driver_worker:
self._execute_model_non_driver()
return []
if execute_model_req is None: if execute_model_req is None:
seq_group_metadata_list = None # This signals that there's no more requests to process for now.
else: # All workers are running infinite loop with broadcast_tensor_dict,
seq_group_metadata_list = execute_model_req.seq_group_metadata_list # and it stops the loop when the driver broadcasts an empty input.
# Send an empty input to notify all other workers to stop their
# execution loop.
broadcast_tensor_dict({}, src=0)
return []
blocks_to_swap_in: torch.Tensor seq_group_metadata_list = execute_model_req.seq_group_metadata_list
blocks_to_swap_out: torch.Tensor num_seq_groups = len(seq_group_metadata_list)
blocks_to_copy: torch.Tensor # `blocks_to_swap_in` and `blocks_to_swap_out` are cpu tensors.
if self.is_driver_worker: # they contain parameters to launch cudamemcpyasync.
assert seq_group_metadata_list is not None blocks_to_swap_in = torch.tensor(execute_model_req.blocks_to_swap_in,
assert execute_model_req is not None device="cpu",
num_seq_groups = len(seq_group_metadata_list) dtype=torch.int64).view(-1, 2)
# `blocks_to_swap_in` and `blocks_to_swap_out` are cpu tensors. blocks_to_swap_out = torch.tensor(execute_model_req.blocks_to_swap_out,
# they contain parameters to launch cudamemcpyasync. device="cpu",
blocks_to_swap_in = torch.tensor(
execute_model_req.blocks_to_swap_in,
device="cpu",
dtype=torch.int64).view(-1, 2)
blocks_to_swap_out = torch.tensor(
execute_model_req.blocks_to_swap_out,
device="cpu",
dtype=torch.int64).view(-1, 2)
# `blocks_to_copy` is a gpu tensor. The src and tgt of
# blocks to copy are in the same device, and `blocks_to_copy`
# can be used directly within cuda kernels.
blocks_to_copy = torch.tensor(execute_model_req.blocks_to_copy,
device=self.device,
dtype=torch.int64).view(-1, 2) dtype=torch.int64).view(-1, 2)
data: Dict[str, Any] = { # `blocks_to_copy` is a gpu tensor. The src and tgt of
"num_seq_groups": num_seq_groups, # blocks to copy are in the same device, and `blocks_to_copy`
"blocks_to_swap_in": blocks_to_swap_in, # can be used directly within cuda kernels.
"blocks_to_swap_out": blocks_to_swap_out, blocks_to_copy = torch.tensor(execute_model_req.blocks_to_copy,
"blocks_to_copy": blocks_to_copy, device=self.device,
} dtype=torch.int64).view(-1, 2)
broadcast_tensor_dict(data, src=0) data: Dict[str, Any] = {
else: "num_seq_groups": num_seq_groups,
data = broadcast_tensor_dict(src=0) "blocks_to_swap_in": blocks_to_swap_in,
num_seq_groups = data["num_seq_groups"] "blocks_to_swap_out": blocks_to_swap_out,
blocks_to_swap_in = data["blocks_to_swap_in"] "blocks_to_copy": blocks_to_copy,
blocks_to_swap_out = data["blocks_to_swap_out"] }
blocks_to_copy = data["blocks_to_copy"] broadcast_tensor_dict(data, src=0)
self.cache_swap(blocks_to_swap_in, blocks_to_swap_out, blocks_to_copy) self.cache_swap(blocks_to_swap_in, blocks_to_swap_out, blocks_to_copy)
@ -282,6 +276,39 @@ class Worker(WorkerBase):
# to conform to interface. # to conform to interface.
return [output] return [output]
@torch.inference_mode()
def start_worker_execution_loop(self) -> None:
"""Execute model loop in parallel worker.
You can stop the loop by executing a driver worker with an empty output.
See `stop_remote_worker_execution_loop` for more details.
"""
while self._execute_model_non_driver():
pass
def _execute_model_non_driver(self) -> bool:
"""Execute model in parallel worker.
Returns True iff there are remaining sequences to process.
"""
assert not self.is_driver_worker
data = broadcast_tensor_dict(src=0)
if not data:
return False
num_seq_groups = data.get("num_seq_groups", 0)
blocks_to_swap_in = data.get("blocks_to_swap_in")
blocks_to_swap_out = data.get("blocks_to_swap_out")
blocks_to_copy = data.get("blocks_to_copy")
self.cache_swap(blocks_to_swap_in, blocks_to_swap_out, blocks_to_copy)
# If there is no input, we don't need to execute the model.
if num_seq_groups == 0:
return False
self.model_runner.execute_model(None, self.gpu_cache)
return True
def add_lora(self, lora_request: LoRARequest) -> bool: def add_lora(self, lora_request: LoRARequest) -> bool:
return self.model_runner.add_lora(lora_request) return self.model_runner.add_lora(lora_request)

View File

@ -1,7 +1,7 @@
import importlib import importlib
import os import os
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Dict, List, Set, Tuple from typing import Dict, List, Optional, Set, Tuple
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
@ -48,8 +48,9 @@ class WorkerBase(ABC):
@abstractmethod @abstractmethod
def execute_model( def execute_model(
self, self,
execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]: execute_model_req: Optional[ExecuteModelRequest] = None
) -> List[SamplerOutput]:
"""Executes at least one model step on the given sequences, unless no """Executes at least one model step on the given sequences, unless no
sequences are provided.""" sequences are provided."""
raise NotImplementedError raise NotImplementedError