mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 22:55:51 +08:00
[Feature] Add VLLM_USE_DEEP_GEMM_E8M0 Env to Control E8M0 Scale (#21968)
Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
parent
8e13d9fe6d
commit
f7dcce7a4a
@ -16,7 +16,7 @@ from vllm.model_executor.layers.fused_moe.fused_moe import (
|
||||
fused_topk, modular_triton_fused_moe)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import has_deep_gemm
|
||||
from vllm.utils.deep_gemm import is_blackwell_deep_gemm_used
|
||||
from vllm.utils.deep_gemm import is_blackwell_deep_gemm_e8m0_used
|
||||
|
||||
dg_available = has_deep_gemm()
|
||||
|
||||
@ -224,7 +224,8 @@ def test_w8a8_block_fp8_fused_moe(M, N, K, E, topk, block_size, dtype, seed,
|
||||
@pytest.mark.parametrize("topk", TOP_KS)
|
||||
@pytest.mark.parametrize("seed", SEEDS)
|
||||
@pytest.mark.skipif(not dg_available, reason="DeepGemm kernels not available.")
|
||||
@pytest.mark.skipif(is_blackwell_deep_gemm_used(), reason="Not E8M0 scale MOE")
|
||||
@pytest.mark.skipif(is_blackwell_deep_gemm_e8m0_used(),
|
||||
reason="Not E8M0 scale MOE")
|
||||
@torch.inference_mode()
|
||||
def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, seed,
|
||||
monkeypatch):
|
||||
|
||||
@ -20,7 +20,7 @@ from vllm.model_executor.layers.fused_moe.modular_kernel import (
|
||||
FusedMoEModularKernel)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import has_deep_ep, has_deep_gemm
|
||||
from vllm.utils.deep_gemm import (is_blackwell_deep_gemm_used,
|
||||
from vllm.utils.deep_gemm import (is_blackwell_deep_gemm_e8m0_used,
|
||||
is_deep_gemm_supported)
|
||||
|
||||
from .parallel_utils import ProcessGroupInfo, parallel_launch
|
||||
@ -370,7 +370,7 @@ NUM_EXPERTS = [32]
|
||||
@pytest.mark.parametrize("world_dp_size", [(2, 1)])
|
||||
@requires_deep_ep
|
||||
@requires_deep_gemm
|
||||
@pytest.mark.skipif(is_blackwell_deep_gemm_used(),
|
||||
@pytest.mark.skipif(is_blackwell_deep_gemm_e8m0_used(),
|
||||
reason="Skipping test for Blackwell DeepGEMM")
|
||||
def test_ht_deepep_deepgemm_moe(mnk: tuple[int, int, int], num_experts: int,
|
||||
topk: int, world_dp_size: tuple[int, int]):
|
||||
@ -427,7 +427,7 @@ USE_FP8_DISPATCH = [False]
|
||||
@pytest.mark.parametrize("world_dp_size", [(2, 1)])
|
||||
@requires_deep_ep
|
||||
@requires_deep_gemm
|
||||
@pytest.mark.skipif(is_blackwell_deep_gemm_used(),
|
||||
@pytest.mark.skipif(is_blackwell_deep_gemm_e8m0_used(),
|
||||
reason="Skipping test for Blackwell DeepGEMM")
|
||||
def test_ll_deepep_deepgemm_moe(
|
||||
mnk: tuple[int, int, int],
|
||||
|
||||
@ -127,6 +127,7 @@ if TYPE_CHECKING:
|
||||
VLLM_TPU_MOST_MODEL_LEN: Optional[int] = None
|
||||
VLLM_TPU_USING_PATHWAYS: bool = False
|
||||
VLLM_USE_DEEP_GEMM: bool = False
|
||||
VLLM_USE_DEEP_GEMM_E8M0: bool = True
|
||||
VLLM_SKIP_DEEP_GEMM_WARMUP: bool = False
|
||||
VLLM_USE_FLASHINFER_MOE_FP8: bool = False
|
||||
VLLM_USE_FLASHINFER_MOE_FP4: bool = False
|
||||
@ -925,6 +926,10 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
||||
"VLLM_USE_DEEP_GEMM":
|
||||
lambda: bool(int(os.getenv("VLLM_USE_DEEP_GEMM", "0"))),
|
||||
|
||||
# Whether to use E8M0 scaling when DeepGEMM is used on Blackwell GPUs.
|
||||
# E8M0 is faster on B200 but may reduce accuracy.
|
||||
"VLLM_USE_DEEP_GEMM_E8M0":
|
||||
lambda: bool(int(os.getenv("VLLM_USE_DEEP_GEMM_E8M0", "1"))),
|
||||
# DeepGemm JITs the kernels on-demand. The warmup attempts to make DeepGemm
|
||||
# JIT all the required kernels before model execution so there is no
|
||||
# JIT'ing in the hot-path. However, this warmup increases the engine
|
||||
|
||||
@ -12,7 +12,7 @@ from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
|
||||
from vllm.model_executor.layers.fused_moe.utils import _resize_cache
|
||||
from vllm.triton_utils import tl, triton
|
||||
from vllm.utils.deep_gemm import (fp8_m_grouped_gemm_nt_masked,
|
||||
is_blackwell_deep_gemm_used)
|
||||
is_blackwell_deep_gemm_e8m0_used)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -176,7 +176,7 @@ def silu_mul_fp8_quant_deep_gemm(
|
||||
eps,
|
||||
fp8_min,
|
||||
fp8_max,
|
||||
is_blackwell_deep_gemm_used(),
|
||||
is_blackwell_deep_gemm_e8m0_used(),
|
||||
BLOCK=group_size,
|
||||
NUM_STAGES=8,
|
||||
num_warps=1,
|
||||
|
||||
@ -40,7 +40,7 @@ from vllm.model_executor.layers.quantization.utils.mxfp4_utils import (
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.triton_utils import tl, triton
|
||||
from vllm.utils import direct_register_custom_op, is_torch_equal_or_newer
|
||||
from vllm.utils.deep_gemm import is_blackwell_deep_gemm_used
|
||||
from vllm.utils.deep_gemm import is_blackwell_deep_gemm_e8m0_used
|
||||
|
||||
from .rocm_aiter_fused_moe import is_rocm_aiter_moe_enabled
|
||||
|
||||
@ -1387,8 +1387,8 @@ def fused_experts(hidden_states: torch.Tensor,
|
||||
# E8M0 scale, which means we requantize the weight and input to the specific
|
||||
# scale. Fallen back to cutlass or triton for some cases would cause
|
||||
# accuracy issue.
|
||||
should_use_deep_gemm = is_blackwell_deep_gemm_used() or _valid_deep_gemm(
|
||||
hidden_states, w1, w2)
|
||||
should_use_deep_gemm = is_blackwell_deep_gemm_e8m0_used(
|
||||
) or _valid_deep_gemm(hidden_states, w1, w2)
|
||||
if (allow_deep_gemm and use_fp8_w8a8 and should_use_deep_gemm):
|
||||
assert apply_router_weight_on_input is False
|
||||
assert is_act_and_mul, (
|
||||
|
||||
@ -10,7 +10,7 @@ from vllm.model_executor.layers.fused_moe.deep_gemm_moe import (
|
||||
DeepGemmExperts, _valid_deep_gemm, _valid_deep_gemm_shape,
|
||||
deep_gemm_block_shape)
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import TritonExperts
|
||||
from vllm.utils.deep_gemm import is_blackwell_deep_gemm_used
|
||||
from vllm.utils.deep_gemm import is_blackwell_deep_gemm_e8m0_used
|
||||
|
||||
|
||||
class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
@ -107,7 +107,7 @@ class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
# Note: the deep gemm workspaces are strictly larger than the triton
|
||||
# workspaces so we can be pessimistic here and allocate for DeepGemm
|
||||
# even if we fall back to triton later, e.g. if expert maps are set.
|
||||
if self.allow_deep_gemm and (is_blackwell_deep_gemm_used()
|
||||
if self.allow_deep_gemm and (is_blackwell_deep_gemm_e8m0_used()
|
||||
or _valid_deep_gemm_shape(M, N, K)):
|
||||
assert self.deep_gemm_expert is not None
|
||||
return self.deep_gemm_expert.workspace_shapes(
|
||||
@ -133,7 +133,7 @@ class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
extra_expert_args: Optional[dict[str, Any]]):
|
||||
use_deep_gemm = (self.allow_deep_gemm
|
||||
and (_valid_deep_gemm(hidden_states, w1, w2)
|
||||
or is_blackwell_deep_gemm_used()))
|
||||
or is_blackwell_deep_gemm_e8m0_used()))
|
||||
|
||||
experts = self.deep_gemm_expert if use_deep_gemm else self.triton_expert
|
||||
assert experts is not None
|
||||
|
||||
@ -45,7 +45,8 @@ from vllm.model_executor.utils import set_weight_attrs
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.scalar_type import scalar_types
|
||||
from vllm.utils import has_deep_gemm
|
||||
from vllm.utils.deep_gemm import is_blackwell_deep_gemm_used
|
||||
from vllm.utils.deep_gemm import (is_blackwell_deep_gemm_e8m0_used,
|
||||
is_deep_gemm_supported)
|
||||
from vllm.utils.flashinfer import has_flashinfer_moe
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@ -415,10 +416,10 @@ class Fp8LinearMethod(LinearMethodBase):
|
||||
# Activations not quantized for marlin.
|
||||
del layer.input_scale
|
||||
|
||||
# On B200, DeepGemm only support E8M0 scale, which means we need to
|
||||
# On B200, if E8M0 for DeepGemm is used, we need to
|
||||
# requantize the weight and input to the specific scale
|
||||
# at the same time.
|
||||
if is_blackwell_deep_gemm_used():
|
||||
if is_blackwell_deep_gemm_e8m0_used():
|
||||
assert layer.weight_block_size is not None
|
||||
block_sz = tuple(layer.weight_block_size)
|
||||
requant_weight_ue8m0_inplace(
|
||||
@ -505,15 +506,9 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
elif not self.block_quant:
|
||||
logger.warning_once("Model is not block quantized. Not using "
|
||||
"DeepGemm kernels")
|
||||
elif (current_platform.is_cuda()
|
||||
and current_platform.is_device_capability(90)):
|
||||
elif (is_deep_gemm_supported()):
|
||||
logger.info_once("Using DeepGemm kernels for Fp8MoEMethod.")
|
||||
self.allow_deep_gemm = True
|
||||
elif (current_platform.is_cuda()
|
||||
and is_blackwell_deep_gemm_used()):
|
||||
logger.info_once("Using DeepGemm SM100 kernels for "
|
||||
"Fp8MoEMethod.")
|
||||
self.allow_deep_gemm = True
|
||||
else:
|
||||
logger.warning_once(
|
||||
"DeepGemm not supported on the current platform.")
|
||||
@ -725,7 +720,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
|
||||
# DeepGemm scales need to be transposed and aligned. We try to do
|
||||
# it ahead of time for performance reasons.
|
||||
if self.allow_deep_gemm and not is_blackwell_deep_gemm_used():
|
||||
if self.allow_deep_gemm and not is_blackwell_deep_gemm_e8m0_used():
|
||||
# Lazy import to avoid CUDA initialization problems.
|
||||
if _is_col_major(layer.w13_weight_scale_inv):
|
||||
layer.w13_weight_scale_inv = \
|
||||
@ -851,7 +846,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
del layer.w13_input_scale
|
||||
del layer.w2_input_scale
|
||||
|
||||
if is_blackwell_deep_gemm_used():
|
||||
if is_blackwell_deep_gemm_e8m0_used():
|
||||
assert layer.weight_block_size is not None
|
||||
# Re-quantise the expert weights so their scales are UE8M0.
|
||||
block_sz = tuple(layer.weight_block_size)
|
||||
|
||||
@ -20,7 +20,7 @@ from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.triton_utils import tl, triton
|
||||
from vllm.utils import cdiv, direct_register_custom_op, has_deep_gemm
|
||||
from vllm.utils.deep_gemm import is_blackwell_deep_gemm_used
|
||||
from vllm.utils.deep_gemm import is_blackwell_deep_gemm_e8m0_used
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -394,10 +394,8 @@ def per_token_group_quant_fp8(
|
||||
tuple[torch.Tensor, torch.Tensor]: The quantized tensor and the
|
||||
scaling factor.
|
||||
"""
|
||||
# TODO(wentao): refactor this
|
||||
# use_ue8m0 should be a global flag that could be set by user
|
||||
if use_ue8m0 is None:
|
||||
use_ue8m0 = is_blackwell_deep_gemm_used()
|
||||
use_ue8m0 = is_blackwell_deep_gemm_e8m0_used()
|
||||
dtype = current_platform.fp8_dtype() if dtype is None else dtype
|
||||
assert (x.shape[-1] % group_size == 0), (
|
||||
f"the last dimension of `x` {x.shape[-1]} must be divisible "
|
||||
|
||||
@ -31,19 +31,37 @@ def is_deep_gemm_supported() -> bool:
|
||||
|
||||
|
||||
@functools.cache
|
||||
def is_blackwell_deep_gemm_used() -> bool:
|
||||
"""Return ``True`` if vLLM is configured to use DeepGEMM on a
|
||||
Blackwell-class GPU.
|
||||
def is_blackwell_deep_gemm_e8m0_used() -> bool:
|
||||
"""Return ``True`` if vLLM is configured to use DeepGEMM "
|
||||
"E8M0 scale on a Blackwell-class GPU.
|
||||
"""
|
||||
if not (envs.VLLM_USE_DEEP_GEMM and has_deep_gemm()):
|
||||
if not (envs.VLLM_USE_DEEP_GEMM):
|
||||
logger.debug_once("DeepGEMM E8M0 disabled: VLLM_USE_DEEP_GEMM=0.")
|
||||
return False
|
||||
|
||||
if not has_deep_gemm():
|
||||
logger.debug_once("DeepGEMM E8M0 disabled: DeepGEMM backend missing.")
|
||||
return False
|
||||
|
||||
if not envs.VLLM_USE_DEEP_GEMM_E8M0:
|
||||
logger.debug_once("DeepGEMM E8M0 disabled: VLLM_USE_DEEP_GEMM_E8M0=0.")
|
||||
return False
|
||||
|
||||
_lazy_init()
|
||||
|
||||
if _fp8_gemm_nt_impl is None:
|
||||
logger.debug_once(
|
||||
"DeepGEMM E8M0 disabled: _fp8_gemm_nt_impl not found")
|
||||
return False
|
||||
|
||||
return (current_platform.is_cuda()
|
||||
and current_platform.is_device_capability(100))
|
||||
enabled = (current_platform.is_cuda()
|
||||
and current_platform.has_device_capability(100))
|
||||
if enabled:
|
||||
logger.debug_once("DeepGEMM E8M0 enabled on Blackwell GPU.")
|
||||
else:
|
||||
logger.debug_once(
|
||||
"DeepGEMM E8M0 disabled: not running on Blackwell GPU.")
|
||||
return enabled
|
||||
|
||||
|
||||
def _missing(*_: Any, **__: Any) -> NoReturn:
|
||||
@ -109,21 +127,30 @@ def fp8_gemm_nt(*args, **kwargs):
|
||||
_lazy_init()
|
||||
if _fp8_gemm_nt_impl is None:
|
||||
return _missing(*args, **kwargs)
|
||||
return _fp8_gemm_nt_impl(*args, **kwargs)
|
||||
return _fp8_gemm_nt_impl(
|
||||
*args,
|
||||
disable_ue8m0_cast=not is_blackwell_deep_gemm_e8m0_used(),
|
||||
**kwargs)
|
||||
|
||||
|
||||
def m_grouped_fp8_gemm_nt_contiguous(*args, **kwargs):
|
||||
_lazy_init()
|
||||
if _grouped_impl is None:
|
||||
return _missing(*args, **kwargs)
|
||||
return _grouped_impl(*args, **kwargs)
|
||||
return _grouped_impl(
|
||||
*args,
|
||||
disable_ue8m0_cast=not is_blackwell_deep_gemm_e8m0_used(),
|
||||
**kwargs)
|
||||
|
||||
|
||||
def fp8_m_grouped_gemm_nt_masked(*args, **kwargs):
|
||||
_lazy_init()
|
||||
if _grouped_masked_impl is None:
|
||||
return _missing(*args, **kwargs)
|
||||
return _grouped_masked_impl(*args, **kwargs)
|
||||
return _grouped_masked_impl(
|
||||
*args,
|
||||
disable_ue8m0_cast=not is_blackwell_deep_gemm_e8m0_used(),
|
||||
**kwargs)
|
||||
|
||||
|
||||
def _ceil_to_ue8m0(x: torch.Tensor):
|
||||
@ -181,6 +208,6 @@ __all__ = [
|
||||
"m_grouped_fp8_gemm_nt_contiguous",
|
||||
"fp8_m_grouped_gemm_nt_masked",
|
||||
"per_block_cast_to_fp8",
|
||||
"is_blackwell_deep_gemm_used",
|
||||
"is_blackwell_deep_gemm_e8m0_used",
|
||||
"is_deep_gemm_supported",
|
||||
]
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user