[Platform][Kernel] platform-specific kernel loading (#25823)

Signed-off-by: Hank <hcc.mayday@gmail.com>
This commit is contained in:
Hank_ 2025-10-05 19:25:15 +08:00 committed by GitHub
parent 3303cfb4ac
commit 17edd8a807
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 27 additions and 11 deletions

View File

@ -1,7 +1,6 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import contextlib
from typing import TYPE_CHECKING, Optional, Union from typing import TYPE_CHECKING, Optional, Union
import torch import torch
@ -13,16 +12,8 @@ from vllm.scalar_type import ScalarType
logger = init_logger(__name__) logger = init_logger(__name__)
if not current_platform.is_tpu() and not current_platform.is_xpu(): current_platform.import_core_kernels()
try: supports_moe_ops = current_platform.try_import_moe_kernels()
import vllm._C
except ImportError as e:
logger.warning("Failed to import from vllm._C with %r", e)
supports_moe_ops = False
with contextlib.suppress(ImportError):
import vllm._moe_C # noqa: F401
supports_moe_ops = True
if TYPE_CHECKING: if TYPE_CHECKING:

View File

@ -1,5 +1,6 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import contextlib
import enum import enum
import os import os
import platform import platform
@ -163,6 +164,22 @@ class Platform:
else: else:
return device_id return device_id
@classmethod
def import_core_kernels(cls) -> None:
""" Import any platform-specific C kernels. """
try:
import vllm._C # noqa: F401
except ImportError as e:
logger.warning("Failed to import from vllm._C: %r", e)
@classmethod
def try_import_moe_kernels(cls) -> bool:
""" Import any platform-specific MoE kernels. """
with contextlib.suppress(ImportError):
import vllm._moe_C # noqa: F401
return True
return False
@classmethod @classmethod
def get_vit_attn_backend(cls, head_size: int, def get_vit_attn_backend(cls, head_size: int,
dtype: torch.dtype) -> "_Backend": dtype: torch.dtype) -> "_Backend":

View File

@ -47,6 +47,10 @@ class TpuPlatform(Platform):
"TPU_CHIPS_PER_HOST_BOUNDS", "TPU_HOST_BOUNDS" "TPU_CHIPS_PER_HOST_BOUNDS", "TPU_HOST_BOUNDS"
] ]
@classmethod
def import_core_kernels(cls) -> None:
pass
@classmethod @classmethod
def get_attn_backend_cls(cls, selected_backend: "_Backend", head_size: int, def get_attn_backend_cls(cls, selected_backend: "_Backend", head_size: int,
dtype: torch.dtype, kv_cache_dtype: Optional[str], dtype: torch.dtype, kv_cache_dtype: Optional[str],

View File

@ -34,6 +34,10 @@ class XPUPlatform(Platform):
dist_backend: str = "ccl" # ccl | xccl dist_backend: str = "ccl" # ccl | xccl
device_control_env_var: str = "ZE_AFFINITY_MASK" device_control_env_var: str = "ZE_AFFINITY_MASK"
@classmethod
def import_core_kernels(cls) -> None:
pass
@classmethod @classmethod
def get_attn_backend_cls(cls, selected_backend: "_Backend", head_size: int, def get_attn_backend_cls(cls, selected_backend: "_Backend", head_size: int,
dtype: torch.dtype, kv_cache_dtype: Optional[str], dtype: torch.dtype, kv_cache_dtype: Optional[str],