From 8f0d7eaea87409a54ccaed76995b59c6b0a3d4cf Mon Sep 17 00:00:00 2001 From: Fanli Lin Date: Wed, 27 Aug 2025 19:57:38 +0800 Subject: [PATCH] [XPU] Fix OOM issue for data parallel with Ray backend (#22500) Signed-off-by: Fanli Lin Signed-off-by: Fanli Lin Co-authored-by: Cyrus Leung --- vllm/v1/engine/core.py | 27 ++++++++++++++++++--------- vllm/v1/engine/utils.py | 35 +++++++++++++++++++++++++++++++---- 2 files changed, 49 insertions(+), 13 deletions(-) diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index b614828061846..a7038e2d2c264 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -39,7 +39,8 @@ from vllm.v1.engine import (EngineCoreOutputs, EngineCoreRequest, EngineCoreRequestType, ReconfigureDistributedRequest, ReconfigureRankType, UtilityOutput, UtilityResult) -from vllm.v1.engine.utils import EngineHandshakeMetadata, EngineZmqAddresses +from vllm.v1.engine.utils import (EngineHandshakeMetadata, EngineZmqAddresses, + get_device_indices) from vllm.v1.executor.abstract import Executor from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.metrics.stats import SchedulerStats @@ -1169,22 +1170,30 @@ class DPEngineCoreActor(DPEngineCoreProc): # https://github.com/ray-project/ray/pull/40461/files#diff-31e8159767361e4bc259b6d9883d9c0d5e5db780fcea4a52ead4ee3ee4a59a78R1860 # noqa: E501 # and get_accelerator_ids_for_accelerator_resource() in worker.py # of ray. - self._set_cuda_visible_devices(vllm_config, local_dp_rank) + self._set_visible_devices(vllm_config, local_dp_rank) super().__init__(vllm_config, local_client, "", executor_class, log_stats) - def _set_cuda_visible_devices(self, vllm_config: VllmConfig, - local_dp_rank: int): + def _set_visible_devices(self, vllm_config: VllmConfig, + local_dp_rank: int): from vllm.platforms import current_platform - device_control_env_var = current_platform.device_control_env_var + if current_platform.is_xpu(): + pass + else: + device_control_env_var = current_platform.device_control_env_var + self._set_cuda_visible_devices(vllm_config, local_dp_rank, + device_control_env_var) + + def _set_cuda_visible_devices(self, vllm_config: VllmConfig, + local_dp_rank: int, + device_control_env_var: str): world_size = vllm_config.parallel_config.world_size # Set CUDA_VISIBLE_DEVICES or equivalent. try: - os.environ[device_control_env_var] = ",".join( - str(current_platform.device_id_to_physical_device_id(i)) - for i in range(local_dp_rank * - world_size, (local_dp_rank + 1) * world_size)) + value = get_device_indices(device_control_env_var, local_dp_rank, + world_size) + os.environ[device_control_env_var] = value except IndexError as e: raise Exception( f"Error setting {device_control_env_var}: " diff --git a/vllm/v1/engine/utils.py b/vllm/v1/engine/utils.py index 62f229e286931..56ef8477d267a 100644 --- a/vllm/v1/engine/utils.py +++ b/vllm/v1/engine/utils.py @@ -164,19 +164,33 @@ def set_device_control_env_var(vllm_config: VllmConfig, """ world_size = vllm_config.parallel_config.world_size evar = current_platform.device_control_env_var + + value = get_device_indices(evar, local_dp_rank, world_size) + with patch.dict(os.environ, values=((evar, value), )): + yield + + +def get_device_indices(device_control_env_var: str, local_dp_rank: int, + world_size: int): + """ + Returns a comma-separated string of device indices for the specified + data parallel rank. + + For example, if world_size=2 and local_dp_rank=1, and there are 4 devices, + this will select devices 2 and 3 for local_dp_rank=1. + """ try: value = ",".join( str(current_platform.device_id_to_physical_device_id(i)) for i in range(local_dp_rank * world_size, (local_dp_rank + 1) * world_size)) except IndexError as e: - raise Exception(f"Error setting {evar}: " + raise Exception(f"Error setting {device_control_env_var}: " f"local range: [{local_dp_rank * world_size}, " f"{(local_dp_rank + 1) * world_size}) " "base value: " - f"\"{os.getenv(evar)}\"") from e - with patch.dict(os.environ, values=((evar, value), )): - yield + f"\"{os.getenv(device_control_env_var)}\"") from e + return value class CoreEngineActorManager: @@ -254,6 +268,19 @@ class CoreEngineActorManager: dp_vllm_config = copy.deepcopy(vllm_config) dp_vllm_config.parallel_config.placement_group = pg local_client = index < local_engine_count + + # Ray XPU known issue: dpctl initializes the GPU runtime early, so + # setting device env vars in Ray actor's initialization method + # will not affect device selection. See: + # https://github.com/ray-project/ray/blob/master/python/ray/_private/accelerators/intel_gpu.py#L56 # noqa: E501 + if current_platform.is_xpu(): + device_evar = current_platform.device_control_env_var + device_indices = get_device_indices(device_evar, local_index, + world_size) + actor_env_vars = self.env_vars_dict.copy() + actor_env_vars[device_evar] = device_indices + runtime_env = RuntimeEnv(env_vars=actor_env_vars) + actor = ray.remote(DPEngineCoreActor).options( scheduling_strategy=PlacementGroupSchedulingStrategy( placement_group=pg,