From e2de455c349df8385b18fe447beb6325dcb6af9c Mon Sep 17 00:00:00 2001 From: Wentao Ye <44945378+yewentao256@users.noreply.github.com> Date: Thu, 10 Jul 2025 23:18:05 -0400 Subject: [PATCH] [Feature] Integrate SM100 DeepGEMM support (#20087) --- benchmarks/kernels/benchmark_moe.py | 3 + tests/kernels/moe/test_block_fp8.py | 16 +- tests/kernels/moe/test_deepep_deepgemm_moe.py | 5 + tests/kernels/moe/test_deepgemm.py | 55 +------ tests/kernels/quantization/test_block_fp8.py | 27 ++-- .../layers/fused_moe/batched_deep_gemm_moe.py | 21 ++- .../layers/fused_moe/deep_gemm_moe.py | 22 +-- .../layers/fused_moe/fused_moe.py | 12 +- .../layers/fused_moe/prepare_finalize.py | 1 - .../layers/fused_moe/triton_deep_gemm_moe.py | 7 +- vllm/model_executor/layers/fused_moe/utils.py | 7 +- .../layers/mamba/ops/causal_conv1d.py | 3 +- .../layers/quantization/deepgemm.py | 8 +- .../model_executor/layers/quantization/fp8.py | 46 +++++- .../layers/quantization/utils/fp8_utils.py | 126 ++++++++++++++- vllm/utils/deep_gemm.py | 152 ++++++++++++++++++ 16 files changed, 397 insertions(+), 114 deletions(-) create mode 100644 vllm/utils/deep_gemm.py diff --git a/benchmarks/kernels/benchmark_moe.py b/benchmarks/kernels/benchmark_moe.py index 07af58d81c683..51c9f68e43af3 100644 --- a/benchmarks/kernels/benchmark_moe.py +++ b/benchmarks/kernels/benchmark_moe.py @@ -86,6 +86,9 @@ def benchmark_config( (num_experts, 2 * shard_intermediate_size), dtype=torch.float32 ) w2_scale = torch.randn((hidden_size, num_experts), dtype=torch.float32) + if use_deep_gemm: + # we use the default block shape for deepgemm + block_quant_shape = [128, 128] if use_fp8_w8a8: if block_quant_shape: block_n, block_k = block_quant_shape[0], block_quant_shape[1] diff --git a/tests/kernels/moe/test_block_fp8.py b/tests/kernels/moe/test_block_fp8.py index c187542205a57..7dc6282326b66 100644 --- a/tests/kernels/moe/test_block_fp8.py +++ b/tests/kernels/moe/test_block_fp8.py @@ -15,13 +15,13 @@ from vllm.model_executor.layers.fused_moe.deep_gemm_moe import ( from vllm.model_executor.layers.fused_moe.fused_moe import ( fused_topk, modular_triton_fused_moe) from vllm.platforms import current_platform +from vllm.utils import has_deep_gemm +from vllm.utils.deep_gemm import is_blackwell_deep_gemm_used -dg_available = False -try: - import deep_gemm - dg_available = True -except ImportError: - pass +dg_available = has_deep_gemm() + +if dg_available: + from deep_gemm import get_m_alignment_for_contiguous_layout if current_platform.get_device_capability() < (9, 0): pytest.skip("FP8 Triton requires CUDA 9.0 or higher", @@ -224,6 +224,7 @@ def test_w8a8_block_fp8_fused_moe(M, N, K, E, topk, block_size, dtype, seed, @pytest.mark.parametrize("topk", TOP_KS) @pytest.mark.parametrize("seed", SEEDS) @pytest.mark.skipif(not dg_available, reason="DeepGemm kernels not available.") +@pytest.mark.skipif(is_blackwell_deep_gemm_used(), reason="Not E8M0 scale MOE") @torch.inference_mode() def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, seed, monkeypatch): @@ -238,8 +239,7 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, seed, torch.manual_seed(seed) monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", str(chunk_size)) - - block_m = deep_gemm.get_m_alignment_for_contiguous_layout() + block_m = get_m_alignment_for_contiguous_layout() block_size = [block_m, block_m] dtype = torch.bfloat16 diff --git a/tests/kernels/moe/test_deepep_deepgemm_moe.py b/tests/kernels/moe/test_deepep_deepgemm_moe.py index b74137eeaaa65..074771e49a061 100644 --- a/tests/kernels/moe/test_deepep_deepgemm_moe.py +++ b/tests/kernels/moe/test_deepep_deepgemm_moe.py @@ -20,6 +20,7 @@ from vllm.model_executor.layers.fused_moe.modular_kernel import ( FusedMoEModularKernel) from vllm.platforms import current_platform from vllm.utils import has_deep_ep, has_deep_gemm +from vllm.utils.deep_gemm import is_blackwell_deep_gemm_used from .parallel_utils import ProcessGroupInfo, parallel_launch from .utils import make_test_weights @@ -368,6 +369,8 @@ NUM_EXPERTS = [32] @pytest.mark.parametrize("world_dp_size", [(2, 1)]) @requires_deep_ep @requires_deep_gemm +@pytest.mark.skipif(is_blackwell_deep_gemm_used(), + reason="Skipping test for Blackwell DeepGEMM") def test_ht_deepep_deepgemm_moe(mnk: tuple[int, int, int], num_experts: int, topk: int, world_dp_size: tuple[int, int]): """ @@ -423,6 +426,8 @@ USE_FP8_DISPATCH = [False] @pytest.mark.parametrize("world_dp_size", [(2, 1)]) @requires_deep_ep @requires_deep_gemm +@pytest.mark.skipif(is_blackwell_deep_gemm_used(), + reason="Skipping test for Blackwell DeepGEMM") def test_ll_deepep_deepgemm_moe( mnk: tuple[int, int, int], num_experts: int, diff --git a/tests/kernels/moe/test_deepgemm.py b/tests/kernels/moe/test_deepgemm.py index fa62507179a20..6a04edafd96c6 100644 --- a/tests/kernels/moe/test_deepgemm.py +++ b/tests/kernels/moe/test_deepgemm.py @@ -13,48 +13,18 @@ import torch # vLLM fused-expert reference (Triton fallback + DeepGEMM option) from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts -from vllm.model_executor.layers.quantization.utils.fp8_utils import ( - per_token_group_quant_fp8) -from vllm.utils import cdiv +from vllm.utils import has_deep_gemm +from vllm.utils.deep_gemm import (calc_diff, per_block_cast_to_fp8, + per_token_group_cast_to_fp8) -has_deep_gemm = importlib.util.find_spec("deep_gemm") is not None - -if has_deep_gemm: - import deep_gemm - BLOCK_M = deep_gemm.get_m_alignment_for_contiguous_layout() - BLOCK_SIZE = [BLOCK_M, BLOCK_M] +BLOCK_SIZE = [128, 128] requires_deep_gemm = pytest.mark.skipif( - not has_deep_gemm, + not has_deep_gemm(), reason="Requires deep_gemm kernels", ) -def calc_diff(x: torch.Tensor, y: torch.Tensor): - x, y = x.double(), y.double() - denominator = (x * x + y * y).sum() - sim = 2 * (x * y).sum() / denominator - return 1 - sim - - -def per_block_cast_to_fp8( - x: torch.Tensor, - block_size_n: int = 128) -> tuple[torch.Tensor, torch.Tensor]: - assert x.dim() == 2 - m, n = x.shape - x_padded = torch.zeros( - (cdiv(m, 128) * 128, cdiv(n, block_size_n) * block_size_n), - dtype=x.dtype, - device=x.device) - x_padded[:m, :n] = x - x_view = x_padded.view(-1, 128, x_padded.size(1) // 128, block_size_n) - x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4) - x_scaled = (x_view * (448.0 / x_amax)).to(torch.float8_e4m3fn) - x_scaled_sub = x_scaled.view_as(x_padded)[:m, :n].contiguous() - scales = (x_amax / 448.0).view(x_view.size(0), x_view.size(2)) - return x_scaled_sub, scales - - def make_block_quant_fp8_weights( e: int, n: int, @@ -111,7 +81,7 @@ def run_single_case(m, n, k, topk, num_experts, block_size): """ tokens_bf16 = torch.randn( m, k, device="cuda", dtype=torch.bfloat16).clamp_min_(-1).clamp_max_(1) - _, a1_scale = per_token_group_quant_fp8(tokens_bf16, block_size[1]) + _, a1_scale = per_token_group_cast_to_fp8(tokens_bf16, block_size[1]) # expert weight tensors w1, w2, w1_s, w2_s = make_block_quant_fp8_weights(num_experts, n, k, @@ -155,17 +125,8 @@ def run_single_case(m, n, k, topk, num_experts, block_size): block_shape=block_size, allow_deep_gemm=True, ) - - base = out_triton.abs().mean() - atol = 0.1 * base.clamp(min=1e-2) # 10% of mean, but not lower than 1e-3 - rtol = 0.05 - # ----- Compare ----- - torch.testing.assert_close( - out_deepgemm.to(torch.float32), - out_triton.to(torch.float32), - rtol=rtol, - atol=float(atol), - ) + diff = calc_diff(out_deepgemm, out_triton) + assert diff < 0.001, f"Diff exceeded 1%: {diff}" # Note: W1 has shape (E, 2N, K), so N = 512 diff --git a/tests/kernels/quantization/test_block_fp8.py b/tests/kernels/quantization/test_block_fp8.py index 42d5526dc21f2..97b5102dd4786 100644 --- a/tests/kernels/quantization/test_block_fp8.py +++ b/tests/kernels/quantization/test_block_fp8.py @@ -8,19 +8,15 @@ import pytest import torch from tests.kernels.quant_utils import (native_per_token_group_quant_fp8, - native_w8a8_block_matmul, - per_block_cast_to_fp8) + native_w8a8_block_matmul) from vllm.config import VllmConfig from vllm.model_executor.layers.quantization.utils.fp8_utils import ( - per_token_group_quant_fp8, w8a8_block_fp8_matmul) + get_col_major_tma_aligned_tensor, per_token_group_quant_fp8, + w8a8_block_fp8_matmul) from vllm.platforms import current_platform - -dg_available = False -try: - import deep_gemm - dg_available = True -except ImportError: - pass +from vllm.utils import has_deep_gemm +from vllm.utils.deep_gemm import (fp8_gemm_nt, per_block_cast_to_fp8, + per_token_group_cast_to_fp8) if current_platform.get_device_capability() < (9, 0): pytest.skip("FP8 Triton requires CUDA 9.0 or higher", @@ -106,7 +102,8 @@ def test_w8a8_block_fp8_matmul(M, N, K, block_size, out_dtype, seed): @pytest.mark.parametrize( "M,N,K,block_size,out_dtype,seed", itertools.product(M, N, K, BLOCK_SIZE, OUT_DTYPES, SEEDS)) -@pytest.mark.skipif(not dg_available, reason="DeepGemm kernels not available.") +@pytest.mark.skipif(not has_deep_gemm(), + reason="DeepGemm kernels not available.") @torch.inference_mode() def test_w8a8_block_fp8_deep_gemm_matmul(M, N, K, block_size, out_dtype, seed): # only aligned sizes @@ -120,9 +117,7 @@ def test_w8a8_block_fp8_deep_gemm_matmul(M, N, K, block_size, out_dtype, seed): A_fp32 = (torch.rand(M, K, dtype=torch.float32) - 0.5) * 2 * fp8_max B_fp32 = (torch.rand(N, K, dtype=torch.float32) - 0.5) * 2 * fp8_max - _, block_k = block_size[0], block_size[1] - - A_fp8, As_fp8 = per_token_group_quant_fp8(A_fp32, block_k) + A_fp8, As_fp8 = per_token_group_cast_to_fp8(A_fp32, block_size[1]) B_fp8, Bs_fp8 = per_block_cast_to_fp8(B_fp32) As = As_fp8.to(torch.float32) @@ -132,14 +127,14 @@ def test_w8a8_block_fp8_deep_gemm_matmul(M, N, K, block_size, out_dtype, seed): out_dtype) # Transpose earlier so that the testing will not trigger transposing kernels - As_fp8 = deep_gemm.get_col_major_tma_aligned_tensor(As_fp8) + As_fp8 = get_col_major_tma_aligned_tensor(As_fp8) out = torch.zeros((M, N), device='cuda', dtype=out_dtype) assert As_fp8.shape == (M, (K + 127) // 128), f"{As_fp8.shape} != {(M, (K + 127) // 128)}" - deep_gemm.gemm_fp8_fp8_bf16_nt((A_fp8, As_fp8), (B_fp8, Bs_fp8), out) + fp8_gemm_nt((A_fp8, As_fp8), (B_fp8, Bs_fp8), out) rel_diff = (torch.mean( torch.abs(out.to(torch.float32) - ref_out.to(torch.float32))) / diff --git a/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py index 751ed6abd999a..70ac6688deb7f 100644 --- a/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py @@ -11,6 +11,7 @@ from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( TopKWeightAndReduceDelegate) from vllm.model_executor.layers.fused_moe.utils import _resize_cache from vllm.triton_utils import tl, triton +from vllm.utils.deep_gemm import fp8_m_grouped_gemm_nt_masked logger = init_logger(__name__) @@ -271,7 +272,6 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): assert expert_tokens_meta is not None expert_num_tokens = expert_tokens_meta.expert_num_tokens - import deep_gemm as dg assert hidden_states.ndim == 3 assert self.block_shape is not None @@ -289,18 +289,15 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): # for the M expectation of each batch, correctly setting this value # may lead to better performance. expected_m = max_num_tokens - - dg.m_grouped_gemm_fp8_fp8_bf16_nt_masked((a1q, a1q_scale), - (w1, w1_scale), - out=workspace1, - masked_m=expert_num_tokens, - expected_m=expected_m) + fp8_m_grouped_gemm_nt_masked((a1q, a1q_scale), (w1, w1_scale), + out=workspace1, + masked_m=expert_num_tokens, + expected_m=expected_m) a2q, a2q_scale = silu_mul_fp8_quant_deep_gemm(workspace1, expert_num_tokens) - dg.m_grouped_gemm_fp8_fp8_bf16_nt_masked((a2q, a2q_scale), - (w2, w2_scale), - out=output, - masked_m=expert_num_tokens, - expected_m=expected_m) + fp8_m_grouped_gemm_nt_masked((a2q, a2q_scale), (w2, w2_scale), + out=output, + masked_m=expert_num_tokens, + expected_m=expected_m) diff --git a/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py index fdeac43902f96..4c0e6665bdc69 100644 --- a/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py @@ -14,9 +14,10 @@ from vllm.model_executor.layers.fused_moe.prepare_finalize import ( MoEPrepareAndFinalizeNoEP) from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( TopKWeightAndReduceDelegate) -from vllm.model_executor.layers.fused_moe.utils import ( - _resize_cache, per_token_group_quant_fp8) +from vllm.model_executor.layers.fused_moe.utils import _resize_cache from vllm.utils import has_deep_gemm, round_up +from vllm.utils.deep_gemm import (m_grouped_fp8_gemm_nt_contiguous, + per_token_group_cast_to_fp8) logger = init_logger(__name__) @@ -127,7 +128,6 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): workspace2: torch.Tensor, expert_tokens_meta: Optional[mk.ExpertTokensMetadata], ): - import deep_gemm as dg assert self.block_shape is not None a1q = hidden_states @@ -164,19 +164,19 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): (M_sum, N // 2)) mm2_out = _resize_cache(workspace2, (M_sum, K)) - dg.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous( - (a1q, a1q_scale), (w1, w1_scale), mm1_out, expert_ids) + m_grouped_fp8_gemm_nt_contiguous((a1q, a1q_scale), (w1, w1_scale), + mm1_out, expert_ids) self.activation(activation, act_out, mm1_out.view(-1, N)) a2q_scale: Optional[torch.Tensor] = None - a2q, a2q_scale = per_token_group_quant_fp8(act_out, - self.block_shape[1], - column_major_scales=True, - out_q=quant_out) + a2q, a2q_scale = per_token_group_cast_to_fp8(act_out, + self.block_shape[1], + column_major_scales=True, + out_q=quant_out) - dg.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous( - (a2q, a2q_scale), (w2, w2_scale), mm2_out, expert_ids) + m_grouped_fp8_gemm_nt_contiguous((a2q, a2q_scale), (w2, w2_scale), + mm2_out, expert_ids) torch.index_select(mm2_out, 0, inv_perm, out=output.view((-1, K))) diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index e16cc9e8507d5..6a9767fc6f3fd 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -34,6 +34,7 @@ from vllm.model_executor.layers.quantization.utils.mxfp4_utils import ( from vllm.platforms import current_platform from vllm.triton_utils import tl, triton from vllm.utils import direct_register_custom_op +from vllm.utils.deep_gemm import is_blackwell_deep_gemm_used from .rocm_aiter_fused_moe import is_rocm_aiter_moe_enabled @@ -1171,9 +1172,15 @@ def fused_experts( allow_cutlass_block_scaled_grouped_gemm: bool = False) -> torch.Tensor: # For now, disable DeepGemm for small N (<= 512) until better # permute/unpermute ops are available. + # However, on B200, we use DeepGemm for all cases becuase they only support + # E8M0 scale, which means we requantize the weight and input to the specific + # scale. Fallen back to cutlass or triton for some cases would cause + # accuracy issue. N = w1.size(1) - if (allow_deep_gemm and use_fp8_w8a8 and N > 512 - and _valid_deep_gemm(hidden_states, w1, w2)): + should_use_deep_gemm = ((N > 512 + and _valid_deep_gemm(hidden_states, w1, w2)) + or is_blackwell_deep_gemm_used()) + if (allow_deep_gemm and use_fp8_w8a8 and should_use_deep_gemm): assert apply_router_weight_on_input is False return deep_gemm_moe_fp8( hidden_states=hidden_states, @@ -1363,7 +1370,6 @@ def fused_experts_impl( curr_topk_ids = topk_ids[begin_chunk_idx:end_chunk_idx] curr_topk_weights = topk_weights[begin_chunk_idx:end_chunk_idx] - qcurr_hidden_states, a1q_scale = moe_kernel_quantize_input( A=curr_hidden_states, A_scale=a1_scale, diff --git a/vllm/model_executor/layers/fused_moe/prepare_finalize.py b/vllm/model_executor/layers/fused_moe/prepare_finalize.py index 567a0a88fec0a..b15c00c44b5d6 100644 --- a/vllm/model_executor/layers/fused_moe/prepare_finalize.py +++ b/vllm/model_executor/layers/fused_moe/prepare_finalize.py @@ -48,7 +48,6 @@ class MoEPrepareAndFinalizeNoEP(mk.FusedMoEPrepareAndFinalize): assert topk == 1, \ "apply_router_weight_on_input is only implemented for topk=1" a1.mul_(topk_weights.to(a1.dtype)) - a1q, a1q_scale = moe_kernel_quantize_input( a1, a1_scale, quant_config.quant_dtype, quant_config.per_act_token_quant, quant_config.block_shape) diff --git a/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py index 891ffd1c79b45..934a98327288d 100644 --- a/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py @@ -9,6 +9,7 @@ from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig from vllm.model_executor.layers.fused_moe.deep_gemm_moe import ( DeepGemmExperts, _valid_deep_gemm, _valid_deep_gemm_shape) from vllm.model_executor.layers.fused_moe.fused_moe import TritonExperts +from vllm.utils.deep_gemm import is_blackwell_deep_gemm_used class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): @@ -102,7 +103,8 @@ class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): # Note: the deep gemm workspaces are strictly larger than the triton # workspaces so we can be pessimistic here and allocate for DeepGemm # even if we fall back to triton later, e.g. if expert maps are set. - if self.allow_deep_gemm and _valid_deep_gemm_shape(M, N, K): + if self.allow_deep_gemm and (_valid_deep_gemm_shape(M, N, K) + or is_blackwell_deep_gemm_used()): assert self.deep_gemm_expert is not None return self.deep_gemm_expert.workspace_shapes( a, aq, M, N, K, topk, global_num_experts, local_num_experts) @@ -132,7 +134,8 @@ class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): expert_tokens_meta: Optional[mk.ExpertTokensMetadata], ): use_deep_gemm = (self.allow_deep_gemm - and _valid_deep_gemm(hidden_states, w1, w2)) + and (_valid_deep_gemm(hidden_states, w1, w2) + or is_blackwell_deep_gemm_used())) experts = self.deep_gemm_expert if use_deep_gemm else self.triton_expert assert experts is not None diff --git a/vllm/model_executor/layers/fused_moe/utils.py b/vllm/model_executor/layers/fused_moe/utils.py index b27e99150541b..75228d3faf3d9 100644 --- a/vllm/model_executor/layers/fused_moe/utils.py +++ b/vllm/model_executor/layers/fused_moe/utils.py @@ -15,6 +15,8 @@ from vllm.model_executor.layers.quantization.utils.mxfp4_utils import ( from vllm.platforms import current_platform from vllm.triton_utils import tl, triton from vllm.utils import cdiv +from vllm.utils.deep_gemm import (is_blackwell_deep_gemm_used, + per_token_group_cast_to_fp8) @triton.jit @@ -115,7 +117,10 @@ def _fp8_quantize( assert not per_act_token assert len(block_shape) == 2 _, block_k = block_shape[0], block_shape[1] - A, A_scale = per_token_group_quant_fp8(A, block_k) + if is_blackwell_deep_gemm_used(): + A, A_scale = per_token_group_cast_to_fp8(A, block_k) + else: + A, A_scale = per_token_group_quant_fp8(A, block_k) assert cdiv(A.size(-1), block_k) == A_scale.size(-1) return A, A_scale diff --git a/vllm/model_executor/layers/mamba/ops/causal_conv1d.py b/vllm/model_executor/layers/mamba/ops/causal_conv1d.py index c1641080ea1e5..6793f6def2b7f 100644 --- a/vllm/model_executor/layers/mamba/ops/causal_conv1d.py +++ b/vllm/model_executor/layers/mamba/ops/causal_conv1d.py @@ -8,10 +8,9 @@ from typing import Optional, Union import numpy as np import torch -import triton -import triton.language as tl from vllm.attention.backends.utils import PAD_SLOT_ID +from vllm.triton_utils import tl, triton @triton.jit() diff --git a/vllm/model_executor/layers/quantization/deepgemm.py b/vllm/model_executor/layers/quantization/deepgemm.py index 5903976eaf6b9..d26a932eddb2c 100644 --- a/vllm/model_executor/layers/quantization/deepgemm.py +++ b/vllm/model_executor/layers/quantization/deepgemm.py @@ -6,10 +6,8 @@ import torch from vllm.platforms import current_platform from vllm.triton_utils import triton -from vllm.utils import direct_register_custom_op, has_deep_gemm - -if has_deep_gemm(): - import deep_gemm +from vllm.utils import direct_register_custom_op +from vllm.utils.deep_gemm import fp8_gemm_nt logger = logging.getLogger(__name__) @@ -57,7 +55,7 @@ def w8a8_block_fp8_matmul_deepgemm( output_dtype) # Deepgemm only supports output tensor type as bfloat16 assert C.dtype == torch.bfloat16 - deep_gemm.gemm_fp8_fp8_bf16_nt((A, As), (B, Bs), C) + fp8_gemm_nt((A, As), (B, Bs), C) return C diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index 5a1a427d7d72e..1e98e6c713840 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -23,6 +23,8 @@ from vllm.model_executor.layers.quantization import QuantizationMethods from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig, QuantizeMethodBase) from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod +from vllm.model_executor.layers.quantization.utils.fp8_utils import ( + get_col_major_tma_aligned_tensor, requant_weight_ue8m0_inplace) from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import ( apply_fp8_marlin_linear, prepare_fp8_layer_for_marlin, prepare_moe_fp8_layer_for_marlin) @@ -40,6 +42,7 @@ from vllm.model_executor.utils import set_weight_attrs from vllm.platforms import current_platform from vllm.scalar_type import scalar_types from vllm.utils import has_deep_gemm +from vllm.utils.deep_gemm import is_blackwell_deep_gemm_used if TYPE_CHECKING: from vllm.model_executor.models.utils import WeightsMapper @@ -393,6 +396,19 @@ class Fp8LinearMethod(LinearMethodBase): # Activations not quantized for marlin. del layer.input_scale + # On B200, DeepGemm only support E8M0 scale, which means we need to + # requantize the weight and input to the specific scale + # at the same time. + if is_blackwell_deep_gemm_used(): + assert layer.weight_block_size is not None + block_sz = tuple(layer.weight_block_size) + requant_weight_ue8m0_inplace( + layer.weight.data, + layer.weight_scale_inv.data if hasattr( + layer, "weight_scale_inv") else layer.weight_scale.data, + block_sz, + ) + def apply(self, layer: torch.nn.Module, x: torch.Tensor, @@ -670,15 +686,14 @@ class Fp8MoEMethod(FusedMoEMethodBase): # DeepGemm scales need to be transposed and aligned. We try to do # it ahead of time for performance reasons. - if self.allow_deep_gemm: + if self.allow_deep_gemm and not is_blackwell_deep_gemm_used(): # Lazy import to avoid CUDA initialization problems. - import deep_gemm as dg if _is_col_major(layer.w13_weight_scale_inv): layer.w13_weight_scale_inv = \ - dg.get_col_major_tma_aligned_tensor(layer.w13_weight_scale_inv).contiguous() + get_col_major_tma_aligned_tensor(layer.w13_weight_scale_inv).contiguous() if _is_col_major(layer.w2_weight_scale_inv): layer.w2_weight_scale_inv = \ - dg.get_col_major_tma_aligned_tensor(layer.w2_weight_scale_inv).contiguous() + get_col_major_tma_aligned_tensor(layer.w2_weight_scale_inv).contiguous() # If checkpoint is fp16, quantize in place. elif not self.quant_config.is_checkpoint_fp8_serialized: @@ -797,6 +812,29 @@ class Fp8MoEMethod(FusedMoEMethodBase): del layer.w13_input_scale del layer.w2_input_scale + if is_blackwell_deep_gemm_used(): + assert layer.weight_block_size is not None + # Re-quantise the expert weights so their scales are UE8M0. + block_sz = tuple(layer.weight_block_size) + requant_weight_ue8m0_inplace( + layer.w13_weight.data, + layer.w13_weight_scale_inv.data, + block_sz, + ) + requant_weight_ue8m0_inplace( + layer.w2_weight.data, + layer.w2_weight_scale_inv.data, + block_sz, + ) + + # Ensure column-major TMA alignment expected by DeepGEMM. + if _is_col_major(layer.w13_weight_scale_inv): + layer.w13_weight_scale_inv = get_col_major_tma_aligned_tensor( + layer.w13_weight_scale_inv).contiguous() + if _is_col_major(layer.w2_weight_scale_inv): + layer.w2_weight_scale_inv = get_col_major_tma_aligned_tensor( + layer.w2_weight_scale_inv).contiguous() + def select_gemm_impl( self, prepare_finalize: FusedMoEPrepareAndFinalize, diff --git a/vllm/model_executor/layers/quantization/utils/fp8_utils.py b/vllm/model_executor/layers/quantization/utils/fp8_utils.py index cbf8231defc6c..1780cc5de2d5c 100644 --- a/vllm/model_executor/layers/quantization/utils/fp8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/fp8_utils.py @@ -5,6 +5,7 @@ import functools import json import os +from collections.abc import Sequence from typing import Any, Callable, Optional, Union import torch @@ -13,7 +14,7 @@ import vllm.envs as envs from vllm import _custom_ops as ops from vllm.logger import init_logger from vllm.model_executor.layers.quantization.utils.quant_utils import ( - scaled_dequantize) + group_broadcast) from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( CUTLASS_BLOCK_FP8_SUPPORTED) from vllm.platforms import current_platform @@ -235,7 +236,7 @@ def block_quant_to_tensor_quant( The outputs are tensor-wise quantization tensor and tensor-wise quantization scale. Note only float8 is supported for now. """ - x_dq_block = scaled_dequantize(x_q_block, x_s) + x_dq_block = group_broadcast(x_q_block, x_s) x_q_tensor, scale = input_to_float8(x_dq_block, dtype=x_q_block.dtype) return x_q_tensor, scale @@ -651,3 +652,124 @@ def w8a8_block_fp8_matmul( ) return C + + +# Taken from https://github.com/deepseek-ai/DeepGEMM/blob/0c88cd01392c1073c7049a97d6328c7bba9b3947 +# TODO(wentao): remove this function when DeepGEMM exposes this function +def get_tma_aligned_size(x: int, element_size: int) -> int: + """ + Global memory address of TMA must be 16-byte aligned. + Since we use column-major layout for the LHS scaling tensor, + the M-axis of the LHS scaling tensor needs to be padded to a multiple of + 16 bytes. + + Arguments: + x: original M-axis shape of the LHS scaling tensor. + element_size: element size of the LHS scaling tensor. + + Returns: + M-axis shape of the LHS scaling tensor after padding. + """ + tma_alignment_bytes = 16 + assert tma_alignment_bytes % element_size == 0 + alignment = tma_alignment_bytes // element_size + return cdiv(x, alignment) * alignment + + +# Taken from https://github.com/deepseek-ai/DeepGEMM/blob/0c88cd01392c1073c7049a97d6328c7bba9b3947 +# TODO(wentao): remove this function when DeepGEMM exposes this function +def get_col_major_tma_aligned_tensor(x: torch.Tensor) -> torch.Tensor: + """ + Returns TMA-aligned transposed format of the input tensor. `torch.transpose` + will be called if necessary. + If the input tensor is already column-major layout and 16-byte aligned along + the M axis (thus meets the requirement of LHS scaling tensor in + DeepGEMM), this function will do nothing. + + Arguments: + x: usually the LHS scaling tensor in GEMM. + + Returns: + The LHS scaling tensor of TMA-aligned transposed format. + """ + # NOTES: for the extreme performance, you may rewrite/fuse this function in + # CUDA + assert x.dim() in (2, 3) + remove_dim = False + m, n = x.shape[-2], x.shape[-1] + aligned_m = get_tma_aligned_size(m, x.element_size()) + if x.dim() == 2: + if x.stride(0) == 1 and x.stride(1) == aligned_m: + return x + x, remove_dim = x.unsqueeze(0), True + + b = x.shape[0] + + # The last kernel gives a column-major TMA aligned layout + if x.stride(0) == aligned_m * n and x.stride(1) == 1 and x.stride( + 2) == aligned_m: + return x.squeeze(0) if remove_dim else x + + # Normal layout requires transposing + aligned_x = torch.transpose( + torch.empty((b, n, aligned_m), device=x.device, dtype=x.dtype), 1, 2) + aligned_x[:, :m, :] = x + aligned_x = aligned_x[:, :m, :] + return aligned_x.squeeze(0) if remove_dim else aligned_x + + +def requant_weight_ue8m0_inplace( + weight: torch.Tensor, + weight_scale: torch.Tensor, + block_size: Sequence[int] = (128, 128), +) -> None: + """Re-quantise *weight* so that its per-block scaling factors are in the + UE8M0 (power-of-two) format expected by the new DeepGEMM kernels inplace. + + Args: + weight: Block-quantised weight tensor stored in ``torch.float8_e4m3fn``. + Expected shape ``(..., M, K)``. + weight_scale: Corresponding per-block scale tensor (``torch.float32``) + with shape ``(..., M // block_size[0], K // block_size[1])``. + block_size: 2-element iterable ``[block_m, block_k]`` describing the + block quantisation granularity. + """ + if weight.numel() == 0: + return + + if weight.dtype != torch.float8_e4m3fn: + raise ValueError("Expected *weight* to be torch.float8_e4m3fn, got " + f"{weight.dtype} instead.") + + from vllm.utils.deep_gemm import per_block_cast_to_fp8 + + block_m, block_k = int(block_size[0]), int(block_size[1]) + + # Flatten leading dimensions so we can iterate over the last two dims. + leading_shape = weight.shape[:-2] + if len(leading_shape) == 0: + w_view = weight.unsqueeze(0) + s_view = weight_scale.unsqueeze(0) + else: + w_view = weight.reshape(-1, weight.shape[-2], weight.shape[-1]) + s_view = weight_scale.reshape(-1, *weight_scale.shape[-2:]) + + num_mats = w_view.size(0) + for idx in range(num_mats): + w_q = w_view[idx] + s_old = s_view[idx] + + # De-quantise with the *old* scaling factors (float32). + m_cur, k_cur = w_q.shape + s_float = s_old.to(torch.float32) + # Expand scales along rows and cols by block size, then crop. + s_exp_r = torch.repeat_interleave(s_float, block_m, dim=0) + s_exp = torch.repeat_interleave(s_exp_r, block_k, dim=1) + s_exp = s_exp[:m_cur, :k_cur] + w_dq = w_q.to(torch.float32) * s_exp + # Re-quantise using power-of-two scaling (UE8M0). + w_requant, s_requant = per_block_cast_to_fp8(w_dq, [block_m, block_k]) + + # Write back the results in-place. + w_q.copy_(w_requant) + s_old.copy_(s_requant) diff --git a/vllm/utils/deep_gemm.py b/vllm/utils/deep_gemm.py new file mode 100644 index 0000000000000..1684d6754f504 --- /dev/null +++ b/vllm/utils/deep_gemm.py @@ -0,0 +1,152 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Compatibility wrapper for DeepGEMM API changes. + +Users of vLLM should always import **only** these wrappers. +""" +from __future__ import annotations + +import functools +import importlib +from typing import Any, Callable, NoReturn + +import torch + +import vllm.envs as envs +from vllm.utils import cuda_get_device_properties, has_deep_gemm + + +@functools.cache +def is_blackwell_deep_gemm_used() -> bool: + """Return ``True`` if vLLM is configured to use DeepGEMM on a + Blackwell-class GPU. + """ + + if not (envs.VLLM_USE_DEEP_GEMM and has_deep_gemm() + and _per_block_cast_impl is not None): + return False + + return cuda_get_device_properties(0, ("major", ))[0] == 10 + + +def _missing(*_: Any, **__: Any) -> NoReturn: + """Placeholder for unavailable DeepGEMM backend.""" + raise RuntimeError( + "DeepGEMM backend is not available. Please install the `deep_gemm` " + "package to enable FP8 kernels.") + + +def _resolve_symbol(module, new: str, old: str) -> Callable[..., Any] | None: + """Return the *new* symbol if it exists, otherwise the *old* one.""" + if hasattr(module, new): + return getattr(module, new) + if hasattr(module, old): + return getattr(module, old) + return None + + +if not has_deep_gemm(): + _fp8_gemm_nt_impl: Callable[..., Any] | None = None + _grouped_impl: Callable[..., Any] | None = None + _grouped_masked_impl: Callable[..., Any] | None = None + _per_token_cast_impl: Callable[..., Any] | None = None + _per_block_cast_impl: Callable[..., Any] | None = None +else: + _dg = importlib.import_module("deep_gemm") # type: ignore + + _fp8_gemm_nt_impl = _resolve_symbol( + _dg, + "fp8_gemm_nt", + "gemm_fp8_fp8_bf16_nt", + ) + _grouped_impl = _resolve_symbol( + _dg, + "m_grouped_fp8_gemm_nt_contiguous", + "m_grouped_gemm_fp8_fp8_bf16_nt_contiguous", + ) + _grouped_masked_impl = _resolve_symbol( + _dg, + "fp8_m_grouped_gemm_nt_masked", + "m_grouped_gemm_fp8_fp8_bf16_nt_masked", + ) + + # Try to get per_token_cast_to_fp8 from DeepGEMM math utils. + try: + _math_mod = importlib.import_module( + "deep_gemm.utils.math") # type: ignore + _per_token_cast_impl = getattr(_math_mod, "per_token_cast_to_fp8", + None) + _per_block_cast_impl = getattr(_math_mod, "per_block_cast_to_fp8", + None) + except ModuleNotFoundError: + _per_token_cast_impl = None + _per_block_cast_impl = None + + +def fp8_gemm_nt(*args, **kwargs): + if _fp8_gemm_nt_impl is None: + return _missing(*args, **kwargs) + return _fp8_gemm_nt_impl(*args, **kwargs) + + +def m_grouped_fp8_gemm_nt_contiguous(*args, **kwargs): + if _grouped_impl is None: + return _missing(*args, **kwargs) + return _grouped_impl(*args, **kwargs) + + +def fp8_m_grouped_gemm_nt_masked(*args, **kwargs): + if _grouped_masked_impl is None: + return _missing(*args, **kwargs) + return _grouped_masked_impl(*args, **kwargs) + + +def per_token_group_cast_to_fp8(x, group_size, *args, **kwargs): + """Wrapper for token-wise FP8 quantisation. + + • If DeepGEMM provides ``per_token_cast_to_fp8`` (new API), use it. + • Otherwise, fall back to vLLM's ``per_token_group_quant_fp8`` + """ + + if _per_token_cast_impl is not None and is_blackwell_deep_gemm_used(): + assert group_size == 128, "group_size must be 128 for deepgemm" + return _per_token_cast_impl(x) + + from vllm.model_executor.layers.quantization.utils.fp8_utils import ( + per_token_group_quant_fp8 as _ptg) + return _ptg(x, group_size, *args, **kwargs) + + +def per_block_cast_to_fp8(x, *args, **kwargs): + if _per_block_cast_impl is not None and is_blackwell_deep_gemm_used(): + return _per_block_cast_impl(x) + # TODO: refactor the `per_block_cast_to_fp8` from tests to vllm utils + from tests.kernels.quant_utils import per_block_cast_to_fp8 as _pbcf + return _pbcf(x, *args, **kwargs) + + +def calc_diff(x: torch.Tensor, y: torch.Tensor): + """Return a global difference metric for unit tests. + + DeepGEMM kernels on Blackwell/B200 currently exhibit noticeable per-element + error, causing ``torch.testing.assert_close`` to fail. Instead of checking + every element, we compute a cosine-style similarity over the whole tensor + and report ``1 - sim``. Once kernel accuracy improves this helper can be + removed. + """ + + x, y = x.double(), y.double() + denominator = (x * x + y * y).sum() + sim = 2 * (x * y).sum() / denominator + return 1 - sim + + +__all__ = [ + "calc_diff", + "fp8_gemm_nt", + "m_grouped_fp8_gemm_nt_contiguous", + "fp8_m_grouped_gemm_nt_masked", + "per_token_group_cast_to_fp8", + "per_block_cast_to_fp8", + "is_blackwell_deep_gemm_used", +]