[Feature] Integrate SM100 DeepGEMM support (#20087)

This commit is contained in:
Wentao Ye 2025-07-10 23:18:05 -04:00 committed by GitHub
parent 5b032352cc
commit e2de455c34
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
16 changed files with 397 additions and 114 deletions

View File

@ -86,6 +86,9 @@ def benchmark_config(
(num_experts, 2 * shard_intermediate_size), dtype=torch.float32 (num_experts, 2 * shard_intermediate_size), dtype=torch.float32
) )
w2_scale = torch.randn((hidden_size, num_experts), dtype=torch.float32) w2_scale = torch.randn((hidden_size, num_experts), dtype=torch.float32)
if use_deep_gemm:
# we use the default block shape for deepgemm
block_quant_shape = [128, 128]
if use_fp8_w8a8: if use_fp8_w8a8:
if block_quant_shape: if block_quant_shape:
block_n, block_k = block_quant_shape[0], block_quant_shape[1] block_n, block_k = block_quant_shape[0], block_quant_shape[1]

View File

@ -15,13 +15,13 @@ from vllm.model_executor.layers.fused_moe.deep_gemm_moe import (
from vllm.model_executor.layers.fused_moe.fused_moe import ( from vllm.model_executor.layers.fused_moe.fused_moe import (
fused_topk, modular_triton_fused_moe) fused_topk, modular_triton_fused_moe)
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils import has_deep_gemm
from vllm.utils.deep_gemm import is_blackwell_deep_gemm_used
dg_available = False dg_available = has_deep_gemm()
try:
import deep_gemm if dg_available:
dg_available = True from deep_gemm import get_m_alignment_for_contiguous_layout
except ImportError:
pass
if current_platform.get_device_capability() < (9, 0): if current_platform.get_device_capability() < (9, 0):
pytest.skip("FP8 Triton requires CUDA 9.0 or higher", pytest.skip("FP8 Triton requires CUDA 9.0 or higher",
@ -224,6 +224,7 @@ def test_w8a8_block_fp8_fused_moe(M, N, K, E, topk, block_size, dtype, seed,
@pytest.mark.parametrize("topk", TOP_KS) @pytest.mark.parametrize("topk", TOP_KS)
@pytest.mark.parametrize("seed", SEEDS) @pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.skipif(not dg_available, reason="DeepGemm kernels not available.") @pytest.mark.skipif(not dg_available, reason="DeepGemm kernels not available.")
@pytest.mark.skipif(is_blackwell_deep_gemm_used(), reason="Not E8M0 scale MOE")
@torch.inference_mode() @torch.inference_mode()
def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, seed, def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, seed,
monkeypatch): monkeypatch):
@ -238,8 +239,7 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, seed,
torch.manual_seed(seed) torch.manual_seed(seed)
monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", str(chunk_size)) monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", str(chunk_size))
block_m = get_m_alignment_for_contiguous_layout()
block_m = deep_gemm.get_m_alignment_for_contiguous_layout()
block_size = [block_m, block_m] block_size = [block_m, block_m]
dtype = torch.bfloat16 dtype = torch.bfloat16

View File

@ -20,6 +20,7 @@ from vllm.model_executor.layers.fused_moe.modular_kernel import (
FusedMoEModularKernel) FusedMoEModularKernel)
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils import has_deep_ep, has_deep_gemm from vllm.utils import has_deep_ep, has_deep_gemm
from vllm.utils.deep_gemm import is_blackwell_deep_gemm_used
from .parallel_utils import ProcessGroupInfo, parallel_launch from .parallel_utils import ProcessGroupInfo, parallel_launch
from .utils import make_test_weights from .utils import make_test_weights
@ -368,6 +369,8 @@ NUM_EXPERTS = [32]
@pytest.mark.parametrize("world_dp_size", [(2, 1)]) @pytest.mark.parametrize("world_dp_size", [(2, 1)])
@requires_deep_ep @requires_deep_ep
@requires_deep_gemm @requires_deep_gemm
@pytest.mark.skipif(is_blackwell_deep_gemm_used(),
reason="Skipping test for Blackwell DeepGEMM")
def test_ht_deepep_deepgemm_moe(mnk: tuple[int, int, int], num_experts: int, def test_ht_deepep_deepgemm_moe(mnk: tuple[int, int, int], num_experts: int,
topk: int, world_dp_size: tuple[int, int]): topk: int, world_dp_size: tuple[int, int]):
""" """
@ -423,6 +426,8 @@ USE_FP8_DISPATCH = [False]
@pytest.mark.parametrize("world_dp_size", [(2, 1)]) @pytest.mark.parametrize("world_dp_size", [(2, 1)])
@requires_deep_ep @requires_deep_ep
@requires_deep_gemm @requires_deep_gemm
@pytest.mark.skipif(is_blackwell_deep_gemm_used(),
reason="Skipping test for Blackwell DeepGEMM")
def test_ll_deepep_deepgemm_moe( def test_ll_deepep_deepgemm_moe(
mnk: tuple[int, int, int], mnk: tuple[int, int, int],
num_experts: int, num_experts: int,

View File

@ -13,48 +13,18 @@ import torch
# vLLM fused-expert reference (Triton fallback + DeepGEMM option) # vLLM fused-expert reference (Triton fallback + DeepGEMM option)
from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts
from vllm.model_executor.layers.quantization.utils.fp8_utils import ( from vllm.utils import has_deep_gemm
per_token_group_quant_fp8) from vllm.utils.deep_gemm import (calc_diff, per_block_cast_to_fp8,
from vllm.utils import cdiv per_token_group_cast_to_fp8)
has_deep_gemm = importlib.util.find_spec("deep_gemm") is not None BLOCK_SIZE = [128, 128]
if has_deep_gemm:
import deep_gemm
BLOCK_M = deep_gemm.get_m_alignment_for_contiguous_layout()
BLOCK_SIZE = [BLOCK_M, BLOCK_M]
requires_deep_gemm = pytest.mark.skipif( requires_deep_gemm = pytest.mark.skipif(
not has_deep_gemm, not has_deep_gemm(),
reason="Requires deep_gemm kernels", reason="Requires deep_gemm kernels",
) )
def calc_diff(x: torch.Tensor, y: torch.Tensor):
x, y = x.double(), y.double()
denominator = (x * x + y * y).sum()
sim = 2 * (x * y).sum() / denominator
return 1 - sim
def per_block_cast_to_fp8(
x: torch.Tensor,
block_size_n: int = 128) -> tuple[torch.Tensor, torch.Tensor]:
assert x.dim() == 2
m, n = x.shape
x_padded = torch.zeros(
(cdiv(m, 128) * 128, cdiv(n, block_size_n) * block_size_n),
dtype=x.dtype,
device=x.device)
x_padded[:m, :n] = x
x_view = x_padded.view(-1, 128, x_padded.size(1) // 128, block_size_n)
x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4)
x_scaled = (x_view * (448.0 / x_amax)).to(torch.float8_e4m3fn)
x_scaled_sub = x_scaled.view_as(x_padded)[:m, :n].contiguous()
scales = (x_amax / 448.0).view(x_view.size(0), x_view.size(2))
return x_scaled_sub, scales
def make_block_quant_fp8_weights( def make_block_quant_fp8_weights(
e: int, e: int,
n: int, n: int,
@ -111,7 +81,7 @@ def run_single_case(m, n, k, topk, num_experts, block_size):
""" """
tokens_bf16 = torch.randn( tokens_bf16 = torch.randn(
m, k, device="cuda", dtype=torch.bfloat16).clamp_min_(-1).clamp_max_(1) m, k, device="cuda", dtype=torch.bfloat16).clamp_min_(-1).clamp_max_(1)
_, a1_scale = per_token_group_quant_fp8(tokens_bf16, block_size[1]) _, a1_scale = per_token_group_cast_to_fp8(tokens_bf16, block_size[1])
# expert weight tensors # expert weight tensors
w1, w2, w1_s, w2_s = make_block_quant_fp8_weights(num_experts, n, k, w1, w2, w1_s, w2_s = make_block_quant_fp8_weights(num_experts, n, k,
@ -155,17 +125,8 @@ def run_single_case(m, n, k, topk, num_experts, block_size):
block_shape=block_size, block_shape=block_size,
allow_deep_gemm=True, allow_deep_gemm=True,
) )
diff = calc_diff(out_deepgemm, out_triton)
base = out_triton.abs().mean() assert diff < 0.001, f"Diff exceeded 1%: {diff}"
atol = 0.1 * base.clamp(min=1e-2) # 10% of mean, but not lower than 1e-3
rtol = 0.05
# ----- Compare -----
torch.testing.assert_close(
out_deepgemm.to(torch.float32),
out_triton.to(torch.float32),
rtol=rtol,
atol=float(atol),
)
# Note: W1 has shape (E, 2N, K), so N = 512 # Note: W1 has shape (E, 2N, K), so N = 512

View File

@ -8,19 +8,15 @@ import pytest
import torch import torch
from tests.kernels.quant_utils import (native_per_token_group_quant_fp8, from tests.kernels.quant_utils import (native_per_token_group_quant_fp8,
native_w8a8_block_matmul, native_w8a8_block_matmul)
per_block_cast_to_fp8)
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.model_executor.layers.quantization.utils.fp8_utils import ( from vllm.model_executor.layers.quantization.utils.fp8_utils import (
per_token_group_quant_fp8, w8a8_block_fp8_matmul) get_col_major_tma_aligned_tensor, per_token_group_quant_fp8,
w8a8_block_fp8_matmul)
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils import has_deep_gemm
dg_available = False from vllm.utils.deep_gemm import (fp8_gemm_nt, per_block_cast_to_fp8,
try: per_token_group_cast_to_fp8)
import deep_gemm
dg_available = True
except ImportError:
pass
if current_platform.get_device_capability() < (9, 0): if current_platform.get_device_capability() < (9, 0):
pytest.skip("FP8 Triton requires CUDA 9.0 or higher", pytest.skip("FP8 Triton requires CUDA 9.0 or higher",
@ -106,7 +102,8 @@ def test_w8a8_block_fp8_matmul(M, N, K, block_size, out_dtype, seed):
@pytest.mark.parametrize( @pytest.mark.parametrize(
"M,N,K,block_size,out_dtype,seed", "M,N,K,block_size,out_dtype,seed",
itertools.product(M, N, K, BLOCK_SIZE, OUT_DTYPES, SEEDS)) itertools.product(M, N, K, BLOCK_SIZE, OUT_DTYPES, SEEDS))
@pytest.mark.skipif(not dg_available, reason="DeepGemm kernels not available.") @pytest.mark.skipif(not has_deep_gemm(),
reason="DeepGemm kernels not available.")
@torch.inference_mode() @torch.inference_mode()
def test_w8a8_block_fp8_deep_gemm_matmul(M, N, K, block_size, out_dtype, seed): def test_w8a8_block_fp8_deep_gemm_matmul(M, N, K, block_size, out_dtype, seed):
# only aligned sizes # only aligned sizes
@ -120,9 +117,7 @@ def test_w8a8_block_fp8_deep_gemm_matmul(M, N, K, block_size, out_dtype, seed):
A_fp32 = (torch.rand(M, K, dtype=torch.float32) - 0.5) * 2 * fp8_max A_fp32 = (torch.rand(M, K, dtype=torch.float32) - 0.5) * 2 * fp8_max
B_fp32 = (torch.rand(N, K, dtype=torch.float32) - 0.5) * 2 * fp8_max B_fp32 = (torch.rand(N, K, dtype=torch.float32) - 0.5) * 2 * fp8_max
_, block_k = block_size[0], block_size[1] A_fp8, As_fp8 = per_token_group_cast_to_fp8(A_fp32, block_size[1])
A_fp8, As_fp8 = per_token_group_quant_fp8(A_fp32, block_k)
B_fp8, Bs_fp8 = per_block_cast_to_fp8(B_fp32) B_fp8, Bs_fp8 = per_block_cast_to_fp8(B_fp32)
As = As_fp8.to(torch.float32) As = As_fp8.to(torch.float32)
@ -132,14 +127,14 @@ def test_w8a8_block_fp8_deep_gemm_matmul(M, N, K, block_size, out_dtype, seed):
out_dtype) out_dtype)
# Transpose earlier so that the testing will not trigger transposing kernels # Transpose earlier so that the testing will not trigger transposing kernels
As_fp8 = deep_gemm.get_col_major_tma_aligned_tensor(As_fp8) As_fp8 = get_col_major_tma_aligned_tensor(As_fp8)
out = torch.zeros((M, N), device='cuda', dtype=out_dtype) out = torch.zeros((M, N), device='cuda', dtype=out_dtype)
assert As_fp8.shape == (M, (K + 127) // assert As_fp8.shape == (M, (K + 127) //
128), f"{As_fp8.shape} != {(M, (K + 127) // 128)}" 128), f"{As_fp8.shape} != {(M, (K + 127) // 128)}"
deep_gemm.gemm_fp8_fp8_bf16_nt((A_fp8, As_fp8), (B_fp8, Bs_fp8), out) fp8_gemm_nt((A_fp8, As_fp8), (B_fp8, Bs_fp8), out)
rel_diff = (torch.mean( rel_diff = (torch.mean(
torch.abs(out.to(torch.float32) - ref_out.to(torch.float32))) / torch.abs(out.to(torch.float32) - ref_out.to(torch.float32))) /

View File

@ -11,6 +11,7 @@ from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
TopKWeightAndReduceDelegate) TopKWeightAndReduceDelegate)
from vllm.model_executor.layers.fused_moe.utils import _resize_cache from vllm.model_executor.layers.fused_moe.utils import _resize_cache
from vllm.triton_utils import tl, triton from vllm.triton_utils import tl, triton
from vllm.utils.deep_gemm import fp8_m_grouped_gemm_nt_masked
logger = init_logger(__name__) logger = init_logger(__name__)
@ -271,7 +272,6 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
assert expert_tokens_meta is not None assert expert_tokens_meta is not None
expert_num_tokens = expert_tokens_meta.expert_num_tokens expert_num_tokens = expert_tokens_meta.expert_num_tokens
import deep_gemm as dg
assert hidden_states.ndim == 3 assert hidden_states.ndim == 3
assert self.block_shape is not None assert self.block_shape is not None
@ -289,18 +289,15 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
# for the M expectation of each batch, correctly setting this value # for the M expectation of each batch, correctly setting this value
# may lead to better performance. # may lead to better performance.
expected_m = max_num_tokens expected_m = max_num_tokens
fp8_m_grouped_gemm_nt_masked((a1q, a1q_scale), (w1, w1_scale),
dg.m_grouped_gemm_fp8_fp8_bf16_nt_masked((a1q, a1q_scale), out=workspace1,
(w1, w1_scale), masked_m=expert_num_tokens,
out=workspace1, expected_m=expected_m)
masked_m=expert_num_tokens,
expected_m=expected_m)
a2q, a2q_scale = silu_mul_fp8_quant_deep_gemm(workspace1, a2q, a2q_scale = silu_mul_fp8_quant_deep_gemm(workspace1,
expert_num_tokens) expert_num_tokens)
dg.m_grouped_gemm_fp8_fp8_bf16_nt_masked((a2q, a2q_scale), fp8_m_grouped_gemm_nt_masked((a2q, a2q_scale), (w2, w2_scale),
(w2, w2_scale), out=output,
out=output, masked_m=expert_num_tokens,
masked_m=expert_num_tokens, expected_m=expected_m)
expected_m=expected_m)

View File

@ -14,9 +14,10 @@ from vllm.model_executor.layers.fused_moe.prepare_finalize import (
MoEPrepareAndFinalizeNoEP) MoEPrepareAndFinalizeNoEP)
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
TopKWeightAndReduceDelegate) TopKWeightAndReduceDelegate)
from vllm.model_executor.layers.fused_moe.utils import ( from vllm.model_executor.layers.fused_moe.utils import _resize_cache
_resize_cache, per_token_group_quant_fp8)
from vllm.utils import has_deep_gemm, round_up from vllm.utils import has_deep_gemm, round_up
from vllm.utils.deep_gemm import (m_grouped_fp8_gemm_nt_contiguous,
per_token_group_cast_to_fp8)
logger = init_logger(__name__) logger = init_logger(__name__)
@ -127,7 +128,6 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
workspace2: torch.Tensor, workspace2: torch.Tensor,
expert_tokens_meta: Optional[mk.ExpertTokensMetadata], expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
): ):
import deep_gemm as dg
assert self.block_shape is not None assert self.block_shape is not None
a1q = hidden_states a1q = hidden_states
@ -164,19 +164,19 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
(M_sum, N // 2)) (M_sum, N // 2))
mm2_out = _resize_cache(workspace2, (M_sum, K)) mm2_out = _resize_cache(workspace2, (M_sum, K))
dg.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous( m_grouped_fp8_gemm_nt_contiguous((a1q, a1q_scale), (w1, w1_scale),
(a1q, a1q_scale), (w1, w1_scale), mm1_out, expert_ids) mm1_out, expert_ids)
self.activation(activation, act_out, mm1_out.view(-1, N)) self.activation(activation, act_out, mm1_out.view(-1, N))
a2q_scale: Optional[torch.Tensor] = None a2q_scale: Optional[torch.Tensor] = None
a2q, a2q_scale = per_token_group_quant_fp8(act_out, a2q, a2q_scale = per_token_group_cast_to_fp8(act_out,
self.block_shape[1], self.block_shape[1],
column_major_scales=True, column_major_scales=True,
out_q=quant_out) out_q=quant_out)
dg.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous( m_grouped_fp8_gemm_nt_contiguous((a2q, a2q_scale), (w2, w2_scale),
(a2q, a2q_scale), (w2, w2_scale), mm2_out, expert_ids) mm2_out, expert_ids)
torch.index_select(mm2_out, 0, inv_perm, out=output.view((-1, K))) torch.index_select(mm2_out, 0, inv_perm, out=output.view((-1, K)))

View File

@ -34,6 +34,7 @@ from vllm.model_executor.layers.quantization.utils.mxfp4_utils import (
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.triton_utils import tl, triton from vllm.triton_utils import tl, triton
from vllm.utils import direct_register_custom_op from vllm.utils import direct_register_custom_op
from vllm.utils.deep_gemm import is_blackwell_deep_gemm_used
from .rocm_aiter_fused_moe import is_rocm_aiter_moe_enabled from .rocm_aiter_fused_moe import is_rocm_aiter_moe_enabled
@ -1171,9 +1172,15 @@ def fused_experts(
allow_cutlass_block_scaled_grouped_gemm: bool = False) -> torch.Tensor: allow_cutlass_block_scaled_grouped_gemm: bool = False) -> torch.Tensor:
# For now, disable DeepGemm for small N (<= 512) until better # For now, disable DeepGemm for small N (<= 512) until better
# permute/unpermute ops are available. # permute/unpermute ops are available.
# However, on B200, we use DeepGemm for all cases becuase they only support
# E8M0 scale, which means we requantize the weight and input to the specific
# scale. Fallen back to cutlass or triton for some cases would cause
# accuracy issue.
N = w1.size(1) N = w1.size(1)
if (allow_deep_gemm and use_fp8_w8a8 and N > 512 should_use_deep_gemm = ((N > 512
and _valid_deep_gemm(hidden_states, w1, w2)): and _valid_deep_gemm(hidden_states, w1, w2))
or is_blackwell_deep_gemm_used())
if (allow_deep_gemm and use_fp8_w8a8 and should_use_deep_gemm):
assert apply_router_weight_on_input is False assert apply_router_weight_on_input is False
return deep_gemm_moe_fp8( return deep_gemm_moe_fp8(
hidden_states=hidden_states, hidden_states=hidden_states,
@ -1363,7 +1370,6 @@ def fused_experts_impl(
curr_topk_ids = topk_ids[begin_chunk_idx:end_chunk_idx] curr_topk_ids = topk_ids[begin_chunk_idx:end_chunk_idx]
curr_topk_weights = topk_weights[begin_chunk_idx:end_chunk_idx] curr_topk_weights = topk_weights[begin_chunk_idx:end_chunk_idx]
qcurr_hidden_states, a1q_scale = moe_kernel_quantize_input( qcurr_hidden_states, a1q_scale = moe_kernel_quantize_input(
A=curr_hidden_states, A=curr_hidden_states,
A_scale=a1_scale, A_scale=a1_scale,

View File

@ -48,7 +48,6 @@ class MoEPrepareAndFinalizeNoEP(mk.FusedMoEPrepareAndFinalize):
assert topk == 1, \ assert topk == 1, \
"apply_router_weight_on_input is only implemented for topk=1" "apply_router_weight_on_input is only implemented for topk=1"
a1.mul_(topk_weights.to(a1.dtype)) a1.mul_(topk_weights.to(a1.dtype))
a1q, a1q_scale = moe_kernel_quantize_input( a1q, a1q_scale = moe_kernel_quantize_input(
a1, a1_scale, quant_config.quant_dtype, a1, a1_scale, quant_config.quant_dtype,
quant_config.per_act_token_quant, quant_config.block_shape) quant_config.per_act_token_quant, quant_config.block_shape)

View File

@ -9,6 +9,7 @@ from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
from vllm.model_executor.layers.fused_moe.deep_gemm_moe import ( from vllm.model_executor.layers.fused_moe.deep_gemm_moe import (
DeepGemmExperts, _valid_deep_gemm, _valid_deep_gemm_shape) DeepGemmExperts, _valid_deep_gemm, _valid_deep_gemm_shape)
from vllm.model_executor.layers.fused_moe.fused_moe import TritonExperts from vllm.model_executor.layers.fused_moe.fused_moe import TritonExperts
from vllm.utils.deep_gemm import is_blackwell_deep_gemm_used
class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
@ -102,7 +103,8 @@ class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
# Note: the deep gemm workspaces are strictly larger than the triton # Note: the deep gemm workspaces are strictly larger than the triton
# workspaces so we can be pessimistic here and allocate for DeepGemm # workspaces so we can be pessimistic here and allocate for DeepGemm
# even if we fall back to triton later, e.g. if expert maps are set. # even if we fall back to triton later, e.g. if expert maps are set.
if self.allow_deep_gemm and _valid_deep_gemm_shape(M, N, K): if self.allow_deep_gemm and (_valid_deep_gemm_shape(M, N, K)
or is_blackwell_deep_gemm_used()):
assert self.deep_gemm_expert is not None assert self.deep_gemm_expert is not None
return self.deep_gemm_expert.workspace_shapes( return self.deep_gemm_expert.workspace_shapes(
a, aq, M, N, K, topk, global_num_experts, local_num_experts) a, aq, M, N, K, topk, global_num_experts, local_num_experts)
@ -132,7 +134,8 @@ class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
expert_tokens_meta: Optional[mk.ExpertTokensMetadata], expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
): ):
use_deep_gemm = (self.allow_deep_gemm use_deep_gemm = (self.allow_deep_gemm
and _valid_deep_gemm(hidden_states, w1, w2)) and (_valid_deep_gemm(hidden_states, w1, w2)
or is_blackwell_deep_gemm_used()))
experts = self.deep_gemm_expert if use_deep_gemm else self.triton_expert experts = self.deep_gemm_expert if use_deep_gemm else self.triton_expert
assert experts is not None assert experts is not None

View File

@ -15,6 +15,8 @@ from vllm.model_executor.layers.quantization.utils.mxfp4_utils import (
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.triton_utils import tl, triton from vllm.triton_utils import tl, triton
from vllm.utils import cdiv from vllm.utils import cdiv
from vllm.utils.deep_gemm import (is_blackwell_deep_gemm_used,
per_token_group_cast_to_fp8)
@triton.jit @triton.jit
@ -115,7 +117,10 @@ def _fp8_quantize(
assert not per_act_token assert not per_act_token
assert len(block_shape) == 2 assert len(block_shape) == 2
_, block_k = block_shape[0], block_shape[1] _, block_k = block_shape[0], block_shape[1]
A, A_scale = per_token_group_quant_fp8(A, block_k) if is_blackwell_deep_gemm_used():
A, A_scale = per_token_group_cast_to_fp8(A, block_k)
else:
A, A_scale = per_token_group_quant_fp8(A, block_k)
assert cdiv(A.size(-1), block_k) == A_scale.size(-1) assert cdiv(A.size(-1), block_k) == A_scale.size(-1)
return A, A_scale return A, A_scale

View File

@ -8,10 +8,9 @@ from typing import Optional, Union
import numpy as np import numpy as np
import torch import torch
import triton
import triton.language as tl
from vllm.attention.backends.utils import PAD_SLOT_ID from vllm.attention.backends.utils import PAD_SLOT_ID
from vllm.triton_utils import tl, triton
@triton.jit() @triton.jit()

View File

@ -6,10 +6,8 @@ import torch
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.triton_utils import triton from vllm.triton_utils import triton
from vllm.utils import direct_register_custom_op, has_deep_gemm from vllm.utils import direct_register_custom_op
from vllm.utils.deep_gemm import fp8_gemm_nt
if has_deep_gemm():
import deep_gemm
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -57,7 +55,7 @@ def w8a8_block_fp8_matmul_deepgemm(
output_dtype) output_dtype)
# Deepgemm only supports output tensor type as bfloat16 # Deepgemm only supports output tensor type as bfloat16
assert C.dtype == torch.bfloat16 assert C.dtype == torch.bfloat16
deep_gemm.gemm_fp8_fp8_bf16_nt((A, As), (B, Bs), C) fp8_gemm_nt((A, As), (B, Bs), C)
return C return C

View File

@ -23,6 +23,8 @@ from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase) QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
get_col_major_tma_aligned_tensor, requant_weight_ue8m0_inplace)
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import ( from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
apply_fp8_marlin_linear, prepare_fp8_layer_for_marlin, apply_fp8_marlin_linear, prepare_fp8_layer_for_marlin,
prepare_moe_fp8_layer_for_marlin) prepare_moe_fp8_layer_for_marlin)
@ -40,6 +42,7 @@ from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.scalar_type import scalar_types from vllm.scalar_type import scalar_types
from vllm.utils import has_deep_gemm from vllm.utils import has_deep_gemm
from vllm.utils.deep_gemm import is_blackwell_deep_gemm_used
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.model_executor.models.utils import WeightsMapper from vllm.model_executor.models.utils import WeightsMapper
@ -393,6 +396,19 @@ class Fp8LinearMethod(LinearMethodBase):
# Activations not quantized for marlin. # Activations not quantized for marlin.
del layer.input_scale del layer.input_scale
# On B200, DeepGemm only support E8M0 scale, which means we need to
# requantize the weight and input to the specific scale
# at the same time.
if is_blackwell_deep_gemm_used():
assert layer.weight_block_size is not None
block_sz = tuple(layer.weight_block_size)
requant_weight_ue8m0_inplace(
layer.weight.data,
layer.weight_scale_inv.data if hasattr(
layer, "weight_scale_inv") else layer.weight_scale.data,
block_sz,
)
def apply(self, def apply(self,
layer: torch.nn.Module, layer: torch.nn.Module,
x: torch.Tensor, x: torch.Tensor,
@ -670,15 +686,14 @@ class Fp8MoEMethod(FusedMoEMethodBase):
# DeepGemm scales need to be transposed and aligned. We try to do # DeepGemm scales need to be transposed and aligned. We try to do
# it ahead of time for performance reasons. # it ahead of time for performance reasons.
if self.allow_deep_gemm: if self.allow_deep_gemm and not is_blackwell_deep_gemm_used():
# Lazy import to avoid CUDA initialization problems. # Lazy import to avoid CUDA initialization problems.
import deep_gemm as dg
if _is_col_major(layer.w13_weight_scale_inv): if _is_col_major(layer.w13_weight_scale_inv):
layer.w13_weight_scale_inv = \ layer.w13_weight_scale_inv = \
dg.get_col_major_tma_aligned_tensor(layer.w13_weight_scale_inv).contiguous() get_col_major_tma_aligned_tensor(layer.w13_weight_scale_inv).contiguous()
if _is_col_major(layer.w2_weight_scale_inv): if _is_col_major(layer.w2_weight_scale_inv):
layer.w2_weight_scale_inv = \ layer.w2_weight_scale_inv = \
dg.get_col_major_tma_aligned_tensor(layer.w2_weight_scale_inv).contiguous() get_col_major_tma_aligned_tensor(layer.w2_weight_scale_inv).contiguous()
# If checkpoint is fp16, quantize in place. # If checkpoint is fp16, quantize in place.
elif not self.quant_config.is_checkpoint_fp8_serialized: elif not self.quant_config.is_checkpoint_fp8_serialized:
@ -797,6 +812,29 @@ class Fp8MoEMethod(FusedMoEMethodBase):
del layer.w13_input_scale del layer.w13_input_scale
del layer.w2_input_scale del layer.w2_input_scale
if is_blackwell_deep_gemm_used():
assert layer.weight_block_size is not None
# Re-quantise the expert weights so their scales are UE8M0.
block_sz = tuple(layer.weight_block_size)
requant_weight_ue8m0_inplace(
layer.w13_weight.data,
layer.w13_weight_scale_inv.data,
block_sz,
)
requant_weight_ue8m0_inplace(
layer.w2_weight.data,
layer.w2_weight_scale_inv.data,
block_sz,
)
# Ensure column-major TMA alignment expected by DeepGEMM.
if _is_col_major(layer.w13_weight_scale_inv):
layer.w13_weight_scale_inv = get_col_major_tma_aligned_tensor(
layer.w13_weight_scale_inv).contiguous()
if _is_col_major(layer.w2_weight_scale_inv):
layer.w2_weight_scale_inv = get_col_major_tma_aligned_tensor(
layer.w2_weight_scale_inv).contiguous()
def select_gemm_impl( def select_gemm_impl(
self, self,
prepare_finalize: FusedMoEPrepareAndFinalize, prepare_finalize: FusedMoEPrepareAndFinalize,

View File

@ -5,6 +5,7 @@
import functools import functools
import json import json
import os import os
from collections.abc import Sequence
from typing import Any, Callable, Optional, Union from typing import Any, Callable, Optional, Union
import torch import torch
@ -13,7 +14,7 @@ import vllm.envs as envs
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.utils.quant_utils import ( from vllm.model_executor.layers.quantization.utils.quant_utils import (
scaled_dequantize) group_broadcast)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
CUTLASS_BLOCK_FP8_SUPPORTED) CUTLASS_BLOCK_FP8_SUPPORTED)
from vllm.platforms import current_platform from vllm.platforms import current_platform
@ -235,7 +236,7 @@ def block_quant_to_tensor_quant(
The outputs are tensor-wise quantization tensor and tensor-wise The outputs are tensor-wise quantization tensor and tensor-wise
quantization scale. Note only float8 is supported for now. quantization scale. Note only float8 is supported for now.
""" """
x_dq_block = scaled_dequantize(x_q_block, x_s) x_dq_block = group_broadcast(x_q_block, x_s)
x_q_tensor, scale = input_to_float8(x_dq_block, dtype=x_q_block.dtype) x_q_tensor, scale = input_to_float8(x_dq_block, dtype=x_q_block.dtype)
return x_q_tensor, scale return x_q_tensor, scale
@ -651,3 +652,124 @@ def w8a8_block_fp8_matmul(
) )
return C return C
# Taken from https://github.com/deepseek-ai/DeepGEMM/blob/0c88cd01392c1073c7049a97d6328c7bba9b3947
# TODO(wentao): remove this function when DeepGEMM exposes this function
def get_tma_aligned_size(x: int, element_size: int) -> int:
"""
Global memory address of TMA must be 16-byte aligned.
Since we use column-major layout for the LHS scaling tensor,
the M-axis of the LHS scaling tensor needs to be padded to a multiple of
16 bytes.
Arguments:
x: original M-axis shape of the LHS scaling tensor.
element_size: element size of the LHS scaling tensor.
Returns:
M-axis shape of the LHS scaling tensor after padding.
"""
tma_alignment_bytes = 16
assert tma_alignment_bytes % element_size == 0
alignment = tma_alignment_bytes // element_size
return cdiv(x, alignment) * alignment
# Taken from https://github.com/deepseek-ai/DeepGEMM/blob/0c88cd01392c1073c7049a97d6328c7bba9b3947
# TODO(wentao): remove this function when DeepGEMM exposes this function
def get_col_major_tma_aligned_tensor(x: torch.Tensor) -> torch.Tensor:
"""
Returns TMA-aligned transposed format of the input tensor. `torch.transpose`
will be called if necessary.
If the input tensor is already column-major layout and 16-byte aligned along
the M axis (thus meets the requirement of LHS scaling tensor in
DeepGEMM), this function will do nothing.
Arguments:
x: usually the LHS scaling tensor in GEMM.
Returns:
The LHS scaling tensor of TMA-aligned transposed format.
"""
# NOTES: for the extreme performance, you may rewrite/fuse this function in
# CUDA
assert x.dim() in (2, 3)
remove_dim = False
m, n = x.shape[-2], x.shape[-1]
aligned_m = get_tma_aligned_size(m, x.element_size())
if x.dim() == 2:
if x.stride(0) == 1 and x.stride(1) == aligned_m:
return x
x, remove_dim = x.unsqueeze(0), True
b = x.shape[0]
# The last kernel gives a column-major TMA aligned layout
if x.stride(0) == aligned_m * n and x.stride(1) == 1 and x.stride(
2) == aligned_m:
return x.squeeze(0) if remove_dim else x
# Normal layout requires transposing
aligned_x = torch.transpose(
torch.empty((b, n, aligned_m), device=x.device, dtype=x.dtype), 1, 2)
aligned_x[:, :m, :] = x
aligned_x = aligned_x[:, :m, :]
return aligned_x.squeeze(0) if remove_dim else aligned_x
def requant_weight_ue8m0_inplace(
weight: torch.Tensor,
weight_scale: torch.Tensor,
block_size: Sequence[int] = (128, 128),
) -> None:
"""Re-quantise *weight* so that its per-block scaling factors are in the
UE8M0 (power-of-two) format expected by the new DeepGEMM kernels inplace.
Args:
weight: Block-quantised weight tensor stored in ``torch.float8_e4m3fn``.
Expected shape ``(..., M, K)``.
weight_scale: Corresponding per-block scale tensor (``torch.float32``)
with shape ``(..., M // block_size[0], K // block_size[1])``.
block_size: 2-element iterable ``[block_m, block_k]`` describing the
block quantisation granularity.
"""
if weight.numel() == 0:
return
if weight.dtype != torch.float8_e4m3fn:
raise ValueError("Expected *weight* to be torch.float8_e4m3fn, got "
f"{weight.dtype} instead.")
from vllm.utils.deep_gemm import per_block_cast_to_fp8
block_m, block_k = int(block_size[0]), int(block_size[1])
# Flatten leading dimensions so we can iterate over the last two dims.
leading_shape = weight.shape[:-2]
if len(leading_shape) == 0:
w_view = weight.unsqueeze(0)
s_view = weight_scale.unsqueeze(0)
else:
w_view = weight.reshape(-1, weight.shape[-2], weight.shape[-1])
s_view = weight_scale.reshape(-1, *weight_scale.shape[-2:])
num_mats = w_view.size(0)
for idx in range(num_mats):
w_q = w_view[idx]
s_old = s_view[idx]
# De-quantise with the *old* scaling factors (float32).
m_cur, k_cur = w_q.shape
s_float = s_old.to(torch.float32)
# Expand scales along rows and cols by block size, then crop.
s_exp_r = torch.repeat_interleave(s_float, block_m, dim=0)
s_exp = torch.repeat_interleave(s_exp_r, block_k, dim=1)
s_exp = s_exp[:m_cur, :k_cur]
w_dq = w_q.to(torch.float32) * s_exp
# Re-quantise using power-of-two scaling (UE8M0).
w_requant, s_requant = per_block_cast_to_fp8(w_dq, [block_m, block_k])
# Write back the results in-place.
w_q.copy_(w_requant)
s_old.copy_(s_requant)

152
vllm/utils/deep_gemm.py Normal file
View File

@ -0,0 +1,152 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Compatibility wrapper for DeepGEMM API changes.
Users of vLLM should always import **only** these wrappers.
"""
from __future__ import annotations
import functools
import importlib
from typing import Any, Callable, NoReturn
import torch
import vllm.envs as envs
from vllm.utils import cuda_get_device_properties, has_deep_gemm
@functools.cache
def is_blackwell_deep_gemm_used() -> bool:
"""Return ``True`` if vLLM is configured to use DeepGEMM on a
Blackwell-class GPU.
"""
if not (envs.VLLM_USE_DEEP_GEMM and has_deep_gemm()
and _per_block_cast_impl is not None):
return False
return cuda_get_device_properties(0, ("major", ))[0] == 10
def _missing(*_: Any, **__: Any) -> NoReturn:
"""Placeholder for unavailable DeepGEMM backend."""
raise RuntimeError(
"DeepGEMM backend is not available. Please install the `deep_gemm` "
"package to enable FP8 kernels.")
def _resolve_symbol(module, new: str, old: str) -> Callable[..., Any] | None:
"""Return the *new* symbol if it exists, otherwise the *old* one."""
if hasattr(module, new):
return getattr(module, new)
if hasattr(module, old):
return getattr(module, old)
return None
if not has_deep_gemm():
_fp8_gemm_nt_impl: Callable[..., Any] | None = None
_grouped_impl: Callable[..., Any] | None = None
_grouped_masked_impl: Callable[..., Any] | None = None
_per_token_cast_impl: Callable[..., Any] | None = None
_per_block_cast_impl: Callable[..., Any] | None = None
else:
_dg = importlib.import_module("deep_gemm") # type: ignore
_fp8_gemm_nt_impl = _resolve_symbol(
_dg,
"fp8_gemm_nt",
"gemm_fp8_fp8_bf16_nt",
)
_grouped_impl = _resolve_symbol(
_dg,
"m_grouped_fp8_gemm_nt_contiguous",
"m_grouped_gemm_fp8_fp8_bf16_nt_contiguous",
)
_grouped_masked_impl = _resolve_symbol(
_dg,
"fp8_m_grouped_gemm_nt_masked",
"m_grouped_gemm_fp8_fp8_bf16_nt_masked",
)
# Try to get per_token_cast_to_fp8 from DeepGEMM math utils.
try:
_math_mod = importlib.import_module(
"deep_gemm.utils.math") # type: ignore
_per_token_cast_impl = getattr(_math_mod, "per_token_cast_to_fp8",
None)
_per_block_cast_impl = getattr(_math_mod, "per_block_cast_to_fp8",
None)
except ModuleNotFoundError:
_per_token_cast_impl = None
_per_block_cast_impl = None
def fp8_gemm_nt(*args, **kwargs):
if _fp8_gemm_nt_impl is None:
return _missing(*args, **kwargs)
return _fp8_gemm_nt_impl(*args, **kwargs)
def m_grouped_fp8_gemm_nt_contiguous(*args, **kwargs):
if _grouped_impl is None:
return _missing(*args, **kwargs)
return _grouped_impl(*args, **kwargs)
def fp8_m_grouped_gemm_nt_masked(*args, **kwargs):
if _grouped_masked_impl is None:
return _missing(*args, **kwargs)
return _grouped_masked_impl(*args, **kwargs)
def per_token_group_cast_to_fp8(x, group_size, *args, **kwargs):
"""Wrapper for token-wise FP8 quantisation.
If DeepGEMM provides ``per_token_cast_to_fp8`` (new API), use it.
Otherwise, fall back to vLLM's ``per_token_group_quant_fp8``
"""
if _per_token_cast_impl is not None and is_blackwell_deep_gemm_used():
assert group_size == 128, "group_size must be 128 for deepgemm"
return _per_token_cast_impl(x)
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
per_token_group_quant_fp8 as _ptg)
return _ptg(x, group_size, *args, **kwargs)
def per_block_cast_to_fp8(x, *args, **kwargs):
if _per_block_cast_impl is not None and is_blackwell_deep_gemm_used():
return _per_block_cast_impl(x)
# TODO: refactor the `per_block_cast_to_fp8` from tests to vllm utils
from tests.kernels.quant_utils import per_block_cast_to_fp8 as _pbcf
return _pbcf(x, *args, **kwargs)
def calc_diff(x: torch.Tensor, y: torch.Tensor):
"""Return a global difference metric for unit tests.
DeepGEMM kernels on Blackwell/B200 currently exhibit noticeable per-element
error, causing ``torch.testing.assert_close`` to fail. Instead of checking
every element, we compute a cosine-style similarity over the whole tensor
and report ``1 - sim``. Once kernel accuracy improves this helper can be
removed.
"""
x, y = x.double(), y.double()
denominator = (x * x + y * y).sum()
sim = 2 * (x * y).sum() / denominator
return 1 - sim
__all__ = [
"calc_diff",
"fp8_gemm_nt",
"m_grouped_fp8_gemm_nt_contiguous",
"fp8_m_grouped_gemm_nt_masked",
"per_token_group_cast_to_fp8",
"per_block_cast_to_fp8",
"is_blackwell_deep_gemm_used",
]