mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 09:06:02 +08:00
[Refactor] Create a function util and cache the results for has_deepgemm, has_deepep, has_pplx (#20187)
Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
parent
daec9dea6e
commit
4d36693687
@ -6,7 +6,6 @@ fp8 block-quantized case.
|
||||
"""
|
||||
|
||||
import dataclasses
|
||||
import importlib
|
||||
from typing import Optional
|
||||
|
||||
import pytest
|
||||
@ -21,18 +20,11 @@ from vllm.model_executor.layers.fused_moe.modular_kernel import (
|
||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||
per_token_group_quant_fp8)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import has_deep_ep, has_deep_gemm
|
||||
|
||||
from .utils import ProcessGroupInfo, parallel_launch
|
||||
|
||||
has_deep_ep = importlib.util.find_spec("deep_ep") is not None
|
||||
|
||||
try:
|
||||
import deep_gemm
|
||||
has_deep_gemm = True
|
||||
except ImportError:
|
||||
has_deep_gemm = False
|
||||
|
||||
if has_deep_ep:
|
||||
if has_deep_ep():
|
||||
from vllm.model_executor.layers.fused_moe.deepep_ht_prepare_finalize import ( # noqa: E501
|
||||
DeepEPHTPrepareAndFinalize)
|
||||
from vllm.model_executor.layers.fused_moe.deepep_ll_prepare_finalize import ( # noqa: E501
|
||||
@ -40,19 +32,21 @@ if has_deep_ep:
|
||||
|
||||
from .deepep_utils import DeepEPHTArgs, DeepEPLLArgs, make_deepep_a2a
|
||||
|
||||
if has_deep_gemm:
|
||||
if has_deep_gemm():
|
||||
import deep_gemm
|
||||
|
||||
from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import (
|
||||
BatchedDeepGemmExperts)
|
||||
from vllm.model_executor.layers.fused_moe.deep_gemm_moe import (
|
||||
DeepGemmExperts)
|
||||
|
||||
requires_deep_ep = pytest.mark.skipif(
|
||||
not has_deep_ep,
|
||||
not has_deep_ep(),
|
||||
reason="Requires deep_ep kernels",
|
||||
)
|
||||
|
||||
requires_deep_gemm = pytest.mark.skipif(
|
||||
not has_deep_gemm,
|
||||
not has_deep_gemm(),
|
||||
reason="Requires deep_gemm kernels",
|
||||
)
|
||||
|
||||
|
||||
@ -4,7 +4,6 @@ Test deepep dispatch-combine logic
|
||||
"""
|
||||
|
||||
import dataclasses
|
||||
import importlib
|
||||
from typing import Optional, Union
|
||||
|
||||
import pytest
|
||||
@ -22,12 +21,11 @@ from vllm.model_executor.layers.fused_moe.modular_kernel import (
|
||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||
per_token_group_quant_fp8)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import has_deep_ep
|
||||
|
||||
from .utils import ProcessGroupInfo, parallel_launch
|
||||
|
||||
has_deep_ep = importlib.util.find_spec("deep_ep") is not None
|
||||
|
||||
if has_deep_ep:
|
||||
if has_deep_ep():
|
||||
from vllm.model_executor.layers.fused_moe.deepep_ht_prepare_finalize import ( # noqa: E501
|
||||
DeepEPHTPrepareAndFinalize)
|
||||
from vllm.model_executor.layers.fused_moe.deepep_ll_prepare_finalize import ( # noqa: E501
|
||||
@ -36,7 +34,7 @@ if has_deep_ep:
|
||||
from .deepep_utils import DeepEPHTArgs, DeepEPLLArgs, make_deepep_a2a
|
||||
|
||||
requires_deep_ep = pytest.mark.skipif(
|
||||
not has_deep_ep,
|
||||
not has_deep_ep(),
|
||||
reason="Requires deep_ep kernels",
|
||||
)
|
||||
|
||||
|
||||
@ -1,6 +1,5 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import importlib.util
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import torch
|
||||
@ -8,6 +7,7 @@ import torch.distributed as dist
|
||||
|
||||
from vllm.forward_context import get_forward_context
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import has_deep_ep, has_pplx
|
||||
|
||||
from .base_device_communicator import All2AllManagerBase, Cache
|
||||
|
||||
@ -80,8 +80,8 @@ class PPLXAll2AllManager(All2AllManagerBase):
|
||||
"""
|
||||
|
||||
def __init__(self, cpu_group):
|
||||
has_pplx = importlib.util.find_spec("pplx_kernels") is not None
|
||||
assert has_pplx, "pplx_kernels not found. Please follow https://github.com/vllm-project/vllm/blob/main/tools/ep_kernels/README.md to install pplx_kernels." # noqa
|
||||
assert has_pplx(
|
||||
), "pplx_kernels not found. Please follow https://github.com/vllm-project/vllm/blob/main/tools/ep_kernels/README.md to install pplx_kernels." # noqa
|
||||
super().__init__(cpu_group)
|
||||
|
||||
if self.internode:
|
||||
@ -133,8 +133,8 @@ class DeepEPAll2AllManagerBase(All2AllManagerBase):
|
||||
"""
|
||||
|
||||
def __init__(self, cpu_group):
|
||||
has_deepep = importlib.util.find_spec("deep_ep") is not None
|
||||
assert has_deepep, "DeepEP kernels not found. Please follow https://github.com/vllm-project/vllm/blob/main/tools/ep_kernels/README.md to install DeepEP kernels." # noqa
|
||||
assert has_deep_ep(
|
||||
), "DeepEP kernels not found. Please follow https://github.com/vllm-project/vllm/blob/main/tools/ep_kernels/README.md to install DeepEP kernels." # noqa
|
||||
super().__init__(cpu_group)
|
||||
self.handle_cache = Cache()
|
||||
|
||||
|
||||
@ -1,5 +1,4 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
import importlib.util
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
@ -11,8 +10,6 @@ from vllm.triton_utils import tl, triton
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
has_deep_gemm = importlib.util.find_spec("deep_gemm") is not None
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _silu_mul_fp8_quant_deep_gemm(
|
||||
|
||||
@ -1,7 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import functools
|
||||
import importlib.util
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
@ -12,14 +11,13 @@ from vllm.model_executor.layers.fused_moe.moe_permute_unpermute import (
|
||||
_moe_permute)
|
||||
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
|
||||
MoEPrepareAndFinalizeNoEP)
|
||||
from vllm.model_executor.layers.fused_moe.utils import (
|
||||
_resize_cache, per_token_group_quant_fp8)
|
||||
from vllm.utils import round_up
|
||||
from vllm.model_executor.layers.fused_moe.utils import _resize_cache
|
||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||
per_token_group_quant_fp8)
|
||||
from vllm.utils import has_deep_gemm, round_up
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
has_deep_gemm = importlib.util.find_spec("deep_gemm") is not None
|
||||
|
||||
|
||||
@functools.cache
|
||||
def deep_gemm_block_shape() -> list[int]:
|
||||
@ -41,7 +39,7 @@ def _valid_deep_gemm(hidden_states: torch.Tensor, w1: torch.Tensor,
|
||||
gemm kernel. All of M, N, K and the quantization block_shape must be
|
||||
aligned by `dg.get_m_alignment_for_contiguous_layout()`.
|
||||
"""
|
||||
if not has_deep_gemm:
|
||||
if not has_deep_gemm():
|
||||
logger.debug("DeepGemm disabled: deep_gemm not available.")
|
||||
return False
|
||||
|
||||
|
||||
@ -1,7 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import importlib
|
||||
from abc import abstractmethod
|
||||
from collections.abc import Iterable
|
||||
from dataclasses import dataclass
|
||||
@ -32,10 +31,7 @@ from vllm.model_executor.layers.quantization.base_config import (
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.platforms.interface import CpuArchEnum
|
||||
from vllm.utils import direct_register_custom_op
|
||||
|
||||
has_pplx = importlib.util.find_spec("pplx_kernels") is not None
|
||||
has_deepep = importlib.util.find_spec("deep_ep") is not None
|
||||
from vllm.utils import direct_register_custom_op, has_deep_ep, has_pplx
|
||||
|
||||
if current_platform.is_cuda_alike():
|
||||
from .fused_batched_moe import BatchedTritonExperts
|
||||
@ -43,9 +39,9 @@ if current_platform.is_cuda_alike():
|
||||
from .modular_kernel import (FusedMoEModularKernel,
|
||||
FusedMoEPermuteExpertsUnpermute,
|
||||
FusedMoEPrepareAndFinalize)
|
||||
if has_pplx:
|
||||
if has_pplx():
|
||||
from .pplx_prepare_finalize import PplxPrepareAndFinalize
|
||||
if has_deepep:
|
||||
if has_deep_ep():
|
||||
from .deepep_ht_prepare_finalize import DeepEPHTPrepareAndFinalize
|
||||
from .deepep_ll_prepare_finalize import (DEEPEP_QUANT_BLOCK_SIZE,
|
||||
DeepEPLLPrepareAndFinalize)
|
||||
|
||||
@ -104,4 +104,4 @@ def find_free_port():
|
||||
with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s:
|
||||
s.bind(('', 0))
|
||||
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
||||
return s.getsockname()[1]
|
||||
return s.getsockname()[1]
|
||||
|
||||
@ -2,7 +2,6 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import enum
|
||||
import importlib
|
||||
from enum import Enum
|
||||
from typing import Callable, Optional
|
||||
|
||||
@ -29,13 +28,12 @@ from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.scalar_type import scalar_types
|
||||
|
||||
has_pplx = importlib.util.find_spec("pplx_kernels") is not None
|
||||
from vllm.utils import has_pplx
|
||||
|
||||
if current_platform.is_cuda_alike():
|
||||
from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
|
||||
BatchedPrepareAndFinalize)
|
||||
if has_pplx:
|
||||
if has_pplx():
|
||||
from vllm.model_executor.layers.fused_moe.pplx_prepare_finalize import (
|
||||
PplxPrepareAndFinalize)
|
||||
|
||||
@ -577,7 +575,7 @@ class CompressedTensorsW8A8Fp8MoECutlassMethod(CompressedTensorsMoEMethod):
|
||||
use_batched_format=True,
|
||||
)
|
||||
|
||||
if has_pplx and isinstance(
|
||||
if has_pplx() and isinstance(
|
||||
prepare_finalize,
|
||||
(BatchedPrepareAndFinalize, PplxPrepareAndFinalize)):
|
||||
# no expert_map support in this case
|
||||
|
||||
@ -1,15 +1,13 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
import importlib.util
|
||||
import logging
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.triton_utils import triton
|
||||
from vllm.utils import direct_register_custom_op
|
||||
from vllm.utils import direct_register_custom_op, has_deep_gemm
|
||||
|
||||
has_deep_gemm = importlib.util.find_spec("deep_gemm") is not None
|
||||
if has_deep_gemm:
|
||||
if has_deep_gemm():
|
||||
import deep_gemm
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -2,7 +2,6 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import functools
|
||||
import importlib.util
|
||||
from typing import Any, Callable, Optional, Union
|
||||
|
||||
import torch
|
||||
@ -38,13 +37,12 @@ from vllm.model_executor.parameter import (BlockQuantScaleParameter,
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.scalar_type import scalar_types
|
||||
from vllm.utils import has_deep_gemm
|
||||
|
||||
ACTIVATION_SCHEMES = ["static", "dynamic"]
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
has_deep_gemm = importlib.util.find_spec("deep_gemm") is not None
|
||||
|
||||
|
||||
def _is_col_major(x: torch.Tensor) -> bool:
|
||||
assert x.dim() == 3
|
||||
@ -451,7 +449,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
# Check for DeepGemm support.
|
||||
self.allow_deep_gemm = False
|
||||
if envs.VLLM_USE_DEEP_GEMM:
|
||||
if not has_deep_gemm:
|
||||
if not has_deep_gemm():
|
||||
logger.warning_once("Failed to import DeepGemm kernels.")
|
||||
elif not self.block_quant:
|
||||
logger.warning_once("Model is not block quantized. Not using "
|
||||
|
||||
@ -3,7 +3,6 @@
|
||||
|
||||
# Adapted from https://github.com/sgl-project/sglang/pull/2575
|
||||
import functools
|
||||
import importlib.util
|
||||
import json
|
||||
import os
|
||||
from typing import Any, Callable, Optional, Union
|
||||
@ -19,10 +18,9 @@ from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
CUTLASS_BLOCK_FP8_SUPPORTED)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.triton_utils import tl, triton
|
||||
from vllm.utils import cdiv, direct_register_custom_op
|
||||
from vllm.utils import cdiv, direct_register_custom_op, has_deep_gemm
|
||||
|
||||
logger = init_logger(__name__)
|
||||
has_deep_gemm = importlib.util.find_spec("deep_gemm") is not None
|
||||
|
||||
|
||||
def is_fp8(x: Union[torch.dtype, torch.Tensor]) -> bool:
|
||||
@ -109,7 +107,7 @@ def should_use_deepgemm(output_dtype: torch.dtype, weight: torch.Tensor):
|
||||
"""
|
||||
|
||||
return (current_platform.is_cuda()
|
||||
and current_platform.is_device_capability(90) and has_deep_gemm
|
||||
and current_platform.is_device_capability(90) and has_deep_gemm()
|
||||
and envs.VLLM_USE_DEEP_GEMM and output_dtype == torch.bfloat16
|
||||
and weight.shape[0] % 128 == 0 and weight.shape[1] % 128 == 0)
|
||||
|
||||
|
||||
@ -2929,3 +2929,31 @@ def is_torch_equal_or_newer(target: str) -> bool:
|
||||
def _is_torch_equal_or_newer(torch_version: str, target: str) -> bool:
|
||||
torch_version = version.parse(torch_version)
|
||||
return torch_version >= version.parse(target)
|
||||
|
||||
|
||||
@cache
|
||||
def _has_module(module_name: str) -> bool:
|
||||
"""Return True if *module_name* can be found in the current environment.
|
||||
|
||||
The result is cached so that subsequent queries for the same module incur
|
||||
no additional overhead.
|
||||
"""
|
||||
return importlib.util.find_spec(module_name) is not None
|
||||
|
||||
|
||||
def has_pplx() -> bool:
|
||||
"""Whether the optional `pplx_kernels` package is available."""
|
||||
|
||||
return _has_module("pplx_kernels")
|
||||
|
||||
|
||||
def has_deep_ep() -> bool:
|
||||
"""Whether the optional `deep_ep` package is available."""
|
||||
|
||||
return _has_module("deep_ep")
|
||||
|
||||
|
||||
def has_deep_gemm() -> bool:
|
||||
"""Whether the optional `deep_gemm` package is available."""
|
||||
|
||||
return _has_module("deep_gemm")
|
||||
Loading…
x
Reference in New Issue
Block a user