From 17edd8a807019c8d1e58634aecb1de7984e8d467 Mon Sep 17 00:00:00 2001 From: Hank_ <37239608+ILikeIneine@users.noreply.github.com> Date: Sun, 5 Oct 2025 19:25:15 +0800 Subject: [PATCH] [Platform][Kernel] platform-specific kernel loading (#25823) Signed-off-by: Hank --- vllm/_custom_ops.py | 13 ++----------- vllm/platforms/interface.py | 17 +++++++++++++++++ vllm/platforms/tpu.py | 4 ++++ vllm/platforms/xpu.py | 4 ++++ 4 files changed, 27 insertions(+), 11 deletions(-) diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 84d96ee3a84d6..0a83faba513f9 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import contextlib from typing import TYPE_CHECKING, Optional, Union import torch @@ -13,16 +12,8 @@ from vllm.scalar_type import ScalarType logger = init_logger(__name__) -if not current_platform.is_tpu() and not current_platform.is_xpu(): - try: - import vllm._C - except ImportError as e: - logger.warning("Failed to import from vllm._C with %r", e) - -supports_moe_ops = False -with contextlib.suppress(ImportError): - import vllm._moe_C # noqa: F401 - supports_moe_ops = True +current_platform.import_core_kernels() +supports_moe_ops = current_platform.try_import_moe_kernels() if TYPE_CHECKING: diff --git a/vllm/platforms/interface.py b/vllm/platforms/interface.py index df1395fa842a7..dd51030e4d5c7 100644 --- a/vllm/platforms/interface.py +++ b/vllm/platforms/interface.py @@ -1,5 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import contextlib import enum import os import platform @@ -163,6 +164,22 @@ class Platform: else: return device_id + @classmethod + def import_core_kernels(cls) -> None: + """ Import any platform-specific C kernels. """ + try: + import vllm._C # noqa: F401 + except ImportError as e: + logger.warning("Failed to import from vllm._C: %r", e) + + @classmethod + def try_import_moe_kernels(cls) -> bool: + """ Import any platform-specific MoE kernels. """ + with contextlib.suppress(ImportError): + import vllm._moe_C # noqa: F401 + return True + return False + @classmethod def get_vit_attn_backend(cls, head_size: int, dtype: torch.dtype) -> "_Backend": diff --git a/vllm/platforms/tpu.py b/vllm/platforms/tpu.py index 91a01a4f4ee92..34b7dedbecc73 100644 --- a/vllm/platforms/tpu.py +++ b/vllm/platforms/tpu.py @@ -47,6 +47,10 @@ class TpuPlatform(Platform): "TPU_CHIPS_PER_HOST_BOUNDS", "TPU_HOST_BOUNDS" ] + @classmethod + def import_core_kernels(cls) -> None: + pass + @classmethod def get_attn_backend_cls(cls, selected_backend: "_Backend", head_size: int, dtype: torch.dtype, kv_cache_dtype: Optional[str], diff --git a/vllm/platforms/xpu.py b/vllm/platforms/xpu.py index 3ccbae58726f5..3efd498cf58e2 100644 --- a/vllm/platforms/xpu.py +++ b/vllm/platforms/xpu.py @@ -34,6 +34,10 @@ class XPUPlatform(Platform): dist_backend: str = "ccl" # ccl | xccl device_control_env_var: str = "ZE_AFFINITY_MASK" + @classmethod + def import_core_kernels(cls) -> None: + pass + @classmethod def get_attn_backend_cls(cls, selected_backend: "_Backend", head_size: int, dtype: torch.dtype, kv_cache_dtype: Optional[str],