mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 17:25:29 +08:00
[core][misc] remove use_dummy driver for _run_workers (#10920)
Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
parent
1b62745b1d
commit
7be15d9356
@ -188,8 +188,14 @@ class RayGPUExecutor(DistributedGPUExecutor):
|
|||||||
self.workers = sorted(self.workers, key=sort_by_driver_then_worker_ip)
|
self.workers = sorted(self.workers, key=sort_by_driver_then_worker_ip)
|
||||||
|
|
||||||
# Get the set of GPU IDs used on each node.
|
# Get the set of GPU IDs used on each node.
|
||||||
worker_node_and_gpu_ids = self._run_workers("get_node_and_gpu_ids",
|
worker_node_and_gpu_ids = []
|
||||||
use_dummy_driver=True)
|
for worker in [self.driver_dummy_worker] + self.workers:
|
||||||
|
if worker is None:
|
||||||
|
# driver_dummy_worker can be None when using ray spmd worker.
|
||||||
|
continue
|
||||||
|
worker_node_and_gpu_ids.append(
|
||||||
|
ray.get(worker.get_node_and_gpu_ids.remote()) \
|
||||||
|
) # type: ignore
|
||||||
|
|
||||||
node_workers = defaultdict(list) # node id -> list of worker ranks
|
node_workers = defaultdict(list) # node id -> list of worker ranks
|
||||||
node_gpus = defaultdict(list) # node id -> list of gpu ids
|
node_gpus = defaultdict(list) # node id -> list of gpu ids
|
||||||
@ -329,7 +335,6 @@ class RayGPUExecutor(DistributedGPUExecutor):
|
|||||||
async_run_tensor_parallel_workers_only: bool = False,
|
async_run_tensor_parallel_workers_only: bool = False,
|
||||||
all_args: Optional[List[Tuple[Any, ...]]] = None,
|
all_args: Optional[List[Tuple[Any, ...]]] = None,
|
||||||
all_kwargs: Optional[List[Dict[str, Any]]] = None,
|
all_kwargs: Optional[List[Dict[str, Any]]] = None,
|
||||||
use_dummy_driver: bool = False,
|
|
||||||
max_concurrent_workers: Optional[int] = None,
|
max_concurrent_workers: Optional[int] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> Any:
|
) -> Any:
|
||||||
@ -389,18 +394,10 @@ class RayGPUExecutor(DistributedGPUExecutor):
|
|||||||
driver_kwargs = kwargs if all_kwargs is None else all_kwargs[0]
|
driver_kwargs = kwargs if all_kwargs is None else all_kwargs[0]
|
||||||
|
|
||||||
# Start the driver worker after all the ray workers.
|
# Start the driver worker after all the ray workers.
|
||||||
if not use_dummy_driver:
|
driver_worker_output = [
|
||||||
driver_worker_output = [
|
self.driver_worker.execute_method(method, *driver_args,
|
||||||
self.driver_worker.execute_method(method, *driver_args,
|
**driver_kwargs)
|
||||||
**driver_kwargs)
|
]
|
||||||
]
|
|
||||||
else:
|
|
||||||
assert self.driver_dummy_worker is not None
|
|
||||||
driver_worker_output = [
|
|
||||||
ray.get(
|
|
||||||
self.driver_dummy_worker.execute_method.remote(
|
|
||||||
method, *driver_args, **driver_kwargs))
|
|
||||||
]
|
|
||||||
|
|
||||||
# Get the results of the ray workers.
|
# Get the results of the ray workers.
|
||||||
if self.workers:
|
if self.workers:
|
||||||
|
|||||||
@ -163,9 +163,14 @@ class RayHPUExecutor(DistributedGPUExecutor):
|
|||||||
# node will be placed first.
|
# node will be placed first.
|
||||||
self.workers = sorted(self.workers, key=sort_by_driver_then_worker_ip)
|
self.workers = sorted(self.workers, key=sort_by_driver_then_worker_ip)
|
||||||
|
|
||||||
# Get the set of GPU IDs used on each node.
|
worker_node_and_gpu_ids = []
|
||||||
worker_node_and_gpu_ids = self._run_workers("get_node_and_gpu_ids",
|
for worker in [self.driver_dummy_worker] + self.workers:
|
||||||
use_dummy_driver=True)
|
if worker is None:
|
||||||
|
# driver_dummy_worker can be None when using ray spmd worker.
|
||||||
|
continue
|
||||||
|
worker_node_and_gpu_ids.append(
|
||||||
|
ray.get(worker.get_node_and_gpu_ids.remote()) \
|
||||||
|
) # type: ignore
|
||||||
|
|
||||||
node_workers = defaultdict(list) # node id -> list of worker ranks
|
node_workers = defaultdict(list) # node id -> list of worker ranks
|
||||||
node_gpus = defaultdict(list) # node id -> list of gpu ids
|
node_gpus = defaultdict(list) # node id -> list of gpu ids
|
||||||
@ -296,7 +301,6 @@ class RayHPUExecutor(DistributedGPUExecutor):
|
|||||||
async_run_tensor_parallel_workers_only: bool = False,
|
async_run_tensor_parallel_workers_only: bool = False,
|
||||||
all_args: Optional[List[Tuple[Any, ...]]] = None,
|
all_args: Optional[List[Tuple[Any, ...]]] = None,
|
||||||
all_kwargs: Optional[List[Dict[str, Any]]] = None,
|
all_kwargs: Optional[List[Dict[str, Any]]] = None,
|
||||||
use_dummy_driver: bool = False,
|
|
||||||
max_concurrent_workers: Optional[int] = None,
|
max_concurrent_workers: Optional[int] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> Any:
|
) -> Any:
|
||||||
@ -356,18 +360,10 @@ class RayHPUExecutor(DistributedGPUExecutor):
|
|||||||
driver_kwargs = kwargs if all_kwargs is None else all_kwargs[0]
|
driver_kwargs = kwargs if all_kwargs is None else all_kwargs[0]
|
||||||
|
|
||||||
# Start the driver worker after all the ray workers.
|
# Start the driver worker after all the ray workers.
|
||||||
if not use_dummy_driver:
|
driver_worker_output = [
|
||||||
driver_worker_output = [
|
self.driver_worker.execute_method(method, *driver_args,
|
||||||
self.driver_worker.execute_method(method, *driver_args,
|
**driver_kwargs)
|
||||||
**driver_kwargs)
|
]
|
||||||
]
|
|
||||||
else:
|
|
||||||
assert self.driver_dummy_worker is not None
|
|
||||||
driver_worker_output = [
|
|
||||||
ray.get(
|
|
||||||
self.driver_dummy_worker.execute_method.remote(
|
|
||||||
method, *driver_args, **driver_kwargs))
|
|
||||||
]
|
|
||||||
|
|
||||||
# Get the results of the ray workers.
|
# Get the results of the ray workers.
|
||||||
if self.workers:
|
if self.workers:
|
||||||
|
|||||||
@ -137,8 +137,14 @@ class RayTPUExecutor(TPUExecutor):
|
|||||||
self.workers = sorted(self.workers, key=sort_by_driver_then_worker_ip)
|
self.workers = sorted(self.workers, key=sort_by_driver_then_worker_ip)
|
||||||
|
|
||||||
# Get the set of TPU IDs used on each node.
|
# Get the set of TPU IDs used on each node.
|
||||||
worker_node_and_gpu_ids = self._run_workers("get_node_and_gpu_ids",
|
worker_node_and_gpu_ids = []
|
||||||
use_dummy_driver=True)
|
for worker in [self.driver_dummy_worker] + self.workers:
|
||||||
|
if worker is None:
|
||||||
|
# driver_dummy_worker can be None when using ray spmd worker.
|
||||||
|
continue
|
||||||
|
worker_node_and_gpu_ids.append(
|
||||||
|
ray.get(worker.get_node_and_gpu_ids.remote()) \
|
||||||
|
) # type: ignore
|
||||||
|
|
||||||
node_workers = defaultdict(list)
|
node_workers = defaultdict(list)
|
||||||
for i, (node_id, _) in enumerate(worker_node_and_gpu_ids):
|
for i, (node_id, _) in enumerate(worker_node_and_gpu_ids):
|
||||||
@ -199,7 +205,6 @@ class RayTPUExecutor(TPUExecutor):
|
|||||||
async_run_remote_workers_only: bool = False,
|
async_run_remote_workers_only: bool = False,
|
||||||
all_args: Optional[List[Tuple[Any, ...]]] = None,
|
all_args: Optional[List[Tuple[Any, ...]]] = None,
|
||||||
all_kwargs: Optional[List[Dict[str, Any]]] = None,
|
all_kwargs: Optional[List[Dict[str, Any]]] = None,
|
||||||
use_dummy_driver: bool = False,
|
|
||||||
max_concurrent_workers: Optional[int] = None,
|
max_concurrent_workers: Optional[int] = None,
|
||||||
use_ray_compiled_dag: bool = False,
|
use_ray_compiled_dag: bool = False,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
@ -241,14 +246,8 @@ class RayTPUExecutor(TPUExecutor):
|
|||||||
driver_kwargs = kwargs if all_kwargs is None else all_kwargs[0]
|
driver_kwargs = kwargs if all_kwargs is None else all_kwargs[0]
|
||||||
|
|
||||||
# Start the driver worker after all the ray workers.
|
# Start the driver worker after all the ray workers.
|
||||||
if not use_dummy_driver:
|
driver_worker_output = self.driver_worker.execute_method(
|
||||||
driver_worker_output = self.driver_worker.execute_method(
|
method, *driver_args, **driver_kwargs)
|
||||||
method, *driver_args, **driver_kwargs)
|
|
||||||
else:
|
|
||||||
assert self.driver_dummy_worker is not None
|
|
||||||
driver_worker_output = ray.get(
|
|
||||||
self.driver_dummy_worker.execute_method.remote(
|
|
||||||
method, *driver_args, **driver_kwargs))
|
|
||||||
# Get the results of the ray workers.
|
# Get the results of the ray workers.
|
||||||
if self.workers:
|
if self.workers:
|
||||||
ray_worker_outputs = ray.get(ray_worker_outputs)
|
ray_worker_outputs = ray.get(ray_worker_outputs)
|
||||||
|
|||||||
@ -1,6 +1,8 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
|
import ray
|
||||||
|
|
||||||
import vllm.envs as envs
|
import vllm.envs as envs
|
||||||
from vllm.executor.ray_gpu_executor import RayGPUExecutor, RayGPUExecutorAsync
|
from vllm.executor.ray_gpu_executor import RayGPUExecutor, RayGPUExecutorAsync
|
||||||
from vllm.executor.xpu_executor import XPUExecutor
|
from vllm.executor.xpu_executor import XPUExecutor
|
||||||
@ -14,8 +16,13 @@ class RayXPUExecutor(RayGPUExecutor, XPUExecutor):
|
|||||||
|
|
||||||
def _get_env_vars_to_be_updated(self):
|
def _get_env_vars_to_be_updated(self):
|
||||||
# Get the set of GPU IDs used on each node.
|
# Get the set of GPU IDs used on each node.
|
||||||
worker_node_and_gpu_ids = self._run_workers("get_node_and_gpu_ids",
|
worker_node_and_gpu_ids = []
|
||||||
use_dummy_driver=True)
|
for worker in [self.driver_dummy_worker] + self.workers:
|
||||||
|
if worker is None:
|
||||||
|
# driver_dummy_worker can be None when using ray spmd worker.
|
||||||
|
continue
|
||||||
|
worker_node_and_gpu_ids.append(
|
||||||
|
ray.get(worker.get_node_and_gpu_ids.remote())) # type: ignore
|
||||||
|
|
||||||
# Set environment variables for the driver and workers.
|
# Set environment variables for the driver and workers.
|
||||||
all_args_to_update_environment_variables = [({
|
all_args_to_update_environment_variables = [({
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user