Prune Ray v1 non-SPMD code paths

This commit is contained in:
Simon Mo 2025-09-18 20:42:46 -07:00
parent 07665f8679
commit 85013bf094

View File

@ -1,7 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import asyncio
import os
import threading
from collections import defaultdict
@ -22,8 +21,8 @@ from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.platforms import current_platform
from vllm.ray.ray_env import get_env_vars_to_copy
from vllm.sequence import ExecuteModelRequest
from vllm.utils import (_run_task_with_lock, get_distributed_init_method,
get_ip, get_open_port, make_async)
from vllm.utils import (get_distributed_init_method, get_ip, get_open_port,
make_async)
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.engine import ReconfigureDistributedRequest, ReconfigureRankType
from vllm.v1.executor.abstract import Executor
@ -155,11 +154,6 @@ class RayDistributedExecutor(DistributedExecutorBase, Executor):
self.input_encoder = None
self.output_decoder = None
self.pp_locks: Optional[List[asyncio.Lock]] = None
if not self.use_ray_compiled_dag:
self.driver_exec_method = make_async(
self.driver_worker.execute_method)
# KV connector setup
self.has_connector = self.vllm_config.kv_transfer_config is not None
@ -177,11 +171,6 @@ class RayDistributedExecutor(DistributedExecutorBase, Executor):
) -> Union[ModelRunnerOutput, Future[ModelRunnerOutput]]:
"""Execute the model on the Ray workers."""
if not self.use_ray_spmd_worker:
raise RuntimeError(
"RayDistributedExecutor in v1 requires "
"VLLM_USE_RAY_SPMD_WORKER=1")
# Build the compiled DAG for the first time.
if self.forward_dag is None:
self.forward_dag = self._compiled_ray_dag(enable_asyncio=False)
@ -247,10 +236,7 @@ class RayDistributedExecutor(DistributedExecutorBase, Executor):
**ray_remote_kwargs):
num_gpus = envs.VLLM_RAY_PER_WORKER_GPUS
# The driver dummy worker does not actually use any resources.
# It holds the resource for the driver worker.
self.driver_dummy_worker: Optional[RayWorkerWrapper] = None
# The remaining workers are the actual ray actors.
# Ray actors that perform all model execution.
self.workers: List[RayWorkerWrapper] = []
# Used in ray compiled DAG: indexed first by PP rank,
@ -323,28 +309,7 @@ class RayDistributedExecutor(DistributedExecutorBase, Executor):
for each, ip in zip(worker_metadata, worker_ips):
each.ip = ip
if not self.use_ray_spmd_worker:
for i, each in enumerate(worker_metadata):
# find and remove the dummy worker from the list
worker = each.worker
worker_ip = each.ip
if self.driver_dummy_worker is None and worker_ip == driver_ip:
# If the worker is on the same node as the driver, we use it
# as the resource holder for the driver process.
self.driver_dummy_worker = worker
self.driver_worker = RayWorkerWrapper(
vllm_config=self.vllm_config, rpc_rank=0)
worker_metadata.pop(i)
break
logger.debug("workers: %s", worker_metadata)
logger.debug("driver_dummy_worker: %s", self.driver_dummy_worker)
if not self.use_ray_spmd_worker and self.driver_dummy_worker is None:
raise ValueError(
"Ray does not allocate any GPUs on the driver node."
f"Driver IP: {driver_ip}, worker IPs: {worker_ips}."
"Consider adjusting the Ray placement group or running "
"the driver on a GPU node.")
ip_counts: Dict[str, int] = {}
for ip in worker_ips:
@ -369,7 +334,7 @@ class RayDistributedExecutor(DistributedExecutorBase, Executor):
# node will be placed first.
sorted_worker_metadata = sorted(worker_metadata,
key=sort_by_driver_then_worker_ip)
start_rank = 0 if self.use_ray_spmd_worker else 1
start_rank = 0
for i, item in enumerate(sorted_worker_metadata):
item.adjusted_rank = i + start_rank
self.workers = [item.worker for item in sorted_worker_metadata]
@ -381,10 +346,7 @@ class RayDistributedExecutor(DistributedExecutorBase, Executor):
# Get the set of GPU IDs used on each node.
worker_node_and_gpu_ids = []
for worker in [self.driver_dummy_worker] + self.workers:
if worker is None:
# driver_dummy_worker can be None when using ray spmd worker.
continue
for worker in self.workers:
worker_node_and_gpu_ids.append(
ray.get(worker.get_node_and_gpu_ids.remote())
) # type: ignore
@ -476,46 +438,23 @@ class RayDistributedExecutor(DistributedExecutorBase, Executor):
max_concurrent_workers=self.parallel_config.
max_parallel_loading_workers)
if self.use_ray_spmd_worker:
for pp_rank in range(self.parallel_config.pipeline_parallel_size):
self.pp_tp_workers.append([])
for tp_rank in range(
self.parallel_config.tensor_parallel_size):
# PP=2, TP=4
# pp_tp_workers = [[0, 1, 2, 3], [4, 5, 6, 7]]
rank = (pp_rank * self.parallel_config.tensor_parallel_size
) + tp_rank
assert len(self.pp_tp_workers[pp_rank]) == tp_rank
assert pp_rank < len(self.pp_tp_workers)
self.pp_tp_workers[pp_rank].append(self.workers[rank])
# This is the list of workers that are rank 0 of each TP group EXCEPT
# global rank 0. These are the workers that will broadcast to the
# rest of the workers.
self.tp_driver_workers: List[RayWorkerWrapper] = []
# This is the list of workers that are not drivers and not the first
# worker in a TP group. These are the workers that will be
# broadcasted to.
self.non_driver_workers: List[RayWorkerWrapper] = []
# Enforce rank order for correct rank to return final output.
for index, worker in enumerate(self.workers):
# The driver worker is rank 0 and not in self.workers.
rank = index + 1
if rank % self.parallel_config.tensor_parallel_size == 0:
self.tp_driver_workers.append(worker)
else:
self.non_driver_workers.append(worker)
for pp_rank in range(self.parallel_config.pipeline_parallel_size):
self.pp_tp_workers.append([])
for tp_rank in range(self.parallel_config.tensor_parallel_size):
# PP=2, TP=4
# pp_tp_workers = [[0, 1, 2, 3], [4, 5, 6, 7]]
rank = (pp_rank * self.parallel_config.tensor_parallel_size
) + tp_rank
assert len(self.pp_tp_workers[pp_rank]) == tp_rank
assert pp_rank < len(self.pp_tp_workers)
self.pp_tp_workers[pp_rank].append(self.workers[rank])
def _driver_execute_model(
self, execute_model_req: Optional[ExecuteModelRequest]
) -> Optional[List[SamplerOutput]]:
"""Run execute_model in the driver worker."""
assert not self.use_ray_spmd_worker, (
"driver_worker does not exist for VLLM_USE_RAY_SPMD_WORKER=1")
return self.driver_worker.execute_method("execute_model",
execute_model_req)
raise RuntimeError(
"RayDistributedExecutor only supports compiled DAG execution "
"and does not expose a separate driver worker loop.")
def _run_workers(
self,
@ -542,33 +481,16 @@ class RayDistributedExecutor(DistributedExecutorBase, Executor):
"max_concurrent_workers is not supported yet.")
# Start the ray workers first.
ray_workers = self.workers
if async_run_tensor_parallel_workers_only:
ray_workers = self.non_driver_workers
ray_worker_outputs = [
worker.execute_method.remote(sent_method, *args, **kwargs)
for worker in ray_workers
for worker in self.workers
]
if async_run_tensor_parallel_workers_only:
# Just return futures
return ray_worker_outputs
driver_worker_output = []
# In SPMD mode, the driver worker is the same as any other worker,
# so we only explicitly execute on the driver worker if using a
# non-SPMD worker class.
if not self.use_ray_spmd_worker:
# Start the driver worker after all the ray workers.
driver_worker_output = [
self.driver_worker.execute_method(sent_method, *args, **kwargs)
]
if not self.workers:
return []
# Get the results of the ray workers.
if self.workers:
ray_worker_outputs = ray.get(ray_worker_outputs)
return driver_worker_output + ray_worker_outputs
return ray.get(ray_worker_outputs)
def _wait_for_tasks_completion(self, parallel_worker_tasks: Any) -> None:
"""Wait for futures returned from _run_workers()."""
@ -674,43 +596,14 @@ class RayDistributedExecutor(DistributedExecutorBase, Executor):
self,
execute_model_req: Optional[ExecuteModelRequest] = None
) -> List[SamplerOutput]:
assert not self.use_ray_spmd_worker, (
"driver_worker does not exist for VLLM_USE_RAY_SPMD_WORKER=1")
if not self.tp_driver_workers:
return await self.driver_exec_method("execute_model",
execute_model_req)
if self.pp_locks is None:
self.pp_locks = [
asyncio.Lock()
for _ in range(self.parallel_config.pipeline_parallel_size)
]
tasks = [
asyncio.create_task(
_run_task_with_lock(self.driver_exec_method, self.pp_locks[0],
"execute_model", execute_model_req))
]
for pp_rank, driver_worker in enumerate(self.tp_driver_workers,
start=1):
tasks.append(
asyncio.create_task(
_run_task_with_lock(driver_worker.execute_method.remote,
self.pp_locks[pp_rank],
"execute_model", execute_model_req)))
results = await asyncio.gather(*tasks)
# Only the last PP stage has the final results.
return results[-1]
raise RuntimeError(
"RayDistributedExecutor only supports compiled DAG execution "
"and does not expose a separate driver worker loop.")
async def _start_worker_execution_loop(self):
assert not self.use_ray_spmd_worker, (
"worker loop is disabled for VLLM_USE_RAY_SPMD_WORKER=1")
coros = [
worker.execute_method.remote("start_worker_execution_loop")
for worker in self.non_driver_workers
]
return await asyncio.gather(*coros)
raise RuntimeError(
"RayDistributedExecutor only supports compiled DAG execution "
"and does not expose a separate driver worker loop.")
def check_health(self) -> None:
# Assume that the Ray workers are healthy.