From 6e2c19ce227ecf285ed24a138b91570b3a2d57a6 Mon Sep 17 00:00:00 2001 From: Yang Yang Date: Mon, 7 Jul 2025 12:32:32 +0800 Subject: [PATCH] [Refactor]Abstract Platform Interface for Distributed Backend and Add xccl Support for Intel XPU (#19410) Signed-off-by: dbyoung18 Signed-off-by: Kunshang Ji Co-authored-by: Kunshang Ji --- docs/getting_started/installation/gpu/xpu.inc.md | 5 +++++ vllm/platforms/__init__.py | 13 +++++++++++-- vllm/platforms/cpu.py | 1 + vllm/platforms/cuda.py | 1 + vllm/platforms/hpu.py | 1 + vllm/platforms/interface.py | 3 +++ vllm/platforms/neuron.py | 1 + vllm/platforms/rocm.py | 1 + vllm/platforms/tpu.py | 1 + vllm/platforms/xpu.py | 1 + vllm/utils/__init__.py | 6 ++++++ vllm/v1/worker/cpu_worker.py | 4 +++- vllm/v1/worker/gpu_worker.py | 3 ++- vllm/v1/worker/tpu_worker.py | 3 ++- vllm/worker/hpu_worker.py | 3 ++- vllm/worker/neuron_worker.py | 2 +- vllm/worker/worker.py | 3 ++- 17 files changed, 44 insertions(+), 8 deletions(-) diff --git a/docs/getting_started/installation/gpu/xpu.inc.md b/docs/getting_started/installation/gpu/xpu.inc.md index 4469be36c0075..1514a0c2d3cd4 100644 --- a/docs/getting_started/installation/gpu/xpu.inc.md +++ b/docs/getting_started/installation/gpu/xpu.inc.md @@ -81,4 +81,9 @@ python -m vllm.entrypoints.openai.api_server \ By default, a ray instance will be launched automatically if no existing one is detected in the system, with `num-gpus` equals to `parallel_config.world_size`. We recommend properly starting a ray cluster before execution, referring to the helper script. # --8<-- [end:supported-features] +# --8<-- [start:distributed-backend] + +XPU platform uses **torch-ccl** for torch<2.8 and **xccl** for torch>=2.8 as distributed backend, since torch 2.8 supports **xccl** as built-in backend for XPU. + +# --8<-- [end:distributed-backend] # --8<-- [end:extra-information] diff --git a/vllm/platforms/__init__.py b/vllm/platforms/__init__.py index 13453d2c4b4b2..7b8953fd75bb0 100644 --- a/vllm/platforms/__init__.py +++ b/vllm/platforms/__init__.py @@ -7,7 +7,7 @@ from itertools import chain from typing import TYPE_CHECKING, Optional from vllm.plugins import load_plugins_by_group -from vllm.utils import resolve_obj_by_qualname +from vllm.utils import resolve_obj_by_qualname, supports_xccl from .interface import _Backend # noqa: F401 from .interface import CpuArchEnum, Platform, PlatformEnum @@ -139,10 +139,19 @@ def xpu_platform_plugin() -> Optional[str]: try: # installed IPEX if the machine has XPUs. import intel_extension_for_pytorch # noqa: F401 - import oneccl_bindings_for_pytorch # noqa: F401 import torch + if supports_xccl(): + dist_backend = "xccl" + else: + dist_backend = "ccl" + import oneccl_bindings_for_pytorch # noqa: F401 + if hasattr(torch, 'xpu') and torch.xpu.is_available(): is_xpu = True + from vllm.platforms.xpu import XPUPlatform + XPUPlatform.dist_backend = dist_backend + logger.debug("Confirmed %s backend is available.", + XPUPlatform.dist_backend) logger.debug("Confirmed XPU platform is available.") except Exception as e: logger.debug("XPU platform is not available because: %s", str(e)) diff --git a/vllm/platforms/cpu.py b/vllm/platforms/cpu.py index 1050d3c593443..676a440a79db8 100644 --- a/vllm/platforms/cpu.py +++ b/vllm/platforms/cpu.py @@ -37,6 +37,7 @@ class CpuPlatform(Platform): device_name: str = "cpu" device_type: str = "cpu" dispatch_key: str = "CPU" + dist_backend: str = "gloo" @property def supported_dtypes(self) -> list[torch.dtype]: diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index 0a5f4004e4488..50eedfa3c412f 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -56,6 +56,7 @@ class CudaPlatformBase(Platform): device_type: str = "cuda" dispatch_key: str = "CUDA" ray_device_key: str = "GPU" + dist_backend: str = "nccl" device_control_env_var: str = "CUDA_VISIBLE_DEVICES" @property diff --git a/vllm/platforms/hpu.py b/vllm/platforms/hpu.py index 3cf28950190c8..0b1e2f2327901 100644 --- a/vllm/platforms/hpu.py +++ b/vllm/platforms/hpu.py @@ -26,6 +26,7 @@ class HpuPlatform(Platform): device_type: str = "hpu" dispatch_key: str = "HPU" ray_device_key: str = "HPU" + dist_backend: str = "hccl" device_control_env_var: str = "HABANA_VISIBLE_MODULES" @classmethod diff --git a/vllm/platforms/interface.py b/vllm/platforms/interface.py index 567d5cbf503fe..b0ef9905481b4 100644 --- a/vllm/platforms/interface.py +++ b/vllm/platforms/interface.py @@ -129,6 +129,9 @@ class Platform: # compilation strategy. simple_compile_backend: str = "inductor" + # The backend used for distributed communication. + dist_backend: str = "" + supported_quantization: list[str] = [] additional_env_vars: list[str] = [] diff --git a/vllm/platforms/neuron.py b/vllm/platforms/neuron.py index 04e918d7aebee..cb8ac8db669fe 100644 --- a/vllm/platforms/neuron.py +++ b/vllm/platforms/neuron.py @@ -30,6 +30,7 @@ class NeuronPlatform(Platform): device_type: str = "neuron" ray_device_key: str = "neuron_cores" supported_quantization: list[str] = ["neuron_quant", "fbgemm_fp8"] + dist_backend: str = "gloo" device_control_env_var: str = "NEURON_RT_VISIBLE_CORES" @classmethod diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py index 4550ef570684b..31f4699cd1b0c 100644 --- a/vllm/platforms/rocm.py +++ b/vllm/platforms/rocm.py @@ -164,6 +164,7 @@ class RocmPlatform(Platform): device_type: str = "cuda" dispatch_key: str = "CUDA" ray_device_key: str = "GPU" + dist_backend: str = "nccl" # rocm shares the same device control env var as CUDA device_control_env_var: str = "CUDA_VISIBLE_DEVICES" diff --git a/vllm/platforms/tpu.py b/vllm/platforms/tpu.py index a8c8cb46de2cc..6810944c848d7 100644 --- a/vllm/platforms/tpu.py +++ b/vllm/platforms/tpu.py @@ -31,6 +31,7 @@ class TpuPlatform(Platform): device_type: str = "tpu" dispatch_key: str = "XLA" ray_device_key: str = "TPU" + dist_backend: str = "gloo" device_control_env_var: str = "TPU_VISIBLE_CHIPS" simple_compile_backend: str = "openxla" diff --git a/vllm/platforms/xpu.py b/vllm/platforms/xpu.py index 5bd34033233ab..de715fd894c33 100644 --- a/vllm/platforms/xpu.py +++ b/vllm/platforms/xpu.py @@ -29,6 +29,7 @@ class XPUPlatform(Platform): # Intel XPU's device key is "GPU" for Ray. # see https://github.com/ray-project/ray/blob/6a5eb5865eeb9ccf058a79b44f107e327e360673/python/ray/_private/accelerators/intel_gpu.py#L20 # noqa: E501 ray_device_key: str = "GPU" + dist_backend: str = "ccl" # ccl | xccl device_control_env_var: str = "ONEAPI_DEVICE_SELECTOR" @classmethod diff --git a/vllm/utils/__init__.py b/vllm/utils/__init__.py index 9550b056fbba9..9322e3cc477a0 100644 --- a/vllm/utils/__init__.py +++ b/vllm/utils/__init__.py @@ -1886,6 +1886,12 @@ def supports_dynamo() -> bool: return base_torch_version >= Version("2.4.0") +# Supports xccl with PyTorch versions >= 2.8.0 for XPU platform +def supports_xccl() -> bool: + return is_torch_equal_or_newer( + "2.8.0") and torch.distributed.is_xccl_available() + + # Some backends use pytorch version < 2.4.0 which doesn't # support `torch.library.custom_op`. def supports_custom_op() -> bool: diff --git a/vllm/v1/worker/cpu_worker.py b/vllm/v1/worker/cpu_worker.py index de575d604055f..7712b7974544f 100644 --- a/vllm/v1/worker/cpu_worker.py +++ b/vllm/v1/worker/cpu_worker.py @@ -11,6 +11,7 @@ from vllm.config import VllmConfig from vllm.distributed.parallel_state import get_pp_group, get_tp_group from vllm.logger import init_logger from vllm.model_executor.utils import set_random_seed +from vllm.platforms import current_platform from vllm.sequence import IntermediateTensors from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.outputs import ModelRunnerOutput @@ -58,7 +59,8 @@ class CPUWorker(Worker): # Initialize the distributed environment. init_worker_distributed_environment(self.vllm_config, self.rank, self.distributed_init_method, - self.local_rank, "gloo") + self.local_rank, + current_platform.dist_backend) # Set random seed. set_random_seed(self.model_config.seed) diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py index 9e7e44d068612..d1df0fd959b5e 100644 --- a/vllm/v1/worker/gpu_worker.py +++ b/vllm/v1/worker/gpu_worker.py @@ -157,7 +157,8 @@ class Worker(WorkerBase): # Initialize the distributed environment. init_worker_distributed_environment(self.vllm_config, self.rank, self.distributed_init_method, - self.local_rank) + self.local_rank, + current_platform.dist_backend) # Set random seed. set_random_seed(self.model_config.seed) diff --git a/vllm/v1/worker/tpu_worker.py b/vllm/v1/worker/tpu_worker.py index a64ce881fe318..ade4d08211683 100644 --- a/vllm/v1/worker/tpu_worker.py +++ b/vllm/v1/worker/tpu_worker.py @@ -18,6 +18,7 @@ from vllm.distributed import (ensure_model_parallel_initialized, from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.model_executor import set_random_seed +from vllm.platforms import current_platform from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, cdiv from vllm.v1.attention.backends.pallas import TPU_HEAD_SIZE_ALIGNMENT from vllm.v1.core.sched.output import SchedulerOutput @@ -300,7 +301,7 @@ class TPUWorker: rank=rank, local_rank=local_rank, distributed_init_method=distributed_init_method, - backend="gloo", + backend=current_platform.dist_backend, ) ensure_model_parallel_initialized( parallel_config.tensor_parallel_size, diff --git a/vllm/worker/hpu_worker.py b/vllm/worker/hpu_worker.py index 6d76ea499a901..560110df0a322 100644 --- a/vllm/worker/hpu_worker.py +++ b/vllm/worker/hpu_worker.py @@ -23,6 +23,7 @@ from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.model_executor import set_random_seed from vllm.model_executor.layers.sampler import SamplerOutput +from vllm.platforms import current_platform from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sequence import ExecuteModelRequest from vllm.utils import bind_kv_cache @@ -413,7 +414,7 @@ def init_worker_distributed_environment( rank, distributed_init_method, local_rank, - backend='hccl') + backend=current_platform.dist_backend) ensure_model_parallel_initialized(parallel_config.tensor_parallel_size, parallel_config.pipeline_parallel_size) diff --git a/vllm/worker/neuron_worker.py b/vllm/worker/neuron_worker.py index 662bde6bc07b0..4e1408300fb8b 100644 --- a/vllm/worker/neuron_worker.py +++ b/vllm/worker/neuron_worker.py @@ -156,7 +156,7 @@ class NeuronWorker(LocalOrDistributedWorkerBase): rank=self.rank, local_rank=self.local_rank, distributed_init_method=self.distributed_init_method, - backend="gloo", + backend=current_platform.dist_backend, ) ensure_model_parallel_initialized( diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index 9a928632688a1..21e684a3fb5a0 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -530,7 +530,8 @@ def init_worker_distributed_environment( set_custom_all_reduce(not parallel_config.disable_custom_all_reduce) init_distributed_environment(parallel_config.world_size, rank, - distributed_init_method, local_rank) + distributed_init_method, local_rank, + current_platform.dist_backend) ensure_model_parallel_initialized(parallel_config.tensor_parallel_size, parallel_config.pipeline_parallel_size)