[Refactor]Abstract Platform Interface for Distributed Backend and Add xccl Support for Intel XPU (#19410)

Signed-off-by: dbyoung18 <yang5.yang@intel.com>
Signed-off-by: Kunshang Ji <kunshang.ji@intel.com>
Co-authored-by: Kunshang Ji <kunshang.ji@intel.com>
This commit is contained in:
Yang Yang 2025-07-07 12:32:32 +08:00 committed by GitHub
parent 47db8c2c15
commit 6e2c19ce22
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
17 changed files with 44 additions and 8 deletions

View File

@ -81,4 +81,9 @@ python -m vllm.entrypoints.openai.api_server \
By default, a ray instance will be launched automatically if no existing one is detected in the system, with `num-gpus` equals to `parallel_config.world_size`. We recommend properly starting a ray cluster before execution, referring to the <gh-file:examples/online_serving/run_cluster.sh> helper script.
# --8<-- [end:supported-features]
# --8<-- [start:distributed-backend]
XPU platform uses **torch-ccl** for torch<2.8 and **xccl** for torch>=2.8 as distributed backend, since torch 2.8 supports **xccl** as built-in backend for XPU.
# --8<-- [end:distributed-backend]
# --8<-- [end:extra-information]

View File

@ -7,7 +7,7 @@ from itertools import chain
from typing import TYPE_CHECKING, Optional
from vllm.plugins import load_plugins_by_group
from vllm.utils import resolve_obj_by_qualname
from vllm.utils import resolve_obj_by_qualname, supports_xccl
from .interface import _Backend # noqa: F401
from .interface import CpuArchEnum, Platform, PlatformEnum
@ -139,10 +139,19 @@ def xpu_platform_plugin() -> Optional[str]:
try:
# installed IPEX if the machine has XPUs.
import intel_extension_for_pytorch # noqa: F401
import oneccl_bindings_for_pytorch # noqa: F401
import torch
if supports_xccl():
dist_backend = "xccl"
else:
dist_backend = "ccl"
import oneccl_bindings_for_pytorch # noqa: F401
if hasattr(torch, 'xpu') and torch.xpu.is_available():
is_xpu = True
from vllm.platforms.xpu import XPUPlatform
XPUPlatform.dist_backend = dist_backend
logger.debug("Confirmed %s backend is available.",
XPUPlatform.dist_backend)
logger.debug("Confirmed XPU platform is available.")
except Exception as e:
logger.debug("XPU platform is not available because: %s", str(e))

View File

@ -37,6 +37,7 @@ class CpuPlatform(Platform):
device_name: str = "cpu"
device_type: str = "cpu"
dispatch_key: str = "CPU"
dist_backend: str = "gloo"
@property
def supported_dtypes(self) -> list[torch.dtype]:

View File

@ -56,6 +56,7 @@ class CudaPlatformBase(Platform):
device_type: str = "cuda"
dispatch_key: str = "CUDA"
ray_device_key: str = "GPU"
dist_backend: str = "nccl"
device_control_env_var: str = "CUDA_VISIBLE_DEVICES"
@property

View File

@ -26,6 +26,7 @@ class HpuPlatform(Platform):
device_type: str = "hpu"
dispatch_key: str = "HPU"
ray_device_key: str = "HPU"
dist_backend: str = "hccl"
device_control_env_var: str = "HABANA_VISIBLE_MODULES"
@classmethod

View File

@ -129,6 +129,9 @@ class Platform:
# compilation strategy.
simple_compile_backend: str = "inductor"
# The backend used for distributed communication.
dist_backend: str = ""
supported_quantization: list[str] = []
additional_env_vars: list[str] = []

View File

@ -30,6 +30,7 @@ class NeuronPlatform(Platform):
device_type: str = "neuron"
ray_device_key: str = "neuron_cores"
supported_quantization: list[str] = ["neuron_quant", "fbgemm_fp8"]
dist_backend: str = "gloo"
device_control_env_var: str = "NEURON_RT_VISIBLE_CORES"
@classmethod

View File

@ -164,6 +164,7 @@ class RocmPlatform(Platform):
device_type: str = "cuda"
dispatch_key: str = "CUDA"
ray_device_key: str = "GPU"
dist_backend: str = "nccl"
# rocm shares the same device control env var as CUDA
device_control_env_var: str = "CUDA_VISIBLE_DEVICES"

View File

@ -31,6 +31,7 @@ class TpuPlatform(Platform):
device_type: str = "tpu"
dispatch_key: str = "XLA"
ray_device_key: str = "TPU"
dist_backend: str = "gloo"
device_control_env_var: str = "TPU_VISIBLE_CHIPS"
simple_compile_backend: str = "openxla"

View File

@ -29,6 +29,7 @@ class XPUPlatform(Platform):
# Intel XPU's device key is "GPU" for Ray.
# see https://github.com/ray-project/ray/blob/6a5eb5865eeb9ccf058a79b44f107e327e360673/python/ray/_private/accelerators/intel_gpu.py#L20 # noqa: E501
ray_device_key: str = "GPU"
dist_backend: str = "ccl" # ccl | xccl
device_control_env_var: str = "ONEAPI_DEVICE_SELECTOR"
@classmethod

View File

@ -1886,6 +1886,12 @@ def supports_dynamo() -> bool:
return base_torch_version >= Version("2.4.0")
# Supports xccl with PyTorch versions >= 2.8.0 for XPU platform
def supports_xccl() -> bool:
return is_torch_equal_or_newer(
"2.8.0") and torch.distributed.is_xccl_available()
# Some backends use pytorch version < 2.4.0 which doesn't
# support `torch.library.custom_op`.
def supports_custom_op() -> bool:

View File

@ -11,6 +11,7 @@ from vllm.config import VllmConfig
from vllm.distributed.parallel_state import get_pp_group, get_tp_group
from vllm.logger import init_logger
from vllm.model_executor.utils import set_random_seed
from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.outputs import ModelRunnerOutput
@ -58,7 +59,8 @@ class CPUWorker(Worker):
# Initialize the distributed environment.
init_worker_distributed_environment(self.vllm_config, self.rank,
self.distributed_init_method,
self.local_rank, "gloo")
self.local_rank,
current_platform.dist_backend)
# Set random seed.
set_random_seed(self.model_config.seed)

