mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-27 13:47:55 +08:00
[DP] Copy environment variables to Ray DPEngineCoreActors (#20344)
Signed-off-by: Rui Qiao <ruisearch42@gmail.com>
This commit is contained in:
parent
a37d75bbec
commit
a6d795d593
@ -2,7 +2,6 @@
|
|||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import json
|
|
||||||
import os
|
import os
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
@ -20,6 +19,7 @@ from vllm.executor.ray_utils import (RayWorkerWrapper, initialize_ray_cluster,
|
|||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.layers.sampler import SamplerOutput
|
from vllm.model_executor.layers.sampler import SamplerOutput
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
|
from vllm.ray.ray_env import get_env_vars_to_copy
|
||||||
from vllm.sequence import ExecuteModelRequest
|
from vllm.sequence import ExecuteModelRequest
|
||||||
from vllm.utils import (_run_task_with_lock, get_distributed_init_method,
|
from vllm.utils import (_run_task_with_lock, get_distributed_init_method,
|
||||||
get_ip, get_open_port, make_async)
|
get_ip, get_open_port, make_async)
|
||||||
@ -58,17 +58,6 @@ class RayDistributedExecutor(DistributedExecutorBase):
|
|||||||
"VLLM_HOST_IP", "VLLM_HOST_PORT", "LOCAL_RANK", "CUDA_VISIBLE_DEVICES"
|
"VLLM_HOST_IP", "VLLM_HOST_PORT", "LOCAL_RANK", "CUDA_VISIBLE_DEVICES"
|
||||||
}
|
}
|
||||||
|
|
||||||
config_home = envs.VLLM_CONFIG_ROOT
|
|
||||||
# This file contains a list of env vars that should not be copied
|
|
||||||
# from the driver to the Ray workers.
|
|
||||||
non_carry_over_env_vars_file = os.path.join(
|
|
||||||
config_home, "ray_non_carry_over_env_vars.json")
|
|
||||||
if os.path.exists(non_carry_over_env_vars_file):
|
|
||||||
with open(non_carry_over_env_vars_file) as f:
|
|
||||||
non_carry_over_env_vars = set(json.load(f))
|
|
||||||
else:
|
|
||||||
non_carry_over_env_vars = set()
|
|
||||||
|
|
||||||
uses_ray: bool = True
|
uses_ray: bool = True
|
||||||
|
|
||||||
def _init_executor(self) -> None:
|
def _init_executor(self) -> None:
|
||||||
@ -335,13 +324,10 @@ class RayDistributedExecutor(DistributedExecutorBase):
|
|||||||
} for (node_id, _) in worker_node_and_gpu_ids]
|
} for (node_id, _) in worker_node_and_gpu_ids]
|
||||||
|
|
||||||
# Environment variables to copy from driver to workers
|
# Environment variables to copy from driver to workers
|
||||||
env_vars_to_copy = [
|
env_vars_to_copy = get_env_vars_to_copy(
|
||||||
v for v in envs.environment_variables
|
exclude_vars=self.WORKER_SPECIFIC_ENV_VARS,
|
||||||
if v not in self.WORKER_SPECIFIC_ENV_VARS
|
additional_vars=set(current_platform.additional_env_vars),
|
||||||
and v not in self.non_carry_over_env_vars
|
destination="workers")
|
||||||
]
|
|
||||||
|
|
||||||
env_vars_to_copy.extend(current_platform.additional_env_vars)
|
|
||||||
|
|
||||||
# Copy existing env vars to each worker's args
|
# Copy existing env vars to each worker's args
|
||||||
for args in all_args_to_update_environment_variables:
|
for args in all_args_to_update_environment_variables:
|
||||||
@ -350,15 +336,6 @@ class RayDistributedExecutor(DistributedExecutorBase):
|
|||||||
if name in os.environ:
|
if name in os.environ:
|
||||||
args[name] = os.environ[name]
|
args[name] = os.environ[name]
|
||||||
|
|
||||||
logger.info("non_carry_over_env_vars from config: %s",
|
|
||||||
self.non_carry_over_env_vars)
|
|
||||||
logger.info(
|
|
||||||
"Copying the following environment variables to workers: %s",
|
|
||||||
[v for v in env_vars_to_copy if v in os.environ])
|
|
||||||
logger.info(
|
|
||||||
"If certain env vars should NOT be copied to workers, add them to "
|
|
||||||
"%s file", self.non_carry_over_env_vars_file)
|
|
||||||
|
|
||||||
self._env_vars_for_all_workers = (
|
self._env_vars_for_all_workers = (
|
||||||
all_args_to_update_environment_variables)
|
all_args_to_update_environment_variables)
|
||||||
|
|
||||||
|
|||||||
71
vllm/ray/ray_env.py
Normal file
71
vllm/ray/ray_env.py
Normal file
@ -0,0 +1,71 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import vllm.envs as envs
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
CONFIG_HOME = envs.VLLM_CONFIG_ROOT
|
||||||
|
|
||||||
|
# This file contains a list of env vars that should not be copied
|
||||||
|
# from the driver to the Ray workers.
|
||||||
|
RAY_NON_CARRY_OVER_ENV_VARS_FILE = os.path.join(
|
||||||
|
CONFIG_HOME, "ray_non_carry_over_env_vars.json")
|
||||||
|
|
||||||
|
try:
|
||||||
|
if os.path.exists(RAY_NON_CARRY_OVER_ENV_VARS_FILE):
|
||||||
|
with open(RAY_NON_CARRY_OVER_ENV_VARS_FILE) as f:
|
||||||
|
RAY_NON_CARRY_OVER_ENV_VARS = set(json.load(f))
|
||||||
|
else:
|
||||||
|
RAY_NON_CARRY_OVER_ENV_VARS = set()
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
logger.warning(
|
||||||
|
"Failed to parse %s. Using an empty set for non-carry-over env vars.",
|
||||||
|
RAY_NON_CARRY_OVER_ENV_VARS_FILE)
|
||||||
|
RAY_NON_CARRY_OVER_ENV_VARS = set()
|
||||||
|
|
||||||
|
|
||||||
|
def get_env_vars_to_copy(exclude_vars: Optional[set[str]] = None,
|
||||||
|
additional_vars: Optional[set[str]] = None,
|
||||||
|
destination: Optional[str] = None) -> set[str]:
|
||||||
|
"""
|
||||||
|
Get the environment variables to copy to downstream Ray actors.
|
||||||
|
|
||||||
|
Example use cases:
|
||||||
|
- Copy environment variables from RayDistributedExecutor to Ray workers.
|
||||||
|
- Copy environment variables from RayDPClient to Ray DPEngineCoreActor.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
exclude_vars: A set of vllm defined environment variables to exclude
|
||||||
|
from copying.
|
||||||
|
additional_vars: A set of additional environment variables to copy.
|
||||||
|
destination: The destination of the environment variables.
|
||||||
|
Returns:
|
||||||
|
A set of environment variables to copy.
|
||||||
|
"""
|
||||||
|
exclude_vars = exclude_vars or set()
|
||||||
|
additional_vars = additional_vars or set()
|
||||||
|
|
||||||
|
env_vars_to_copy = {
|
||||||
|
v
|
||||||
|
for v in envs.environment_variables
|
||||||
|
if v not in exclude_vars and v not in RAY_NON_CARRY_OVER_ENV_VARS
|
||||||
|
}
|
||||||
|
env_vars_to_copy.update(additional_vars)
|
||||||
|
|
||||||
|
to_destination = " to " + destination if destination is not None else ""
|
||||||
|
|
||||||
|
logger.info("RAY_NON_CARRY_OVER_ENV_VARS from config: %s",
|
||||||
|
RAY_NON_CARRY_OVER_ENV_VARS)
|
||||||
|
logger.info("Copying the following environment variables%s: %s",
|
||||||
|
to_destination,
|
||||||
|
[v for v in env_vars_to_copy if v in os.environ])
|
||||||
|
logger.info(
|
||||||
|
"If certain env vars should NOT be copied, add them to "
|
||||||
|
"%s file", RAY_NON_CARRY_OVER_ENV_VARS_FILE)
|
||||||
|
|
||||||
|
return env_vars_to_copy
|
||||||
@ -2,6 +2,7 @@
|
|||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
import contextlib
|
import contextlib
|
||||||
|
import os
|
||||||
import weakref
|
import weakref
|
||||||
from collections.abc import Iterator
|
from collections.abc import Iterator
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
@ -15,6 +16,7 @@ import zmq
|
|||||||
|
|
||||||
from vllm.config import CacheConfig, ParallelConfig, VllmConfig
|
from vllm.config import CacheConfig, ParallelConfig, VllmConfig
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
|
from vllm.ray.ray_env import get_env_vars_to_copy
|
||||||
from vllm.utils import get_mp_context, get_open_zmq_ipc_path, zmq_socket_ctx
|
from vllm.utils import get_mp_context, get_open_zmq_ipc_path, zmq_socket_ctx
|
||||||
from vllm.v1.engine.coordinator import DPCoordinator
|
from vllm.v1.engine.coordinator import DPCoordinator
|
||||||
from vllm.v1.executor.abstract import Executor
|
from vllm.v1.executor.abstract import Executor
|
||||||
@ -164,6 +166,7 @@ class CoreEngineActorManager:
|
|||||||
import copy
|
import copy
|
||||||
|
|
||||||
import ray
|
import ray
|
||||||
|
from ray.runtime_env import RuntimeEnv
|
||||||
from ray.util.scheduling_strategies import (
|
from ray.util.scheduling_strategies import (
|
||||||
PlacementGroupSchedulingStrategy)
|
PlacementGroupSchedulingStrategy)
|
||||||
|
|
||||||
@ -175,6 +178,12 @@ class CoreEngineActorManager:
|
|||||||
local_engine_count = \
|
local_engine_count = \
|
||||||
vllm_config.parallel_config.data_parallel_size_local
|
vllm_config.parallel_config.data_parallel_size_local
|
||||||
world_size = vllm_config.parallel_config.world_size
|
world_size = vllm_config.parallel_config.world_size
|
||||||
|
env_vars_set = get_env_vars_to_copy(destination="DPEngineCoreActor")
|
||||||
|
env_vars_dict = {
|
||||||
|
name: os.environ[name]
|
||||||
|
for name in env_vars_set if name in os.environ
|
||||||
|
}
|
||||||
|
runtime_env = RuntimeEnv(env_vars=env_vars_dict)
|
||||||
|
|
||||||
if ray.is_initialized():
|
if ray.is_initialized():
|
||||||
logger.info(
|
logger.info(
|
||||||
@ -210,13 +219,14 @@ class CoreEngineActorManager:
|
|||||||
scheduling_strategy=PlacementGroupSchedulingStrategy(
|
scheduling_strategy=PlacementGroupSchedulingStrategy(
|
||||||
placement_group=pg,
|
placement_group=pg,
|
||||||
placement_group_bundle_index=world_size,
|
placement_group_bundle_index=world_size,
|
||||||
)).remote(vllm_config=dp_vllm_config,
|
),
|
||||||
executor_class=executor_class,
|
runtime_env=runtime_env).remote(vllm_config=dp_vllm_config,
|
||||||
log_stats=log_stats,
|
executor_class=executor_class,
|
||||||
local_client=local_client,
|
log_stats=log_stats,
|
||||||
addresses=addresses,
|
local_client=local_client,
|
||||||
dp_rank=index,
|
addresses=addresses,
|
||||||
local_dp_rank=local_index)
|
dp_rank=index,
|
||||||
|
local_dp_rank=local_index)
|
||||||
if local_client:
|
if local_client:
|
||||||
self.local_engine_actors.append(actor)
|
self.local_engine_actors.append(actor)
|
||||||
else:
|
else:
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user