[Refactor] Create a function util and cache the results for has_deepgemm, has_deepep, has_pplx (#20187)

Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
Wentao Ye 2025-06-28 18:06:38 -04:00 committed by GitHub
parent daec9dea6e
commit 4d36693687
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 61 additions and 58 deletions

View File

@ -6,7 +6,6 @@ fp8 block-quantized case.
"""
import dataclasses
import importlib
from typing import Optional
import pytest
@ -21,18 +20,11 @@ from vllm.model_executor.layers.fused_moe.modular_kernel import (
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
per_token_group_quant_fp8)
from vllm.platforms import current_platform
from vllm.utils import has_deep_ep, has_deep_gemm
from .utils import ProcessGroupInfo, parallel_launch
has_deep_ep = importlib.util.find_spec("deep_ep") is not None
try:
import deep_gemm
has_deep_gemm = True
except ImportError:
has_deep_gemm = False
if has_deep_ep:
if has_deep_ep():
from vllm.model_executor.layers.fused_moe.deepep_ht_prepare_finalize import ( # noqa: E501
DeepEPHTPrepareAndFinalize)
from vllm.model_executor.layers.fused_moe.deepep_ll_prepare_finalize import ( # noqa: E501
@ -40,19 +32,21 @@ if has_deep_ep:
from .deepep_utils import DeepEPHTArgs, DeepEPLLArgs, make_deepep_a2a
if has_deep_gemm:
if has_deep_gemm():
import deep_gemm
from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import (
BatchedDeepGemmExperts)
from vllm.model_executor.layers.fused_moe.deep_gemm_moe import (
DeepGemmExperts)
requires_deep_ep = pytest.mark.skipif(
not has_deep_ep,
not has_deep_ep(),
reason="Requires deep_ep kernels",
)
requires_deep_gemm = pytest.mark.skipif(
not has_deep_gemm,
not has_deep_gemm(),
reason="Requires deep_gemm kernels",
)

View File

@ -4,7 +4,6 @@ Test deepep dispatch-combine logic
"""
import dataclasses
import importlib
from typing import Optional, Union
import pytest
@ -22,12 +21,11 @@ from vllm.model_executor.layers.fused_moe.modular_kernel import (
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
per_token_group_quant_fp8)
from vllm.platforms import current_platform
from vllm.utils import has_deep_ep
from .utils import ProcessGroupInfo, parallel_launch
has_deep_ep = importlib.util.find_spec("deep_ep") is not None
if has_deep_ep:
if has_deep_ep():
from vllm.model_executor.layers.fused_moe.deepep_ht_prepare_finalize import ( # noqa: E501
DeepEPHTPrepareAndFinalize)
from vllm.model_executor.layers.fused_moe.deepep_ll_prepare_finalize import ( # noqa: E501
@ -36,7 +34,7 @@ if has_deep_ep:
from .deepep_utils import DeepEPHTArgs, DeepEPLLArgs, make_deepep_a2a
requires_deep_ep = pytest.mark.skipif(
not has_deep_ep,
not has_deep_ep(),
reason="Requires deep_ep kernels",
)

View File

@ -1,6 +1,5 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import importlib.util
from typing import TYPE_CHECKING, Any
import torch
@ -8,6 +7,7 @@ import torch.distributed as dist
from vllm.forward_context import get_forward_context
from vllm.logger import init_logger
from vllm.utils import has_deep_ep, has_pplx
from .base_device_communicator import All2AllManagerBase, Cache
@ -80,8 +80,8 @@ class PPLXAll2AllManager(All2AllManagerBase):
"""
def __init__(self, cpu_group):
has_pplx = importlib.util.find_spec("pplx_kernels") is not None
assert has_pplx, "pplx_kernels not found. Please follow https://github.com/vllm-project/vllm/blob/main/tools/ep_kernels/README.md to install pplx_kernels." # noqa
assert has_pplx(
), "pplx_kernels not found. Please follow https://github.com/vllm-project/vllm/blob/main/tools/ep_kernels/README.md to install pplx_kernels." # noqa
super().__init__(cpu_group)
if self.internode:
@ -133,8 +133,8 @@ class DeepEPAll2AllManagerBase(All2AllManagerBase):
"""
def __init__(self, cpu_group):
has_deepep = importlib.util.find_spec("deep_ep") is not None
assert has_deepep, "DeepEP kernels not found. Please follow https://github.com/vllm-project/vllm/blob/main/tools/ep_kernels/README.md to install DeepEP kernels." # noqa
assert has_deep_ep(
), "DeepEP kernels not found. Please follow https://github.com/vllm-project/vllm/blob/main/tools/ep_kernels/README.md to install DeepEP kernels." # noqa
super().__init__(cpu_group)
self.handle_cache = Cache()

View File

@ -1,5 +1,4 @@
# SPDX-License-Identifier: Apache-2.0
import importlib.util
from typing import Optional
import torch
@ -11,8 +10,6 @@ from vllm.triton_utils import tl, triton
logger = init_logger(__name__)
has_deep_gemm = importlib.util.find_spec("deep_gemm") is not None
@triton.jit
def _silu_mul_fp8_quant_deep_gemm(

View File

@ -1,7 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import functools
import importlib.util
from typing import Optional
import torch
@ -12,14 +11,13 @@ from vllm.model_executor.layers.fused_moe.moe_permute_unpermute import (
_moe_permute)
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
MoEPrepareAndFinalizeNoEP)
from vllm.model_executor.layers.fused_moe.utils import (
_resize_cache, per_token_group_quant_fp8)
from vllm.utils import round_up
from vllm.model_executor.layers.fused_moe.utils import _resize_cache
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
per_token_group_quant_fp8)
from vllm.utils import has_deep_gemm, round_up
logger = init_logger(__name__)
has_deep_gemm = importlib.util.find_spec("deep_gemm") is not None
@functools.cache
def deep_gemm_block_shape() -> list[int]:
@ -41,7 +39,7 @@ def _valid_deep_gemm(hidden_states: torch.Tensor, w1: torch.Tensor,
gemm kernel. All of M, N, K and the quantization block_shape must be
aligned by `dg.get_m_alignment_for_contiguous_layout()`.
"""
if not has_deep_gemm:
if not has_deep_gemm():
logger.debug("DeepGemm disabled: deep_gemm not available.")
return False

View File

@ -1,7 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import importlib
from abc import abstractmethod
from collections.abc import Iterable
from dataclasses import dataclass
@ -32,10 +31,7 @@ from vllm.model_executor.layers.quantization.base_config import (
from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform
from vllm.platforms.interface import CpuArchEnum
from vllm.utils import direct_register_custom_op
has_pplx = importlib.util.find_spec("pplx_kernels") is not None
has_deepep = importlib.util.find_spec("deep_ep") is not None
from vllm.utils import direct_register_custom_op, has_deep_ep, has_pplx
if current_platform.is_cuda_alike():
from .fused_batched_moe import BatchedTritonExperts
@ -43,9 +39,9 @@ if current_platform.is_cuda_alike():
from .modular_kernel import (FusedMoEModularKernel,
FusedMoEPermuteExpertsUnpermute,
FusedMoEPrepareAndFinalize)
if has_pplx:
if has_pplx():
from .pplx_prepare_finalize import PplxPrepareAndFinalize
if has_deepep:
if has_deep_ep():
from .deepep_ht_prepare_finalize import DeepEPHTPrepareAndFinalize
from .deepep_ll_prepare_finalize import (DEEPEP_QUANT_BLOCK_SIZE,
DeepEPLLPrepareAndFinalize)

View File

@ -104,4 +104,4 @@ def find_free_port():
with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s:
s.bind(('', 0))
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
return s.getsockname()[1]
return s.getsockname()[1]

View File

@ -2,7 +2,6 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import enum
import importlib
from enum import Enum
from typing import Callable, Optional
@ -29,13 +28,12 @@ from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform
from vllm.scalar_type import scalar_types
has_pplx = importlib.util.find_spec("pplx_kernels") is not None
from vllm.utils import has_pplx
if current_platform.is_cuda_alike():
from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
BatchedPrepareAndFinalize)
if has_pplx:
if has_pplx():
from vllm.model_executor.layers.fused_moe.pplx_prepare_finalize import (
PplxPrepareAndFinalize)
@ -577,7 +575,7 @@ class CompressedTensorsW8A8Fp8MoECutlassMethod(CompressedTensorsMoEMethod):
use_batched_format=True,
)
if has_pplx and isinstance(
if has_pplx() and isinstance(
prepare_finalize,
(BatchedPrepareAndFinalize, PplxPrepareAndFinalize)):
# no expert_map support in this case

View File

@ -1,15 +1,13 @@
# SPDX-License-Identifier: Apache-2.0
import importlib.util
import logging
import torch
from vllm.platforms import current_platform
from vllm.triton_utils import triton
from vllm.utils import direct_register_custom_op
from vllm.utils import direct_register_custom_op, has_deep_gemm
has_deep_gemm = importlib.util.find_spec("deep_gemm") is not None
if has_deep_gemm:
if has_deep_gemm():
import deep_gemm
logger = logging.getLogger(__name__)

View File

@ -2,7 +2,6 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import functools
import importlib.util
from typing import Any, Callable, Optional, Union
import torch
@ -38,13 +37,12 @@ from vllm.model_executor.parameter import (BlockQuantScaleParameter,
from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform
from vllm.scalar_type import scalar_types
from vllm.utils import has_deep_gemm
ACTIVATION_SCHEMES = ["static", "dynamic"]
logger = init_logger(__name__)
has_deep_gemm = importlib.util.find_spec("deep_gemm") is not None
def _is_col_major(x: torch.Tensor) -> bool:
assert x.dim() == 3
@ -451,7 +449,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
# Check for DeepGemm support.
self.allow_deep_gemm = False
if envs.VLLM_USE_DEEP_GEMM:
if not has_deep_gemm:
if not has_deep_gemm():
logger.warning_once("Failed to import DeepGemm kernels.")
elif not self.block_quant:
logger.warning_once("Model is not block quantized. Not using "

View File

@ -3,7 +3,6 @@
# Adapted from https://github.com/sgl-project/sglang/pull/2575
import functools
import importlib.util
import json
import os
from typing import Any, Callable, Optional, Union
@ -19,10 +18,9 @@ from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
CUTLASS_BLOCK_FP8_SUPPORTED)
from vllm.platforms import current_platform
from vllm.triton_utils import tl, triton
from vllm.utils import cdiv, direct_register_custom_op
from vllm.utils import cdiv, direct_register_custom_op, has_deep_gemm
logger = init_logger(__name__)
has_deep_gemm = importlib.util.find_spec("deep_gemm") is not None
def is_fp8(x: Union[torch.dtype, torch.Tensor]) -> bool:
@ -109,7 +107,7 @@ def should_use_deepgemm(output_dtype: torch.dtype, weight: torch.Tensor):
"""
return (current_platform.is_cuda()
and current_platform.is_device_capability(90) and has_deep_gemm
and current_platform.is_device_capability(90) and has_deep_gemm()
and envs.VLLM_USE_DEEP_GEMM and output_dtype == torch.bfloat16
and weight.shape[0] % 128 == 0 and weight.shape[1] % 128 == 0)

View File

@ -2929,3 +2929,31 @@ def is_torch_equal_or_newer(target: str) -> bool:
def _is_torch_equal_or_newer(torch_version: str, target: str) -> bool:
torch_version = version.parse(torch_version)
return torch_version >= version.parse(target)
@cache
def _has_module(module_name: str) -> bool:
"""Return True if *module_name* can be found in the current environment.
The result is cached so that subsequent queries for the same module incur
no additional overhead.
"""
return importlib.util.find_spec(module_name) is not None
def has_pplx() -> bool:
"""Whether the optional `pplx_kernels` package is available."""
return _has_module("pplx_kernels")
def has_deep_ep() -> bool:
"""Whether the optional `deep_ep` package is available."""
return _has_module("deep_ep")
def has_deep_gemm() -> bool:
"""Whether the optional `deep_gemm` package is available."""
return _has_module("deep_gemm")