mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 00:45:01 +08:00
[Feature] Add Hopper DeepGEMM E8M0 for DeepSeekV3.1 scale_fmt (#23666)
Signed-off-by: yewentao256 <zhyanwentao@126.com> Signed-off-by: youkaichao <youkaichao@gmail.com> Co-authored-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
parent
513c1fe255
commit
3af47c3cc6
@ -16,7 +16,7 @@ from vllm.model_executor.layers.fused_moe.fused_moe import (
|
|||||||
fused_topk, modular_triton_fused_moe)
|
fused_topk, modular_triton_fused_moe)
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.utils import has_deep_gemm
|
from vllm.utils import has_deep_gemm
|
||||||
from vllm.utils.deep_gemm import is_blackwell_deep_gemm_e8m0_used
|
from vllm.utils.deep_gemm import is_deep_gemm_e8m0_used
|
||||||
|
|
||||||
dg_available = has_deep_gemm()
|
dg_available = has_deep_gemm()
|
||||||
|
|
||||||
@ -226,8 +226,7 @@ def test_w8a8_block_fp8_fused_moe(M, N, K, E, topk, block_size, dtype, seed,
|
|||||||
@pytest.mark.parametrize("topk", TOP_KS)
|
@pytest.mark.parametrize("topk", TOP_KS)
|
||||||
@pytest.mark.parametrize("seed", SEEDS)
|
@pytest.mark.parametrize("seed", SEEDS)
|
||||||
@pytest.mark.skipif(not dg_available, reason="DeepGemm kernels not available.")
|
@pytest.mark.skipif(not dg_available, reason="DeepGemm kernels not available.")
|
||||||
@pytest.mark.skipif(is_blackwell_deep_gemm_e8m0_used(),
|
@pytest.mark.skipif(is_deep_gemm_e8m0_used(), reason="Not E8M0 scale MOE")
|
||||||
reason="Not E8M0 scale MOE")
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, seed,
|
def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, seed,
|
||||||
monkeypatch):
|
monkeypatch):
|
||||||
|
|||||||
@ -20,8 +20,7 @@ from vllm.model_executor.layers.fused_moe.modular_kernel import (
|
|||||||
FusedMoEModularKernel)
|
FusedMoEModularKernel)
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.utils import has_deep_ep, has_deep_gemm
|
from vllm.utils import has_deep_ep, has_deep_gemm
|
||||||
from vllm.utils.deep_gemm import (is_blackwell_deep_gemm_e8m0_used,
|
from vllm.utils.deep_gemm import is_deep_gemm_e8m0_used, is_deep_gemm_supported
|
||||||
is_deep_gemm_supported)
|
|
||||||
|
|
||||||
from ...utils import multi_gpu_test
|
from ...utils import multi_gpu_test
|
||||||
from .parallel_utils import ProcessGroupInfo, parallel_launch
|
from .parallel_utils import ProcessGroupInfo, parallel_launch
|
||||||
@ -374,7 +373,7 @@ NUM_EXPERTS = [32]
|
|||||||
@multi_gpu_test(num_gpus=2)
|
@multi_gpu_test(num_gpus=2)
|
||||||
@requires_deep_ep
|
@requires_deep_ep
|
||||||
@requires_deep_gemm
|
@requires_deep_gemm
|
||||||
@pytest.mark.skipif(is_blackwell_deep_gemm_e8m0_used(),
|
@pytest.mark.skipif(is_deep_gemm_e8m0_used(),
|
||||||
reason="Skipping test for Blackwell DeepGEMM")
|
reason="Skipping test for Blackwell DeepGEMM")
|
||||||
def test_ht_deepep_deepgemm_moe(mnk: tuple[int, int, int], num_experts: int,
|
def test_ht_deepep_deepgemm_moe(mnk: tuple[int, int, int], num_experts: int,
|
||||||
topk: int, world_dp_size: tuple[int, int]):
|
topk: int, world_dp_size: tuple[int, int]):
|
||||||
@ -432,7 +431,7 @@ USE_FP8_DISPATCH = [False]
|
|||||||
@multi_gpu_test(num_gpus=2)
|
@multi_gpu_test(num_gpus=2)
|
||||||
@requires_deep_ep
|
@requires_deep_ep
|
||||||
@requires_deep_gemm
|
@requires_deep_gemm
|
||||||
@pytest.mark.skipif(is_blackwell_deep_gemm_e8m0_used(),
|
@pytest.mark.skipif(is_deep_gemm_e8m0_used(),
|
||||||
reason="Skipping test for Blackwell DeepGEMM")
|
reason="Skipping test for Blackwell DeepGEMM")
|
||||||
def test_ll_deepep_deepgemm_moe(
|
def test_ll_deepep_deepgemm_moe(
|
||||||
mnk: tuple[int, int, int],
|
mnk: tuple[int, int, int],
|
||||||
|
|||||||
@ -131,6 +131,7 @@ if TYPE_CHECKING:
|
|||||||
VLLM_TPU_USING_PATHWAYS: bool = False
|
VLLM_TPU_USING_PATHWAYS: bool = False
|
||||||
VLLM_USE_DEEP_GEMM: bool = False
|
VLLM_USE_DEEP_GEMM: bool = False
|
||||||
VLLM_USE_DEEP_GEMM_E8M0: bool = True
|
VLLM_USE_DEEP_GEMM_E8M0: bool = True
|
||||||
|
VLLM_USE_DEEP_GEMM_E8M0_HOPPER: bool = False
|
||||||
VLLM_SKIP_DEEP_GEMM_WARMUP: bool = False
|
VLLM_SKIP_DEEP_GEMM_WARMUP: bool = False
|
||||||
VLLM_USE_FUSED_MOE_GROUPED_TOPK: bool = True
|
VLLM_USE_FUSED_MOE_GROUPED_TOPK: bool = True
|
||||||
VLLM_USE_FLASHINFER_MOE_FP8: bool = False
|
VLLM_USE_FLASHINFER_MOE_FP8: bool = False
|
||||||
@ -954,9 +955,12 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
|||||||
lambda: bool(int(os.getenv("VLLM_USE_DEEP_GEMM", "0"))),
|
lambda: bool(int(os.getenv("VLLM_USE_DEEP_GEMM", "0"))),
|
||||||
|
|
||||||
# Whether to use E8M0 scaling when DeepGEMM is used on Blackwell GPUs.
|
# Whether to use E8M0 scaling when DeepGEMM is used on Blackwell GPUs.
|
||||||
# E8M0 is faster on B200 but may reduce accuracy.
|
|
||||||
"VLLM_USE_DEEP_GEMM_E8M0":
|
"VLLM_USE_DEEP_GEMM_E8M0":
|
||||||
lambda: bool(int(os.getenv("VLLM_USE_DEEP_GEMM_E8M0", "1"))),
|
lambda: bool(int(os.getenv("VLLM_USE_DEEP_GEMM_E8M0", "1"))),
|
||||||
|
# TODO(wentao): unify the two E8M0 flags after verifying the correctness.
|
||||||
|
# Whether to use E8M0 scaling when DeepGEMM is used on Hopper GPUs.
|
||||||
|
"VLLM_USE_DEEP_GEMM_E8M0_HOPPER":
|
||||||
|
lambda: bool(int(os.getenv("VLLM_USE_DEEP_GEMM_E8M0_HOPPER", "0"))),
|
||||||
# DeepGemm JITs the kernels on-demand. The warmup attempts to make DeepGemm
|
# DeepGemm JITs the kernels on-demand. The warmup attempts to make DeepGemm
|
||||||
# JIT all the required kernels before model execution so there is no
|
# JIT all the required kernels before model execution so there is no
|
||||||
# JIT'ing in the hot-path. However, this warmup increases the engine
|
# JIT'ing in the hot-path. However, this warmup increases the engine
|
||||||
@ -1244,6 +1248,8 @@ def compute_hash() -> str:
|
|||||||
"VLLM_USE_FLASHINFER_SAMPLER",
|
"VLLM_USE_FLASHINFER_SAMPLER",
|
||||||
"VLLM_DISABLED_KERNELS",
|
"VLLM_DISABLED_KERNELS",
|
||||||
"VLLM_USE_DEEP_GEMM",
|
"VLLM_USE_DEEP_GEMM",
|
||||||
|
"VLLM_USE_DEEP_GEMM_E8M0",
|
||||||
|
"VLLM_USE_DEEP_GEMM_E8M0_HOPPER",
|
||||||
"VLLM_USE_TRTLLM_FP4_GEMM",
|
"VLLM_USE_TRTLLM_FP4_GEMM",
|
||||||
"VLLM_USE_FUSED_MOE_GROUPED_TOPK",
|
"VLLM_USE_FUSED_MOE_GROUPED_TOPK",
|
||||||
"VLLM_USE_FLASHINFER_MOE_FP8",
|
"VLLM_USE_FLASHINFER_MOE_FP8",
|
||||||
|
|||||||
@ -12,7 +12,7 @@ from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
|
|||||||
from vllm.model_executor.layers.fused_moe.utils import _resize_cache
|
from vllm.model_executor.layers.fused_moe.utils import _resize_cache
|
||||||
from vllm.triton_utils import tl, triton
|
from vllm.triton_utils import tl, triton
|
||||||
from vllm.utils.deep_gemm import (fp8_m_grouped_gemm_nt_masked,
|
from vllm.utils.deep_gemm import (fp8_m_grouped_gemm_nt_masked,
|
||||||
is_blackwell_deep_gemm_e8m0_used)
|
is_deep_gemm_e8m0_used)
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
@ -174,7 +174,7 @@ def silu_mul_fp8_quant_deep_gemm(
|
|||||||
eps,
|
eps,
|
||||||
fp8_min,
|
fp8_min,
|
||||||
fp8_max,
|
fp8_max,
|
||||||
is_blackwell_deep_gemm_e8m0_used(),
|
is_deep_gemm_e8m0_used(),
|
||||||
BLOCK=group_size,
|
BLOCK=group_size,
|
||||||
NUM_STAGES=4,
|
NUM_STAGES=4,
|
||||||
num_warps=1,
|
num_warps=1,
|
||||||
|
|||||||
@ -40,7 +40,7 @@ from vllm.model_executor.layers.quantization.utils.mxfp4_utils import (
|
|||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.triton_utils import tl, triton
|
from vllm.triton_utils import tl, triton
|
||||||
from vllm.utils import direct_register_custom_op, is_torch_equal_or_newer
|
from vllm.utils import direct_register_custom_op, is_torch_equal_or_newer
|
||||||
from vllm.utils.deep_gemm import is_blackwell_deep_gemm_e8m0_used
|
from vllm.utils.deep_gemm import is_deep_gemm_e8m0_used
|
||||||
|
|
||||||
from .rocm_aiter_fused_moe import is_rocm_aiter_moe_enabled
|
from .rocm_aiter_fused_moe import is_rocm_aiter_moe_enabled
|
||||||
|
|
||||||
@ -1431,9 +1431,8 @@ def fused_experts(hidden_states: torch.Tensor,
|
|||||||
# E8M0 scale, which means we requantize the weight and input to the specific
|
# E8M0 scale, which means we requantize the weight and input to the specific
|
||||||
# scale. Fallen back to cutlass or triton for some cases would cause
|
# scale. Fallen back to cutlass or triton for some cases would cause
|
||||||
# accuracy issue.
|
# accuracy issue.
|
||||||
if (allow_deep_gemm and use_fp8_w8a8
|
if (allow_deep_gemm and use_fp8_w8a8 and
|
||||||
and (is_blackwell_deep_gemm_e8m0_used()
|
(is_deep_gemm_e8m0_used() or _valid_deep_gemm(hidden_states, w1, w2))):
|
||||||
or _valid_deep_gemm(hidden_states, w1, w2))):
|
|
||||||
assert apply_router_weight_on_input is False
|
assert apply_router_weight_on_input is False
|
||||||
assert is_act_and_mul, (
|
assert is_act_and_mul, (
|
||||||
"DeepGemm only supports is_act_and_mul=True for now.")
|
"DeepGemm only supports is_act_and_mul=True for now.")
|
||||||
|
|||||||
@ -10,7 +10,7 @@ from vllm.model_executor.layers.fused_moe.deep_gemm_moe import (
|
|||||||
DeepGemmExperts, _valid_deep_gemm, _valid_deep_gemm_shape,
|
DeepGemmExperts, _valid_deep_gemm, _valid_deep_gemm_shape,
|
||||||
deep_gemm_block_shape)
|
deep_gemm_block_shape)
|
||||||
from vllm.model_executor.layers.fused_moe.fused_moe import TritonExperts
|
from vllm.model_executor.layers.fused_moe.fused_moe import TritonExperts
|
||||||
from vllm.utils.deep_gemm import is_blackwell_deep_gemm_e8m0_used
|
from vllm.utils.deep_gemm import is_deep_gemm_e8m0_used
|
||||||
|
|
||||||
|
|
||||||
class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||||
@ -107,7 +107,7 @@ class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
|||||||
# Note: the deep gemm workspaces are strictly larger than the triton
|
# Note: the deep gemm workspaces are strictly larger than the triton
|
||||||
# workspaces so we can be pessimistic here and allocate for DeepGemm
|
# workspaces so we can be pessimistic here and allocate for DeepGemm
|
||||||
# even if we fall back to triton later, e.g. if expert maps are set.
|
# even if we fall back to triton later, e.g. if expert maps are set.
|
||||||
if self.allow_deep_gemm and (is_blackwell_deep_gemm_e8m0_used()
|
if self.allow_deep_gemm and (is_deep_gemm_e8m0_used()
|
||||||
or _valid_deep_gemm_shape(M, N, K)):
|
or _valid_deep_gemm_shape(M, N, K)):
|
||||||
assert self.deep_gemm_expert is not None
|
assert self.deep_gemm_expert is not None
|
||||||
return self.deep_gemm_expert.workspace_shapes(
|
return self.deep_gemm_expert.workspace_shapes(
|
||||||
@ -143,7 +143,7 @@ class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
|||||||
):
|
):
|
||||||
use_deep_gemm = (self.allow_deep_gemm
|
use_deep_gemm = (self.allow_deep_gemm
|
||||||
and (_valid_deep_gemm(hidden_states, w1, w2)
|
and (_valid_deep_gemm(hidden_states, w1, w2)
|
||||||
or is_blackwell_deep_gemm_e8m0_used()))
|
or is_deep_gemm_e8m0_used()))
|
||||||
|
|
||||||
experts = self.deep_gemm_expert if use_deep_gemm else self.triton_expert
|
experts = self.deep_gemm_expert if use_deep_gemm else self.triton_expert
|
||||||
assert experts is not None
|
assert experts is not None
|
||||||
|
|||||||
@ -48,8 +48,7 @@ from vllm.model_executor.utils import set_weight_attrs
|
|||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.scalar_type import scalar_types
|
from vllm.scalar_type import scalar_types
|
||||||
from vllm.utils import has_deep_gemm
|
from vllm.utils import has_deep_gemm
|
||||||
from vllm.utils.deep_gemm import (is_blackwell_deep_gemm_e8m0_used,
|
from vllm.utils.deep_gemm import is_deep_gemm_e8m0_used, is_deep_gemm_supported
|
||||||
is_deep_gemm_supported)
|
|
||||||
from vllm.utils.flashinfer import has_flashinfer_moe
|
from vllm.utils.flashinfer import has_flashinfer_moe
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@ -427,7 +426,7 @@ class Fp8LinearMethod(LinearMethodBase):
|
|||||||
# On B200, if E8M0 for DeepGemm is used, we need to
|
# On B200, if E8M0 for DeepGemm is used, we need to
|
||||||
# requantize the weight and input to the specific scale
|
# requantize the weight and input to the specific scale
|
||||||
# at the same time.
|
# at the same time.
|
||||||
if is_blackwell_deep_gemm_e8m0_used():
|
if is_deep_gemm_e8m0_used():
|
||||||
assert layer.weight_block_size is not None
|
assert layer.weight_block_size is not None
|
||||||
block_sz = tuple(layer.weight_block_size)
|
block_sz = tuple(layer.weight_block_size)
|
||||||
requant_weight_ue8m0_inplace(
|
requant_weight_ue8m0_inplace(
|
||||||
@ -734,7 +733,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|||||||
|
|
||||||
# DeepGemm scales need to be transposed and aligned. We try to do
|
# DeepGemm scales need to be transposed and aligned. We try to do
|
||||||
# it ahead of time for performance reasons.
|
# it ahead of time for performance reasons.
|
||||||
if self.allow_deep_gemm and not is_blackwell_deep_gemm_e8m0_used():
|
if self.allow_deep_gemm and not is_deep_gemm_e8m0_used():
|
||||||
# Lazy import to avoid CUDA initialization problems.
|
# Lazy import to avoid CUDA initialization problems.
|
||||||
if _is_col_major(layer.w13_weight_scale_inv):
|
if _is_col_major(layer.w13_weight_scale_inv):
|
||||||
layer.w13_weight_scale_inv = \
|
layer.w13_weight_scale_inv = \
|
||||||
@ -871,7 +870,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|||||||
del layer.w13_input_scale
|
del layer.w13_input_scale
|
||||||
del layer.w2_input_scale
|
del layer.w2_input_scale
|
||||||
|
|
||||||
if is_blackwell_deep_gemm_e8m0_used():
|
if is_deep_gemm_e8m0_used():
|
||||||
assert layer.weight_block_size is not None
|
assert layer.weight_block_size is not None
|
||||||
# Re-quantise the expert weights so their scales are UE8M0.
|
# Re-quantise the expert weights so their scales are UE8M0.
|
||||||
block_sz = tuple(layer.weight_block_size)
|
block_sz = tuple(layer.weight_block_size)
|
||||||
|
|||||||
@ -20,7 +20,7 @@ from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
|||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.triton_utils import tl, triton
|
from vllm.triton_utils import tl, triton
|
||||||
from vllm.utils import cdiv, direct_register_custom_op
|
from vllm.utils import cdiv, direct_register_custom_op
|
||||||
from vllm.utils.deep_gemm import (is_blackwell_deep_gemm_e8m0_used,
|
from vllm.utils.deep_gemm import (is_deep_gemm_e8m0_used,
|
||||||
should_use_deepgemm_for_fp8_linear)
|
should_use_deepgemm_for_fp8_linear)
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
@ -385,7 +385,7 @@ def per_token_group_quant_fp8(
|
|||||||
scaling factor.
|
scaling factor.
|
||||||
"""
|
"""
|
||||||
if use_ue8m0 is None:
|
if use_ue8m0 is None:
|
||||||
use_ue8m0 = is_blackwell_deep_gemm_e8m0_used()
|
use_ue8m0 = is_deep_gemm_e8m0_used()
|
||||||
dtype = current_platform.fp8_dtype() if dtype is None else dtype
|
dtype = current_platform.fp8_dtype() if dtype is None else dtype
|
||||||
assert (x.shape[-1] % group_size == 0), (
|
assert (x.shape[-1] % group_size == 0), (
|
||||||
f"the last dimension of `x` {x.shape[-1]} must be divisible "
|
f"the last dimension of `x` {x.shape[-1]} must be divisible "
|
||||||
|
|||||||
@ -501,6 +501,24 @@ def get_config(
|
|||||||
|
|
||||||
if quantization_config is not None:
|
if quantization_config is not None:
|
||||||
config.quantization_config = quantization_config
|
config.quantization_config = quantization_config
|
||||||
|
# auto-enable DeepGEMM UE8M0 on Hopper if model config requests it
|
||||||
|
scale_fmt = quantization_config.get("scale_fmt", None)
|
||||||
|
if scale_fmt in ("ue8m0", ):
|
||||||
|
if not envs.is_set("VLLM_USE_DEEP_GEMM_E8M0_HOPPER"):
|
||||||
|
os.environ["VLLM_USE_DEEP_GEMM_E8M0_HOPPER"] = "1"
|
||||||
|
logger.info_once(
|
||||||
|
("Detected quantization_config.scale_fmt=%s; "
|
||||||
|
"enabling Hopper UE8M0."),
|
||||||
|
scale_fmt,
|
||||||
|
)
|
||||||
|
elif not envs.VLLM_USE_DEEP_GEMM_E8M0_HOPPER:
|
||||||
|
logger.warning_once(
|
||||||
|
("Model config requests UE8M0 "
|
||||||
|
"(quantization_config.scale_fmt=%s), but "
|
||||||
|
"VLLM_USE_DEEP_GEMM_E8M0_HOPPER=0 is set; "
|
||||||
|
"Hopper UE8M0 disabled."),
|
||||||
|
scale_fmt,
|
||||||
|
)
|
||||||
|
|
||||||
if hf_overrides_kw:
|
if hf_overrides_kw:
|
||||||
logger.debug("Overriding HF config with %s", hf_overrides_kw)
|
logger.debug("Overriding HF config with %s", hf_overrides_kw)
|
||||||
|
|||||||
@ -31,34 +31,33 @@ def is_deep_gemm_supported() -> bool:
|
|||||||
|
|
||||||
|
|
||||||
@functools.cache
|
@functools.cache
|
||||||
def is_blackwell_deep_gemm_e8m0_used() -> bool:
|
def is_deep_gemm_e8m0_used() -> bool:
|
||||||
"""Return ``True`` if vLLM is configured to use DeepGEMM "
|
"""Return ``True`` if vLLM is configured to use DeepGEMM "
|
||||||
"E8M0 scale on a Blackwell-class GPU.
|
"E8M0 scale on a Hopper or Blackwell-class GPU.
|
||||||
"""
|
"""
|
||||||
if not is_deep_gemm_supported():
|
if not is_deep_gemm_supported():
|
||||||
logger.debug_once(
|
logger.info_once(
|
||||||
"DeepGEMM E8M0 disabled: DeepGEMM not supported on this system.")
|
"DeepGEMM E8M0 disabled: DeepGEMM not supported on this system.")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
if not envs.VLLM_USE_DEEP_GEMM_E8M0:
|
|
||||||
logger.debug_once("DeepGEMM E8M0 disabled: VLLM_USE_DEEP_GEMM_E8M0=0.")
|
|
||||||
return False
|
|
||||||
|
|
||||||
_lazy_init()
|
_lazy_init()
|
||||||
|
|
||||||
if _fp8_gemm_nt_impl is None:
|
if _fp8_gemm_nt_impl is None:
|
||||||
logger.debug_once(
|
logger.info_once("DeepGEMM E8M0 disabled: _fp8_gemm_nt_impl not found")
|
||||||
"DeepGEMM E8M0 disabled: _fp8_gemm_nt_impl not found")
|
|
||||||
return False
|
return False
|
||||||
|
|
||||||
enabled = (current_platform.is_cuda()
|
if current_platform.is_device_capability(100) and \
|
||||||
and current_platform.has_device_capability(100))
|
envs.VLLM_USE_DEEP_GEMM_E8M0:
|
||||||
if enabled:
|
logger.info_once("DeepGEMM E8M0 enabled on Blackwell GPU.")
|
||||||
logger.debug_once("DeepGEMM E8M0 enabled on Blackwell GPU.")
|
return True
|
||||||
else:
|
|
||||||
logger.debug_once(
|
if current_platform.is_device_capability(90) and \
|
||||||
"DeepGEMM E8M0 disabled: not running on Blackwell GPU.")
|
envs.VLLM_USE_DEEP_GEMM_E8M0_HOPPER:
|
||||||
return enabled
|
logger.info_once("DeepGEMM E8M0 enabled on Hopper GPU.")
|
||||||
|
return True
|
||||||
|
|
||||||
|
logger.info_once("DeepGEMM E8M0 disabled on current configuration.")
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
def _missing(*_: Any, **__: Any) -> NoReturn:
|
def _missing(*_: Any, **__: Any) -> NoReturn:
|
||||||
@ -124,20 +123,18 @@ def fp8_gemm_nt(*args, **kwargs):
|
|||||||
_lazy_init()
|
_lazy_init()
|
||||||
if _fp8_gemm_nt_impl is None:
|
if _fp8_gemm_nt_impl is None:
|
||||||
return _missing(*args, **kwargs)
|
return _missing(*args, **kwargs)
|
||||||
return _fp8_gemm_nt_impl(
|
return _fp8_gemm_nt_impl(*args,
|
||||||
*args,
|
disable_ue8m0_cast=not is_deep_gemm_e8m0_used(),
|
||||||
disable_ue8m0_cast=not is_blackwell_deep_gemm_e8m0_used(),
|
**kwargs)
|
||||||
**kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
def m_grouped_fp8_gemm_nt_contiguous(*args, **kwargs):
|
def m_grouped_fp8_gemm_nt_contiguous(*args, **kwargs):
|
||||||
_lazy_init()
|
_lazy_init()
|
||||||
if _grouped_impl is None:
|
if _grouped_impl is None:
|
||||||
return _missing(*args, **kwargs)
|
return _missing(*args, **kwargs)
|
||||||
return _grouped_impl(
|
return _grouped_impl(*args,
|
||||||
*args,
|
disable_ue8m0_cast=not is_deep_gemm_e8m0_used(),
|
||||||
disable_ue8m0_cast=not is_blackwell_deep_gemm_e8m0_used(),
|
**kwargs)
|
||||||
**kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
def fp8_m_grouped_gemm_nt_masked(*args, **kwargs):
|
def fp8_m_grouped_gemm_nt_masked(*args, **kwargs):
|
||||||
@ -145,9 +142,7 @@ def fp8_m_grouped_gemm_nt_masked(*args, **kwargs):
|
|||||||
if _grouped_masked_impl is None:
|
if _grouped_masked_impl is None:
|
||||||
return _missing(*args, **kwargs)
|
return _missing(*args, **kwargs)
|
||||||
return _grouped_masked_impl(
|
return _grouped_masked_impl(
|
||||||
*args,
|
*args, disable_ue8m0_cast=not is_deep_gemm_e8m0_used(), **kwargs)
|
||||||
disable_ue8m0_cast=not is_blackwell_deep_gemm_e8m0_used(),
|
|
||||||
**kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
def _ceil_to_ue8m0(x: torch.Tensor):
|
def _ceil_to_ue8m0(x: torch.Tensor):
|
||||||
@ -211,7 +206,7 @@ __all__ = [
|
|||||||
"m_grouped_fp8_gemm_nt_contiguous",
|
"m_grouped_fp8_gemm_nt_contiguous",
|
||||||
"fp8_m_grouped_gemm_nt_masked",
|
"fp8_m_grouped_gemm_nt_masked",
|
||||||
"per_block_cast_to_fp8",
|
"per_block_cast_to_fp8",
|
||||||
"is_blackwell_deep_gemm_e8m0_used",
|
"is_deep_gemm_e8m0_used",
|
||||||
"is_deep_gemm_supported",
|
"is_deep_gemm_supported",
|
||||||
"should_use_deepgemm_for_fp8_linear",
|
"should_use_deepgemm_for_fp8_linear",
|
||||||
]
|
]
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user