[Refactor] Create a function util and cache the results for has_deepgemm, has_deepep, has_pplx (#20187)

Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
Wentao Ye 2025-06-28 18:06:38 -04:00 committed by GitHub
parent daec9dea6e
commit 4d36693687
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 61 additions and 58 deletions

View File

@ -6,7 +6,6 @@ fp8 block-quantized case.
""" """
import dataclasses import dataclasses
import importlib
from typing import Optional from typing import Optional
import pytest import pytest
@ -21,18 +20,11 @@ from vllm.model_executor.layers.fused_moe.modular_kernel import (
from vllm.model_executor.layers.quantization.utils.fp8_utils import ( from vllm.model_executor.layers.quantization.utils.fp8_utils import (
per_token_group_quant_fp8) per_token_group_quant_fp8)
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils import has_deep_ep, has_deep_gemm
from .utils import ProcessGroupInfo, parallel_launch from .utils import ProcessGroupInfo, parallel_launch
has_deep_ep = importlib.util.find_spec("deep_ep") is not None if has_deep_ep():
try:
import deep_gemm
has_deep_gemm = True
except ImportError:
has_deep_gemm = False
if has_deep_ep:
from vllm.model_executor.layers.fused_moe.deepep_ht_prepare_finalize import ( # noqa: E501 from vllm.model_executor.layers.fused_moe.deepep_ht_prepare_finalize import ( # noqa: E501
DeepEPHTPrepareAndFinalize) DeepEPHTPrepareAndFinalize)
from vllm.model_executor.layers.fused_moe.deepep_ll_prepare_finalize import ( # noqa: E501 from vllm.model_executor.layers.fused_moe.deepep_ll_prepare_finalize import ( # noqa: E501
@ -40,19 +32,21 @@ if has_deep_ep:
from .deepep_utils import DeepEPHTArgs, DeepEPLLArgs, make_deepep_a2a from .deepep_utils import DeepEPHTArgs, DeepEPLLArgs, make_deepep_a2a
if has_deep_gemm: if has_deep_gemm():
import deep_gemm
from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import ( from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import (
BatchedDeepGemmExperts) BatchedDeepGemmExperts)
from vllm.model_executor.layers.fused_moe.deep_gemm_moe import ( from vllm.model_executor.layers.fused_moe.deep_gemm_moe import (
DeepGemmExperts) DeepGemmExperts)
requires_deep_ep = pytest.mark.skipif( requires_deep_ep = pytest.mark.skipif(
not has_deep_ep, not has_deep_ep(),
reason="Requires deep_ep kernels", reason="Requires deep_ep kernels",
) )
requires_deep_gemm = pytest.mark.skipif( requires_deep_gemm = pytest.mark.skipif(
not has_deep_gemm, not has_deep_gemm(),
reason="Requires deep_gemm kernels", reason="Requires deep_gemm kernels",
) )

View File

@ -4,7 +4,6 @@ Test deepep dispatch-combine logic
""" """
import dataclasses import dataclasses
import importlib
from typing import Optional, Union from typing import Optional, Union
import pytest import pytest
@ -22,12 +21,11 @@ from vllm.model_executor.layers.fused_moe.modular_kernel import (
from vllm.model_executor.layers.quantization.utils.fp8_utils import ( from vllm.model_executor.layers.quantization.utils.fp8_utils import (
per_token_group_quant_fp8) per_token_group_quant_fp8)
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils import has_deep_ep
from .utils import ProcessGroupInfo, parallel_launch from .utils import ProcessGroupInfo, parallel_launch
has_deep_ep = importlib.util.find_spec("deep_ep") is not None if has_deep_ep():
if has_deep_ep:
from vllm.model_executor.layers.fused_moe.deepep_ht_prepare_finalize import ( # noqa: E501 from vllm.model_executor.layers.fused_moe.deepep_ht_prepare_finalize import ( # noqa: E501
DeepEPHTPrepareAndFinalize) DeepEPHTPrepareAndFinalize)
from vllm.model_executor.layers.fused_moe.deepep_ll_prepare_finalize import ( # noqa: E501 from vllm.model_executor.layers.fused_moe.deepep_ll_prepare_finalize import ( # noqa: E501
@ -36,7 +34,7 @@ if has_deep_ep:
from .deepep_utils import DeepEPHTArgs, DeepEPLLArgs, make_deepep_a2a from .deepep_utils import DeepEPHTArgs, DeepEPLLArgs, make_deepep_a2a
requires_deep_ep = pytest.mark.skipif( requires_deep_ep = pytest.mark.skipif(
not has_deep_ep, not has_deep_ep(),
reason="Requires deep_ep kernels", reason="Requires deep_ep kernels",
) )

View File

@ -1,6 +1,5 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import importlib.util
from typing import TYPE_CHECKING, Any from typing import TYPE_CHECKING, Any
import torch import torch
@ -8,6 +7,7 @@ import torch.distributed as dist
from vllm.forward_context import get_forward_context from vllm.forward_context import get_forward_context
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.utils import has_deep_ep, has_pplx
from .base_device_communicator import All2AllManagerBase, Cache from .base_device_communicator import All2AllManagerBase, Cache
@ -80,8 +80,8 @@ class PPLXAll2AllManager(All2AllManagerBase):
""" """
def __init__(self, cpu_group): def __init__(self, cpu_group):
has_pplx = importlib.util.find_spec("pplx_kernels") is not None assert has_pplx(
assert has_pplx, "pplx_kernels not found. Please follow https://github.com/vllm-project/vllm/blob/main/tools/ep_kernels/README.md to install pplx_kernels." # noqa ), "pplx_kernels not found. Please follow https://github.com/vllm-project/vllm/blob/main/tools/ep_kernels/README.md to install pplx_kernels." # noqa
super().__init__(cpu_group) super().__init__(cpu_group)
if self.internode: if self.internode:
@ -133,8 +133,8 @@ class DeepEPAll2AllManagerBase(All2AllManagerBase):
""" """
def __init__(self, cpu_group): def __init__(self, cpu_group):
has_deepep = importlib.util.find_spec("deep_ep") is not None assert has_deep_ep(
assert has_deepep, "DeepEP kernels not found. Please follow https://github.com/vllm-project/vllm/blob/main/tools/ep_kernels/README.md to install DeepEP kernels." # noqa ), "DeepEP kernels not found. Please follow https://github.com/vllm-project/vllm/blob/main/tools/ep_kernels/README.md to install DeepEP kernels." # noqa
super().__init__(cpu_group) super().__init__(cpu_group)
self.handle_cache = Cache() self.handle_cache = Cache()

View File

@ -1,5 +1,4 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import importlib.util
from typing import Optional from typing import Optional
import torch import torch
@ -11,8 +10,6 @@ from vllm.triton_utils import tl, triton
logger = init_logger(__name__) logger = init_logger(__name__)
has_deep_gemm = importlib.util.find_spec("deep_gemm") is not None
@triton.jit @triton.jit
def _silu_mul_fp8_quant_deep_gemm( def _silu_mul_fp8_quant_deep_gemm(

View File

@ -1,7 +1,6 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import functools import functools
import importlib.util
from typing import Optional from typing import Optional
import torch import torch
@ -12,14 +11,13 @@ from vllm.model_executor.layers.fused_moe.moe_permute_unpermute import (
_moe_permute) _moe_permute)
from vllm.model_executor.layers.fused_moe.prepare_finalize import ( from vllm.model_executor.layers.fused_moe.prepare_finalize import (
MoEPrepareAndFinalizeNoEP) MoEPrepareAndFinalizeNoEP)
from vllm.model_executor.layers.fused_moe.utils import ( from vllm.model_executor.layers.fused_moe.utils import _resize_cache
_resize_cache, per_token_group_quant_fp8) from vllm.model_executor.layers.quantization.utils.fp8_utils import (
from vllm.utils import round_up per_token_group_quant_fp8)
from vllm.utils import has_deep_gemm, round_up
logger = init_logger(__name__) logger = init_logger(__name__)
has_deep_gemm = importlib.util.find_spec("deep_gemm") is not None
@functools.cache @functools.cache
def deep_gemm_block_shape() -> list[int]: def deep_gemm_block_shape() -> list[int]:
@ -41,7 +39,7 @@ def _valid_deep_gemm(hidden_states: torch.Tensor, w1: torch.Tensor,
gemm kernel. All of M, N, K and the quantization block_shape must be gemm kernel. All of M, N, K and the quantization block_shape must be
aligned by `dg.get_m_alignment_for_contiguous_layout()`. aligned by `dg.get_m_alignment_for_contiguous_layout()`.
""" """
if not has_deep_gemm: if not has_deep_gemm():
logger.debug("DeepGemm disabled: deep_gemm not available.") logger.debug("DeepGemm disabled: deep_gemm not available.")
return False return False

View File

@ -1,7 +1,6 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import importlib
from abc import abstractmethod from abc import abstractmethod
from collections.abc import Iterable from collections.abc import Iterable
from dataclasses import dataclass from dataclasses import dataclass
@ -32,10 +31,7 @@ from vllm.model_executor.layers.quantization.base_config import (
from vllm.model_executor.utils import set_weight_attrs from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.platforms.interface import CpuArchEnum from vllm.platforms.interface import CpuArchEnum
from vllm.utils import direct_register_custom_op from vllm.utils import direct_register_custom_op, has_deep_ep, has_pplx
has_pplx = importlib.util.find_spec("pplx_kernels") is not None
has_deepep = importlib.util.find_spec("deep_ep") is not None
if current_platform.is_cuda_alike(): if current_platform.is_cuda_alike():
from .fused_batched_moe import BatchedTritonExperts from .fused_batched_moe import BatchedTritonExperts
@ -43,9 +39,9 @@ if current_platform.is_cuda_alike():
from .modular_kernel import (FusedMoEModularKernel, from .modular_kernel import (FusedMoEModularKernel,
FusedMoEPermuteExpertsUnpermute, FusedMoEPermuteExpertsUnpermute,
FusedMoEPrepareAndFinalize) FusedMoEPrepareAndFinalize)
if has_pplx: if has_pplx():
from .pplx_prepare_finalize import PplxPrepareAndFinalize from .pplx_prepare_finalize import PplxPrepareAndFinalize
if has_deepep: if has_deep_ep():
from .deepep_ht_prepare_finalize import DeepEPHTPrepareAndFinalize from .deepep_ht_prepare_finalize import DeepEPHTPrepareAndFinalize
from .deepep_ll_prepare_finalize import (DEEPEP_QUANT_BLOCK_SIZE, from .deepep_ll_prepare_finalize import (DEEPEP_QUANT_BLOCK_SIZE,
DeepEPLLPrepareAndFinalize) DeepEPLLPrepareAndFinalize)

View File

@ -104,4 +104,4 @@ def find_free_port():
with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s: with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s:
s.bind(('', 0)) s.bind(('', 0))
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
return s.getsockname()[1] return s.getsockname()[1]

View File

@ -2,7 +2,6 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import enum import enum
import importlib
from enum import Enum from enum import Enum
from typing import Callable, Optional from typing import Callable, Optional
@ -29,13 +28,12 @@ from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
from vllm.model_executor.utils import set_weight_attrs from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.scalar_type import scalar_types from vllm.scalar_type import scalar_types
from vllm.utils import has_pplx
has_pplx = importlib.util.find_spec("pplx_kernels") is not None
if current_platform.is_cuda_alike(): if current_platform.is_cuda_alike():
from vllm.model_executor.layers.fused_moe.fused_batched_moe import ( from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
BatchedPrepareAndFinalize) BatchedPrepareAndFinalize)
if has_pplx: if has_pplx():
from vllm.model_executor.layers.fused_moe.pplx_prepare_finalize import ( from vllm.model_executor.layers.fused_moe.pplx_prepare_finalize import (
PplxPrepareAndFinalize) PplxPrepareAndFinalize)
@ -577,7 +575,7 @@ class CompressedTensorsW8A8Fp8MoECutlassMethod(CompressedTensorsMoEMethod):
use_batched_format=True, use_batched_format=True,
) )
if has_pplx and isinstance( if has_pplx() and isinstance(
prepare_finalize, prepare_finalize,
(BatchedPrepareAndFinalize, PplxPrepareAndFinalize)): (BatchedPrepareAndFinalize, PplxPrepareAndFinalize)):
# no expert_map support in this case # no expert_map support in this case

View File

@ -1,15 +1,13 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import importlib.util
import logging import logging
import torch import torch
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.triton_utils import triton from vllm.triton_utils import triton
from vllm.utils import direct_register_custom_op from vllm.utils import direct_register_custom_op, has_deep_gemm
has_deep_gemm = importlib.util.find_spec("deep_gemm") is not None if has_deep_gemm():
if has_deep_gemm:
import deep_gemm import deep_gemm
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)

View File

@ -2,7 +2,6 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import functools import functools
import importlib.util
from typing import Any, Callable, Optional, Union from typing import Any, Callable, Optional, Union
import torch import torch
@ -38,13 +37,12 @@ from vllm.model_executor.parameter import (BlockQuantScaleParameter,
from vllm.model_executor.utils import set_weight_attrs from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.scalar_type import scalar_types from vllm.scalar_type import scalar_types
from vllm.utils import has_deep_gemm
ACTIVATION_SCHEMES = ["static", "dynamic"] ACTIVATION_SCHEMES = ["static", "dynamic"]
logger = init_logger(__name__) logger = init_logger(__name__)
has_deep_gemm = importlib.util.find_spec("deep_gemm") is not None
def _is_col_major(x: torch.Tensor) -> bool: def _is_col_major(x: torch.Tensor) -> bool:
assert x.dim() == 3 assert x.dim() == 3
@ -451,7 +449,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
# Check for DeepGemm support. # Check for DeepGemm support.
self.allow_deep_gemm = False self.allow_deep_gemm = False
if envs.VLLM_USE_DEEP_GEMM: if envs.VLLM_USE_DEEP_GEMM:
if not has_deep_gemm: if not has_deep_gemm():
logger.warning_once("Failed to import DeepGemm kernels.") logger.warning_once("Failed to import DeepGemm kernels.")
elif not self.block_quant: elif not self.block_quant:
logger.warning_once("Model is not block quantized. Not using " logger.warning_once("Model is not block quantized. Not using "

View File

@ -3,7 +3,6 @@
# Adapted from https://github.com/sgl-project/sglang/pull/2575 # Adapted from https://github.com/sgl-project/sglang/pull/2575
import functools import functools
import importlib.util
import json import json
import os import os
from typing import Any, Callable, Optional, Union from typing import Any, Callable, Optional, Union
@ -19,10 +18,9 @@ from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
CUTLASS_BLOCK_FP8_SUPPORTED) CUTLASS_BLOCK_FP8_SUPPORTED)
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.triton_utils import tl, triton from vllm.triton_utils import tl, triton
from vllm.utils import cdiv, direct_register_custom_op from vllm.utils import cdiv, direct_register_custom_op, has_deep_gemm
logger = init_logger(__name__) logger = init_logger(__name__)
has_deep_gemm = importlib.util.find_spec("deep_gemm") is not None
def is_fp8(x: Union[torch.dtype, torch.Tensor]) -> bool: def is_fp8(x: Union[torch.dtype, torch.Tensor]) -> bool:
@ -109,7 +107,7 @@ def should_use_deepgemm(output_dtype: torch.dtype, weight: torch.Tensor):
""" """
return (current_platform.is_cuda() return (current_platform.is_cuda()
and current_platform.is_device_capability(90) and has_deep_gemm and current_platform.is_device_capability(90) and has_deep_gemm()
and envs.VLLM_USE_DEEP_GEMM and output_dtype == torch.bfloat16 and envs.VLLM_USE_DEEP_GEMM and output_dtype == torch.bfloat16
and weight.shape[0] % 128 == 0 and weight.shape[1] % 128 == 0) and weight.shape[0] % 128 == 0 and weight.shape[1] % 128 == 0)

View File

@ -2929,3 +2929,31 @@ def is_torch_equal_or_newer(target: str) -> bool:
def _is_torch_equal_or_newer(torch_version: str, target: str) -> bool: def _is_torch_equal_or_newer(torch_version: str, target: str) -> bool:
torch_version = version.parse(torch_version) torch_version = version.parse(torch_version)
return torch_version >= version.parse(target) return torch_version >= version.parse(target)
@cache
def _has_module(module_name: str) -> bool:
"""Return True if *module_name* can be found in the current environment.
The result is cached so that subsequent queries for the same module incur
no additional overhead.
"""
return importlib.util.find_spec(module_name) is not None
def has_pplx() -> bool:
"""Whether the optional `pplx_kernels` package is available."""
return _has_module("pplx_kernels")
def has_deep_ep() -> bool:
"""Whether the optional `deep_ep` package is available."""
return _has_module("deep_ep")
def has_deep_gemm() -> bool:
"""Whether the optional `deep_gemm` package is available."""
return _has_module("deep_gemm")