mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-15 12:16:13 +08:00
[Refactor] Create a function util and cache the results for has_deepgemm, has_deepep, has_pplx (#20187)
Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
parent
daec9dea6e
commit
4d36693687
@ -6,7 +6,6 @@ fp8 block-quantized case.
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import dataclasses
|
import dataclasses
|
||||||
import importlib
|
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
@ -21,18 +20,11 @@ from vllm.model_executor.layers.fused_moe.modular_kernel import (
|
|||||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||||
per_token_group_quant_fp8)
|
per_token_group_quant_fp8)
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
|
from vllm.utils import has_deep_ep, has_deep_gemm
|
||||||
|
|
||||||
from .utils import ProcessGroupInfo, parallel_launch
|
from .utils import ProcessGroupInfo, parallel_launch
|
||||||
|
|
||||||
has_deep_ep = importlib.util.find_spec("deep_ep") is not None
|
if has_deep_ep():
|
||||||
|
|
||||||
try:
|
|
||||||
import deep_gemm
|
|
||||||
has_deep_gemm = True
|
|
||||||
except ImportError:
|
|
||||||
has_deep_gemm = False
|
|
||||||
|
|
||||||
if has_deep_ep:
|
|
||||||
from vllm.model_executor.layers.fused_moe.deepep_ht_prepare_finalize import ( # noqa: E501
|
from vllm.model_executor.layers.fused_moe.deepep_ht_prepare_finalize import ( # noqa: E501
|
||||||
DeepEPHTPrepareAndFinalize)
|
DeepEPHTPrepareAndFinalize)
|
||||||
from vllm.model_executor.layers.fused_moe.deepep_ll_prepare_finalize import ( # noqa: E501
|
from vllm.model_executor.layers.fused_moe.deepep_ll_prepare_finalize import ( # noqa: E501
|
||||||
@ -40,19 +32,21 @@ if has_deep_ep:
|
|||||||
|
|
||||||
from .deepep_utils import DeepEPHTArgs, DeepEPLLArgs, make_deepep_a2a
|
from .deepep_utils import DeepEPHTArgs, DeepEPLLArgs, make_deepep_a2a
|
||||||
|
|
||||||
if has_deep_gemm:
|
if has_deep_gemm():
|
||||||
|
import deep_gemm
|
||||||
|
|
||||||
from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import (
|
from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import (
|
||||||
BatchedDeepGemmExperts)
|
BatchedDeepGemmExperts)
|
||||||
from vllm.model_executor.layers.fused_moe.deep_gemm_moe import (
|
from vllm.model_executor.layers.fused_moe.deep_gemm_moe import (
|
||||||
DeepGemmExperts)
|
DeepGemmExperts)
|
||||||
|
|
||||||
requires_deep_ep = pytest.mark.skipif(
|
requires_deep_ep = pytest.mark.skipif(
|
||||||
not has_deep_ep,
|
not has_deep_ep(),
|
||||||
reason="Requires deep_ep kernels",
|
reason="Requires deep_ep kernels",
|
||||||
)
|
)
|
||||||
|
|
||||||
requires_deep_gemm = pytest.mark.skipif(
|
requires_deep_gemm = pytest.mark.skipif(
|
||||||
not has_deep_gemm,
|
not has_deep_gemm(),
|
||||||
reason="Requires deep_gemm kernels",
|
reason="Requires deep_gemm kernels",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@ -4,7 +4,6 @@ Test deepep dispatch-combine logic
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import dataclasses
|
import dataclasses
|
||||||
import importlib
|
|
||||||
from typing import Optional, Union
|
from typing import Optional, Union
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
@ -22,12 +21,11 @@ from vllm.model_executor.layers.fused_moe.modular_kernel import (
|
|||||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||||
per_token_group_quant_fp8)
|
per_token_group_quant_fp8)
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
|
from vllm.utils import has_deep_ep
|
||||||
|
|
||||||
from .utils import ProcessGroupInfo, parallel_launch
|
from .utils import ProcessGroupInfo, parallel_launch
|
||||||
|
|
||||||
has_deep_ep = importlib.util.find_spec("deep_ep") is not None
|
if has_deep_ep():
|
||||||
|
|
||||||
if has_deep_ep:
|
|
||||||
from vllm.model_executor.layers.fused_moe.deepep_ht_prepare_finalize import ( # noqa: E501
|
from vllm.model_executor.layers.fused_moe.deepep_ht_prepare_finalize import ( # noqa: E501
|
||||||
DeepEPHTPrepareAndFinalize)
|
DeepEPHTPrepareAndFinalize)
|
||||||
from vllm.model_executor.layers.fused_moe.deepep_ll_prepare_finalize import ( # noqa: E501
|
from vllm.model_executor.layers.fused_moe.deepep_ll_prepare_finalize import ( # noqa: E501
|
||||||
@ -36,7 +34,7 @@ if has_deep_ep:
|
|||||||
from .deepep_utils import DeepEPHTArgs, DeepEPLLArgs, make_deepep_a2a
|
from .deepep_utils import DeepEPHTArgs, DeepEPLLArgs, make_deepep_a2a
|
||||||
|
|
||||||
requires_deep_ep = pytest.mark.skipif(
|
requires_deep_ep = pytest.mark.skipif(
|
||||||
not has_deep_ep,
|
not has_deep_ep(),
|
||||||
reason="Requires deep_ep kernels",
|
reason="Requires deep_ep kernels",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@ -1,6 +1,5 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
import importlib.util
|
|
||||||
from typing import TYPE_CHECKING, Any
|
from typing import TYPE_CHECKING, Any
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@ -8,6 +7,7 @@ import torch.distributed as dist
|
|||||||
|
|
||||||
from vllm.forward_context import get_forward_context
|
from vllm.forward_context import get_forward_context
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
|
from vllm.utils import has_deep_ep, has_pplx
|
||||||
|
|
||||||
from .base_device_communicator import All2AllManagerBase, Cache
|
from .base_device_communicator import All2AllManagerBase, Cache
|
||||||
|
|
||||||
@ -80,8 +80,8 @@ class PPLXAll2AllManager(All2AllManagerBase):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, cpu_group):
|
def __init__(self, cpu_group):
|
||||||
has_pplx = importlib.util.find_spec("pplx_kernels") is not None
|
assert has_pplx(
|
||||||
assert has_pplx, "pplx_kernels not found. Please follow https://github.com/vllm-project/vllm/blob/main/tools/ep_kernels/README.md to install pplx_kernels." # noqa
|
), "pplx_kernels not found. Please follow https://github.com/vllm-project/vllm/blob/main/tools/ep_kernels/README.md to install pplx_kernels." # noqa
|
||||||
super().__init__(cpu_group)
|
super().__init__(cpu_group)
|
||||||
|
|
||||||
if self.internode:
|
if self.internode:
|
||||||
@ -133,8 +133,8 @@ class DeepEPAll2AllManagerBase(All2AllManagerBase):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, cpu_group):
|
def __init__(self, cpu_group):
|
||||||
has_deepep = importlib.util.find_spec("deep_ep") is not None
|
assert has_deep_ep(
|
||||||
assert has_deepep, "DeepEP kernels not found. Please follow https://github.com/vllm-project/vllm/blob/main/tools/ep_kernels/README.md to install DeepEP kernels." # noqa
|
), "DeepEP kernels not found. Please follow https://github.com/vllm-project/vllm/blob/main/tools/ep_kernels/README.md to install DeepEP kernels." # noqa
|
||||||
super().__init__(cpu_group)
|
super().__init__(cpu_group)
|
||||||
self.handle_cache = Cache()
|
self.handle_cache = Cache()
|
||||||
|
|
||||||
|
|||||||
@ -1,5 +1,4 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
import importlib.util
|
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@ -11,8 +10,6 @@ from vllm.triton_utils import tl, triton
|
|||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
has_deep_gemm = importlib.util.find_spec("deep_gemm") is not None
|
|
||||||
|
|
||||||
|
|
||||||
@triton.jit
|
@triton.jit
|
||||||
def _silu_mul_fp8_quant_deep_gemm(
|
def _silu_mul_fp8_quant_deep_gemm(
|
||||||
|
|||||||
@ -1,7 +1,6 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
import functools
|
import functools
|
||||||
import importlib.util
|
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@ -12,14 +11,13 @@ from vllm.model_executor.layers.fused_moe.moe_permute_unpermute import (
|
|||||||
_moe_permute)
|
_moe_permute)
|
||||||
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
|
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
|
||||||
MoEPrepareAndFinalizeNoEP)
|
MoEPrepareAndFinalizeNoEP)
|
||||||
from vllm.model_executor.layers.fused_moe.utils import (
|
from vllm.model_executor.layers.fused_moe.utils import _resize_cache
|
||||||
_resize_cache, per_token_group_quant_fp8)
|
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||||
from vllm.utils import round_up
|
per_token_group_quant_fp8)
|
||||||
|
from vllm.utils import has_deep_gemm, round_up
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
has_deep_gemm = importlib.util.find_spec("deep_gemm") is not None
|
|
||||||
|
|
||||||
|
|
||||||
@functools.cache
|
@functools.cache
|
||||||
def deep_gemm_block_shape() -> list[int]:
|
def deep_gemm_block_shape() -> list[int]:
|
||||||
@ -41,7 +39,7 @@ def _valid_deep_gemm(hidden_states: torch.Tensor, w1: torch.Tensor,
|
|||||||
gemm kernel. All of M, N, K and the quantization block_shape must be
|
gemm kernel. All of M, N, K and the quantization block_shape must be
|
||||||
aligned by `dg.get_m_alignment_for_contiguous_layout()`.
|
aligned by `dg.get_m_alignment_for_contiguous_layout()`.
|
||||||
"""
|
"""
|
||||||
if not has_deep_gemm:
|
if not has_deep_gemm():
|
||||||
logger.debug("DeepGemm disabled: deep_gemm not available.")
|
logger.debug("DeepGemm disabled: deep_gemm not available.")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|||||||
@ -1,7 +1,6 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
import importlib
|
|
||||||
from abc import abstractmethod
|
from abc import abstractmethod
|
||||||
from collections.abc import Iterable
|
from collections.abc import Iterable
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
@ -32,10 +31,7 @@ from vllm.model_executor.layers.quantization.base_config import (
|
|||||||
from vllm.model_executor.utils import set_weight_attrs
|
from vllm.model_executor.utils import set_weight_attrs
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.platforms.interface import CpuArchEnum
|
from vllm.platforms.interface import CpuArchEnum
|
||||||
from vllm.utils import direct_register_custom_op
|
from vllm.utils import direct_register_custom_op, has_deep_ep, has_pplx
|
||||||
|
|
||||||
has_pplx = importlib.util.find_spec("pplx_kernels") is not None
|
|
||||||
has_deepep = importlib.util.find_spec("deep_ep") is not None
|
|
||||||
|
|
||||||
if current_platform.is_cuda_alike():
|
if current_platform.is_cuda_alike():
|
||||||
from .fused_batched_moe import BatchedTritonExperts
|
from .fused_batched_moe import BatchedTritonExperts
|
||||||
@ -43,9 +39,9 @@ if current_platform.is_cuda_alike():
|
|||||||
from .modular_kernel import (FusedMoEModularKernel,
|
from .modular_kernel import (FusedMoEModularKernel,
|
||||||
FusedMoEPermuteExpertsUnpermute,
|
FusedMoEPermuteExpertsUnpermute,
|
||||||
FusedMoEPrepareAndFinalize)
|
FusedMoEPrepareAndFinalize)
|
||||||
if has_pplx:
|
if has_pplx():
|
||||||
from .pplx_prepare_finalize import PplxPrepareAndFinalize
|
from .pplx_prepare_finalize import PplxPrepareAndFinalize
|
||||||
if has_deepep:
|
if has_deep_ep():
|
||||||
from .deepep_ht_prepare_finalize import DeepEPHTPrepareAndFinalize
|
from .deepep_ht_prepare_finalize import DeepEPHTPrepareAndFinalize
|
||||||
from .deepep_ll_prepare_finalize import (DEEPEP_QUANT_BLOCK_SIZE,
|
from .deepep_ll_prepare_finalize import (DEEPEP_QUANT_BLOCK_SIZE,
|
||||||
DeepEPLLPrepareAndFinalize)
|
DeepEPLLPrepareAndFinalize)
|
||||||
|
|||||||
@ -104,4 +104,4 @@ def find_free_port():
|
|||||||
with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s:
|
with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s:
|
||||||
s.bind(('', 0))
|
s.bind(('', 0))
|
||||||
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
||||||
return s.getsockname()[1]
|
return s.getsockname()[1]
|
||||||
|
|||||||
@ -2,7 +2,6 @@
|
|||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
import enum
|
import enum
|
||||||
import importlib
|
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Callable, Optional
|
from typing import Callable, Optional
|
||||||
|
|
||||||
@ -29,13 +28,12 @@ from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
|||||||
from vllm.model_executor.utils import set_weight_attrs
|
from vllm.model_executor.utils import set_weight_attrs
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.scalar_type import scalar_types
|
from vllm.scalar_type import scalar_types
|
||||||
|
from vllm.utils import has_pplx
|
||||||
has_pplx = importlib.util.find_spec("pplx_kernels") is not None
|
|
||||||
|
|
||||||
if current_platform.is_cuda_alike():
|
if current_platform.is_cuda_alike():
|
||||||
from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
|
from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
|
||||||
BatchedPrepareAndFinalize)
|
BatchedPrepareAndFinalize)
|
||||||
if has_pplx:
|
if has_pplx():
|
||||||
from vllm.model_executor.layers.fused_moe.pplx_prepare_finalize import (
|
from vllm.model_executor.layers.fused_moe.pplx_prepare_finalize import (
|
||||||
PplxPrepareAndFinalize)
|
PplxPrepareAndFinalize)
|
||||||
|
|
||||||
@ -577,7 +575,7 @@ class CompressedTensorsW8A8Fp8MoECutlassMethod(CompressedTensorsMoEMethod):
|
|||||||
use_batched_format=True,
|
use_batched_format=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
if has_pplx and isinstance(
|
if has_pplx() and isinstance(
|
||||||
prepare_finalize,
|
prepare_finalize,
|
||||||
(BatchedPrepareAndFinalize, PplxPrepareAndFinalize)):
|
(BatchedPrepareAndFinalize, PplxPrepareAndFinalize)):
|
||||||
# no expert_map support in this case
|
# no expert_map support in this case
|
||||||
|
|||||||
@ -1,15 +1,13 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
import importlib.util
|
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.triton_utils import triton
|
from vllm.triton_utils import triton
|
||||||
from vllm.utils import direct_register_custom_op
|
from vllm.utils import direct_register_custom_op, has_deep_gemm
|
||||||
|
|
||||||
has_deep_gemm = importlib.util.find_spec("deep_gemm") is not None
|
if has_deep_gemm():
|
||||||
if has_deep_gemm:
|
|
||||||
import deep_gemm
|
import deep_gemm
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|||||||
@ -2,7 +2,6 @@
|
|||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
import functools
|
import functools
|
||||||
import importlib.util
|
|
||||||
from typing import Any, Callable, Optional, Union
|
from typing import Any, Callable, Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@ -38,13 +37,12 @@ from vllm.model_executor.parameter import (BlockQuantScaleParameter,
|
|||||||
from vllm.model_executor.utils import set_weight_attrs
|
from vllm.model_executor.utils import set_weight_attrs
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.scalar_type import scalar_types
|
from vllm.scalar_type import scalar_types
|
||||||
|
from vllm.utils import has_deep_gemm
|
||||||
|
|
||||||
ACTIVATION_SCHEMES = ["static", "dynamic"]
|
ACTIVATION_SCHEMES = ["static", "dynamic"]
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
has_deep_gemm = importlib.util.find_spec("deep_gemm") is not None
|
|
||||||
|
|
||||||
|
|
||||||
def _is_col_major(x: torch.Tensor) -> bool:
|
def _is_col_major(x: torch.Tensor) -> bool:
|
||||||
assert x.dim() == 3
|
assert x.dim() == 3
|
||||||
@ -451,7 +449,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|||||||
# Check for DeepGemm support.
|
# Check for DeepGemm support.
|
||||||
self.allow_deep_gemm = False
|
self.allow_deep_gemm = False
|
||||||
if envs.VLLM_USE_DEEP_GEMM:
|
if envs.VLLM_USE_DEEP_GEMM:
|
||||||
if not has_deep_gemm:
|
if not has_deep_gemm():
|
||||||
logger.warning_once("Failed to import DeepGemm kernels.")
|
logger.warning_once("Failed to import DeepGemm kernels.")
|
||||||
elif not self.block_quant:
|
elif not self.block_quant:
|
||||||
logger.warning_once("Model is not block quantized. Not using "
|
logger.warning_once("Model is not block quantized. Not using "
|
||||||
|
|||||||
@ -3,7 +3,6 @@
|
|||||||
|
|
||||||
# Adapted from https://github.com/sgl-project/sglang/pull/2575
|
# Adapted from https://github.com/sgl-project/sglang/pull/2575
|
||||||
import functools
|
import functools
|
||||||
import importlib.util
|
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
from typing import Any, Callable, Optional, Union
|
from typing import Any, Callable, Optional, Union
|
||||||
@ -19,10 +18,9 @@ from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
|||||||
CUTLASS_BLOCK_FP8_SUPPORTED)
|
CUTLASS_BLOCK_FP8_SUPPORTED)
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.triton_utils import tl, triton
|
from vllm.triton_utils import tl, triton
|
||||||
from vllm.utils import cdiv, direct_register_custom_op
|
from vllm.utils import cdiv, direct_register_custom_op, has_deep_gemm
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
has_deep_gemm = importlib.util.find_spec("deep_gemm") is not None
|
|
||||||
|
|
||||||
|
|
||||||
def is_fp8(x: Union[torch.dtype, torch.Tensor]) -> bool:
|
def is_fp8(x: Union[torch.dtype, torch.Tensor]) -> bool:
|
||||||
@ -109,7 +107,7 @@ def should_use_deepgemm(output_dtype: torch.dtype, weight: torch.Tensor):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
return (current_platform.is_cuda()
|
return (current_platform.is_cuda()
|
||||||
and current_platform.is_device_capability(90) and has_deep_gemm
|
and current_platform.is_device_capability(90) and has_deep_gemm()
|
||||||
and envs.VLLM_USE_DEEP_GEMM and output_dtype == torch.bfloat16
|
and envs.VLLM_USE_DEEP_GEMM and output_dtype == torch.bfloat16
|
||||||
and weight.shape[0] % 128 == 0 and weight.shape[1] % 128 == 0)
|
and weight.shape[0] % 128 == 0 and weight.shape[1] % 128 == 0)
|
||||||
|
|
||||||
|
|||||||
@ -2929,3 +2929,31 @@ def is_torch_equal_or_newer(target: str) -> bool:
|
|||||||
def _is_torch_equal_or_newer(torch_version: str, target: str) -> bool:
|
def _is_torch_equal_or_newer(torch_version: str, target: str) -> bool:
|
||||||
torch_version = version.parse(torch_version)
|
torch_version = version.parse(torch_version)
|
||||||
return torch_version >= version.parse(target)
|
return torch_version >= version.parse(target)
|
||||||
|
|
||||||
|
|
||||||
|
@cache
|
||||||
|
def _has_module(module_name: str) -> bool:
|
||||||
|
"""Return True if *module_name* can be found in the current environment.
|
||||||
|
|
||||||
|
The result is cached so that subsequent queries for the same module incur
|
||||||
|
no additional overhead.
|
||||||
|
"""
|
||||||
|
return importlib.util.find_spec(module_name) is not None
|
||||||
|
|
||||||
|
|
||||||
|
def has_pplx() -> bool:
|
||||||
|
"""Whether the optional `pplx_kernels` package is available."""
|
||||||
|
|
||||||
|
return _has_module("pplx_kernels")
|
||||||
|
|
||||||
|
|
||||||
|
def has_deep_ep() -> bool:
|
||||||
|
"""Whether the optional `deep_ep` package is available."""
|
||||||
|
|
||||||
|
return _has_module("deep_ep")
|
||||||
|
|
||||||
|
|
||||||
|
def has_deep_gemm() -> bool:
|
||||||
|
"""Whether the optional `deep_gemm` package is available."""
|
||||||
|
|
||||||
|
return _has_module("deep_gemm")
|
||||||
Loading…
x
Reference in New Issue
Block a user