View File

@ -157,7 +157,8 @@ class Worker(WorkerBase):
# Initialize the distributed environment.
init_worker_distributed_environment(self.vllm_config, self.rank,
self.distributed_init_method,
self.local_rank)
self.local_rank,
current_platform.dist_backend)
# Set random seed.
set_random_seed(self.model_config.seed)

View File

@ -18,6 +18,7 @@ from vllm.distributed import (ensure_model_parallel_initialized,
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.model_executor import set_random_seed
from vllm.platforms import current_platform
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, cdiv
from vllm.v1.attention.backends.pallas import TPU_HEAD_SIZE_ALIGNMENT
from vllm.v1.core.sched.output import SchedulerOutput
@ -300,7 +301,7 @@ class TPUWorker:
rank=rank,
local_rank=local_rank,
distributed_init_method=distributed_init_method,
backend="gloo",
backend=current_platform.dist_backend,
)
ensure_model_parallel_initialized(
parallel_config.tensor_parallel_size,

View File

@ -23,6 +23,7 @@ from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.model_executor import set_random_seed
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.platforms import current_platform
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sequence import ExecuteModelRequest
from vllm.utils import bind_kv_cache
@ -413,7 +414,7 @@ def init_worker_distributed_environment(
rank,
distributed_init_method,
local_rank,
backend='hccl')
backend=current_platform.dist_backend)
ensure_model_parallel_initialized(parallel_config.tensor_parallel_size,
parallel_config.pipeline_parallel_size)

View File

@ -156,7 +156,7 @@ class NeuronWorker(LocalOrDistributedWorkerBase):
rank=self.rank,
local_rank=self.local_rank,
distributed_init_method=self.distributed_init_method,
backend="gloo",
backend=current_platform.dist_backend,
)
ensure_model_parallel_initialized(

View File

@ -530,7 +530,8 @@ def init_worker_distributed_environment(
set_custom_all_reduce(not parallel_config.disable_custom_all_reduce)
init_distributed_environment(parallel_config.world_size, rank,
distributed_init_method, local_rank)
distributed_init_method, local_rank,
current_platform.dist_backend)
ensure_model_parallel_initialized(parallel_config.tensor_parallel_size,
parallel_config.pipeline_parallel_size)