mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-25 06:55:01 +08:00
[BugFix] Set CUDA_VISIBLE_DEVICES before spawning the subprocesses (#21211)
Signed-off-by: Yinghai Lu <yinghai@thinkingmachines.ai> Signed-off-by: Nick Hill <nhill@redhat.com> Signed-off-by: Rui Qiao <ruisearch42@gmail.com> Co-authored-by: Nick Hill <nhill@redhat.com> Co-authored-by: Rui Qiao <ruisearch42@gmail.com>
This commit is contained in:
parent
dc2f159f8a
commit
11ef7a611e
@ -910,22 +910,6 @@ class DPEngineCoreProc(EngineCoreProc):
|
|||||||
logger.debug("Setting kv_transfer_config.engine_id to %s",
|
logger.debug("Setting kv_transfer_config.engine_id to %s",
|
||||||
vllm_config.kv_transfer_config.engine_id)
|
vllm_config.kv_transfer_config.engine_id)
|
||||||
|
|
||||||
from vllm.platforms import current_platform
|
|
||||||
device_control_env_var = current_platform.device_control_env_var
|
|
||||||
world_size = vllm_config.parallel_config.world_size
|
|
||||||
# Set CUDA_VISIBLE_DEVICES or equivalent.
|
|
||||||
try:
|
|
||||||
os.environ[device_control_env_var] = ",".join(
|
|
||||||
str(current_platform.device_id_to_physical_device_id(i))
|
|
||||||
for i in range(local_dp_rank *
|
|
||||||
world_size, (local_dp_rank + 1) * world_size))
|
|
||||||
except IndexError as e:
|
|
||||||
raise Exception(
|
|
||||||
f"Error setting {device_control_env_var}: "
|
|
||||||
f"local range: [{local_dp_rank * world_size}, "
|
|
||||||
f"{(local_dp_rank + 1) * world_size}) "
|
|
||||||
f"base value: \"{os.getenv(device_control_env_var)}\"") from e
|
|
||||||
|
|
||||||
self.dp_rank = dp_rank
|
self.dp_rank = dp_rank
|
||||||
self.dp_group = vllm_config.parallel_config.stateless_init_dp_group()
|
self.dp_group = vllm_config.parallel_config.stateless_init_dp_group()
|
||||||
|
|
||||||
@ -1088,14 +1072,41 @@ class DPEngineCoreActor(DPEngineCoreProc):
|
|||||||
vllm_config.parallel_config.data_parallel_rank_local = \
|
vllm_config.parallel_config.data_parallel_rank_local = \
|
||||||
local_dp_rank
|
local_dp_rank
|
||||||
|
|
||||||
# Ray sets CUDA_VISIBLE_DEVICES to empty string,
|
# Set CUDA_VISIBLE_DEVICES as early as possible in actor life cycle
|
||||||
# we clean this up to be able to properly initialize
|
# NOTE: in MP we set CUDA_VISIBLE_DEVICES at process creation time,
|
||||||
# data parallel groups.
|
# and this cannot be done in the same way for Ray because:
|
||||||
del os.environ['CUDA_VISIBLE_DEVICES']
|
# 1) Ray manages life cycle of all ray workers (including
|
||||||
|
# DPEngineCoreActor)
|
||||||
|
# 2) Ray sets CUDA_VISIBLE_DEVICES based on num_gpus configuration
|
||||||
|
# To bypass 2, we need to also set
|
||||||
|
# RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES, but vLLM workers created
|
||||||
|
# thereafter would have CUDA_VISIBLE_DEVICES set, which is sticky:
|
||||||
|
# https://github.com/ray-project/ray/blob/e752fc319ddedd9779a0989b6d3613909bad75c9/python/ray/_private/worker.py#L456 # noqa: E501
|
||||||
|
# But vLLM worker assumes visibility into all local GPUs, therefore
|
||||||
|
# this results in incorrect indexing into the GPU ID list.
|
||||||
|
self._set_cuda_visible_devices(vllm_config, local_dp_rank)
|
||||||
|
|
||||||
super().__init__(vllm_config, local_client, "", executor_class,
|
super().__init__(vllm_config, local_client, "", executor_class,
|
||||||
log_stats)
|
log_stats)
|
||||||
|
|
||||||
|
def _set_cuda_visible_devices(self, vllm_config: VllmConfig,
|
||||||
|
local_dp_rank: int):
|
||||||
|
from vllm.platforms import current_platform
|
||||||
|
device_control_env_var = current_platform.device_control_env_var
|
||||||
|
world_size = vllm_config.parallel_config.world_size
|
||||||
|
# Set CUDA_VISIBLE_DEVICES or equivalent.
|
||||||
|
try:
|
||||||
|
os.environ[device_control_env_var] = ",".join(
|
||||||
|
str(current_platform.device_id_to_physical_device_id(i))
|
||||||
|
for i in range(local_dp_rank *
|
||||||
|
world_size, (local_dp_rank + 1) * world_size))
|
||||||
|
except IndexError as e:
|
||||||
|
raise Exception(
|
||||||
|
f"Error setting {device_control_env_var}: "
|
||||||
|
f"local range: [{local_dp_rank * world_size}, "
|
||||||
|
f"{(local_dp_rank + 1) * world_size}) "
|
||||||
|
f"base value: \"{os.getenv(device_control_env_var)}\"") from e
|
||||||
|
|
||||||
def _decorate_logs(self):
|
def _decorate_logs(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|||||||
@ -10,12 +10,14 @@ from enum import Enum, auto
|
|||||||
from multiprocessing import Process, connection
|
from multiprocessing import Process, connection
|
||||||
from multiprocessing.process import BaseProcess
|
from multiprocessing.process import BaseProcess
|
||||||
from typing import TYPE_CHECKING, Callable, Optional, Union
|
from typing import TYPE_CHECKING, Callable, Optional, Union
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
import msgspec
|
import msgspec
|
||||||
import zmq
|
import zmq
|
||||||
|
|
||||||
from vllm.config import CacheConfig, ParallelConfig, VllmConfig
|
from vllm.config import CacheConfig, ParallelConfig, VllmConfig
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
|
from vllm.platforms import current_platform
|
||||||
from vllm.ray.ray_env import get_env_vars_to_copy
|
from vllm.ray.ray_env import get_env_vars_to_copy
|
||||||
from vllm.utils import get_mp_context, get_open_zmq_ipc_path, zmq_socket_ctx
|
from vllm.utils import get_mp_context, get_open_zmq_ipc_path, zmq_socket_ctx
|
||||||
from vllm.v1.engine.coordinator import DPCoordinator
|
from vllm.v1.engine.coordinator import DPCoordinator
|
||||||
@ -105,10 +107,13 @@ class CoreEngineProcManager:
|
|||||||
"client_handshake_address"] = client_handshake_address
|
"client_handshake_address"] = client_handshake_address
|
||||||
|
|
||||||
self.processes: list[BaseProcess] = []
|
self.processes: list[BaseProcess] = []
|
||||||
|
local_dp_ranks = []
|
||||||
for index in range(local_engine_count):
|
for index in range(local_engine_count):
|
||||||
local_index = local_start_index + index
|
local_index = local_start_index + index
|
||||||
global_index = start_index + index
|
global_index = start_index + index
|
||||||
|
|
||||||
# Start EngineCore in background process.
|
# Start EngineCore in background process.
|
||||||
|
local_dp_ranks.append(local_index)
|
||||||
self.processes.append(
|
self.processes.append(
|
||||||
context.Process(target=target_fn,
|
context.Process(target=target_fn,
|
||||||
name=f"EngineCore_{global_index}",
|
name=f"EngineCore_{global_index}",
|
||||||
@ -118,9 +123,14 @@ class CoreEngineProcManager:
|
|||||||
}))
|
}))
|
||||||
|
|
||||||
self._finalizer = weakref.finalize(self, shutdown, self.processes)
|
self._finalizer = weakref.finalize(self, shutdown, self.processes)
|
||||||
|
|
||||||
|
data_parallel = vllm_config.parallel_config.data_parallel_size > 1
|
||||||
try:
|
try:
|
||||||
for proc in self.processes:
|
for proc, local_dp_rank in zip(self.processes, local_dp_ranks):
|
||||||
proc.start()
|
with set_device_control_env_var(
|
||||||
|
vllm_config, local_dp_rank) if (
|
||||||
|
data_parallel) else contextlib.nullcontext():
|
||||||
|
proc.start()
|
||||||
finally:
|
finally:
|
||||||
# Kill other procs if not all are running.
|
# Kill other procs if not all are running.
|
||||||
if self.finished_procs():
|
if self.finished_procs():
|
||||||
@ -145,6 +155,30 @@ class CoreEngineProcManager:
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@contextlib.contextmanager
|
||||||
|
def set_device_control_env_var(vllm_config: VllmConfig,
|
||||||
|
local_dp_rank: int) -> Iterator[None]:
|
||||||
|
"""
|
||||||
|
Temporarily set CUDA_VISIBLE_DEVICES or equivalent
|
||||||
|
for engine subprocess.
|
||||||
|
"""
|
||||||
|
world_size = vllm_config.parallel_config.world_size
|
||||||
|
evar = current_platform.device_control_env_var
|
||||||
|
try:
|
||||||
|
value = ",".join(
|
||||||
|
str(current_platform.device_id_to_physical_device_id(i))
|
||||||
|
for i in range(local_dp_rank * world_size, (local_dp_rank + 1) *
|
||||||
|
world_size))
|
||||||
|
except IndexError as e:
|
||||||
|
raise Exception(f"Error setting {evar}: "
|
||||||
|
f"local range: [{local_dp_rank * world_size}, "
|
||||||
|
f"{(local_dp_rank + 1) * world_size}) "
|
||||||
|
"base value: "
|
||||||
|
f"\"{os.getenv(evar)}\"") from e
|
||||||
|
with patch.dict(os.environ, values=((evar, value), )):
|
||||||
|
yield
|
||||||
|
|
||||||
|
|
||||||
class CoreEngineActorManager:
|
class CoreEngineActorManager:
|
||||||
"""
|
"""
|
||||||
Utility class to handle creation, readiness, and shutdown
|
Utility class to handle creation, readiness, and shutdown
|
||||||
@ -215,10 +249,9 @@ class CoreEngineActorManager:
|
|||||||
|
|
||||||
self.placement_group_is_local = []
|
self.placement_group_is_local = []
|
||||||
refs = []
|
refs = []
|
||||||
for index in range(dp_size):
|
for index, local_index, pg in zip(range(dp_size), local_dp_ranks,
|
||||||
local_index = local_dp_ranks[index]
|
placement_groups):
|
||||||
dp_vllm_config = copy.deepcopy(vllm_config)
|
dp_vllm_config = copy.deepcopy(vllm_config)
|
||||||
pg = placement_groups[index]
|
|
||||||
dp_vllm_config.parallel_config.placement_group = pg
|
dp_vllm_config.parallel_config.placement_group = pg
|
||||||
local_client = index < local_engine_count
|
local_client = index < local_engine_count
|
||||||
actor = ray.remote(DPEngineCoreActor).options(
|
actor = ray.remote(DPEngineCoreActor).options(
|
||||||
@ -264,7 +297,6 @@ class CoreEngineActorManager:
|
|||||||
local_engine_count = \
|
local_engine_count = \
|
||||||
vllm_config.parallel_config.data_parallel_size_local
|
vllm_config.parallel_config.data_parallel_size_local
|
||||||
|
|
||||||
nodes = list_nodes()
|
|
||||||
nodes = sorted(list_nodes(),
|
nodes = sorted(list_nodes(),
|
||||||
key=lambda node: node.node_ip != dp_master_ip)
|
key=lambda node: node.node_ip != dp_master_ip)
|
||||||
assert nodes[0].node_ip == dp_master_ip, (
|
assert nodes[0].node_ip == dp_master_ip, (
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user