From 4ebc9108a70ef6226a0ed7853b0f6681cd5dc18b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Nicol=C3=B2=20Lucchesi?= Date: Wed, 8 Oct 2025 22:25:31 +0200 Subject: [PATCH] [Kernel] Centralize platform kernel import in `current_platform.import_kernels` (#26286) Signed-off-by: NickLucche --- vllm/_custom_ops.py | 5 ++--- vllm/platforms/interface.py | 9 +-------- vllm/platforms/tpu.py | 7 +++++-- vllm/platforms/xpu.py | 7 +++++-- 4 files changed, 13 insertions(+), 15 deletions(-) diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 9fa346cca56d1..68ef9b96bbceb 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -12,8 +12,7 @@ from vllm.scalar_type import ScalarType logger = init_logger(__name__) -current_platform.import_core_kernels() -supports_moe_ops = current_platform.try_import_moe_kernels() +current_platform.import_kernels() if TYPE_CHECKING: @@ -1921,7 +1920,7 @@ def moe_wna16_marlin_gemm( ) -if supports_moe_ops and hasattr(torch.ops._moe_C, "marlin_gemm_moe"): +if hasattr(torch.ops, "_moe_C") and hasattr(torch.ops._moe_C, "marlin_gemm_moe"): @register_fake("_moe_C::marlin_gemm_moe") def marlin_gemm_moe_fake( diff --git a/vllm/platforms/interface.py b/vllm/platforms/interface.py index 6dc49f99ac2ad..e372ebf0cb3f7 100644 --- a/vllm/platforms/interface.py +++ b/vllm/platforms/interface.py @@ -170,22 +170,15 @@ class Platform: return device_id @classmethod - def import_core_kernels(cls) -> None: + def import_kernels(cls) -> None: """Import any platform-specific C kernels.""" try: import vllm._C # noqa: F401 except ImportError as e: logger.warning("Failed to import from vllm._C: %r", e) - - @classmethod - def try_import_moe_kernels(cls) -> bool: - """Import any platform-specific MoE kernels.""" with contextlib.suppress(ImportError): import vllm._moe_C # noqa: F401 - return True - return False - @classmethod def get_vit_attn_backend(cls, head_size: int, dtype: torch.dtype) -> "_Backend": from vllm.attention.backends.registry import _Backend diff --git a/vllm/platforms/tpu.py b/vllm/platforms/tpu.py index 8c23b1de44e4e..1c323ba8200a2 100644 --- a/vllm/platforms/tpu.py +++ b/vllm/platforms/tpu.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import contextlib from typing import TYPE_CHECKING, Optional, Union, cast import torch @@ -45,8 +46,10 @@ class TpuPlatform(Platform): additional_env_vars: list[str] = ["TPU_CHIPS_PER_HOST_BOUNDS", "TPU_HOST_BOUNDS"] @classmethod - def import_core_kernels(cls) -> None: - pass + def import_kernels(cls) -> None: + # Do not import vllm._C + with contextlib.suppress(ImportError): + import vllm._moe_C # noqa: F401 @classmethod def get_attn_backend_cls( diff --git a/vllm/platforms/xpu.py b/vllm/platforms/xpu.py index 2f2f3ab8b9d94..e0c8a6605b7d4 100644 --- a/vllm/platforms/xpu.py +++ b/vllm/platforms/xpu.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import contextlib import os from typing import TYPE_CHECKING, Optional @@ -35,8 +36,10 @@ class XPUPlatform(Platform): device_control_env_var: str = "ZE_AFFINITY_MASK" @classmethod - def import_core_kernels(cls) -> None: - pass + def import_kernels(cls) -> None: + # Do not import vllm._C + with contextlib.suppress(ImportError): + import vllm._moe_C # noqa: F401 @classmethod def get_attn_backend_cls(