[Performance][B200] Fix deepgemm prologue (#27897)

Signed-off-by: Varun Sundar Rabindranath <vsundarr@redhat.com>
Co-authored-by: Varun Sundar Rabindranath <vsundarr@redhat.com>
This commit is contained in:
Varun Sundar Rabindranath 2025-11-12 16:13:03 -05:00 committed by GitHub
parent 478ee511de
commit 74a9a9faad
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 163 additions and 48 deletions

View File

@ -232,6 +232,7 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
"""
super().__init__(quant_config)
assert self.block_shape == get_mk_alignment_for_contiguous_layout()
assert self.quant_config.use_fp8_w8a8
self.max_num_tokens = max_num_tokens
self.num_dispatchers = num_dispatchers
@ -250,6 +251,12 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
def supports_expert_map(self) -> bool:
return False
def supports_packed_ue8m0_act_scales(self) -> bool:
"""
DeepGemm supports packed ue8m0 activation scales format in devices == sm100
"""
return current_platform.is_device_capability(100)
def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
# Let PrepareAndFinalize::finalize() decide the impl.
return TopKWeightAndReduceDelegate()

View File

@ -6,6 +6,7 @@ import deep_ep
import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
TopKWeightAndReduceDelegate,
@ -20,6 +21,8 @@ from vllm.v1.worker.ubatching import (
dbo_maybe_run_recv_hook,
)
logger = init_logger(__name__)
# DeepEP kernels quantize dispatch inputs in 128 element chunks.
DEEPEP_QUANT_BLOCK_SIZE = 128
DEEPEP_QUANT_BLOCK_SHAPE = [DEEPEP_QUANT_BLOCK_SIZE, DEEPEP_QUANT_BLOCK_SIZE]
@ -94,6 +97,29 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
self.handles: list[tuple | None] = [None, None]
self.num_dispatchers_ = num_dispatchers
# We don't have enough information to determine if we should dispatch
# activation scales in a packed ue8m0 format during object construction
# time. This setting is handled by post_init_setup.
self.use_ue8m0_dispatch = False
def post_init_setup(self, fused_experts: mk.FusedMoEPermuteExpertsUnpermute):
if not fused_experts.supports_packed_ue8m0_act_scales():
# Early exit.
return
if self.use_fp8_dispatch:
logger.debug_once(
"Update DeepEPLLPrepareFinalize to do packed ue8m0 scales dispatch."
)
self.use_ue8m0_dispatch = True
else:
logger.warning_once(
"DeepEPLLPrepareAndFinalize is setup to dispatch raw/unquantized "
f"activations despite ({fused_experts.__class__.__name__}) being able "
"to support quantized activations.",
scope="local",
)
def num_dispatchers(self) -> int:
return self.num_dispatchers_
@ -206,6 +232,9 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
self.max_tokens_per_rank,
num_experts,
use_fp8=self.use_fp8_dispatch,
# round_scale needs to be set to dispatch in ue8m0
round_scale=self.use_ue8m0_dispatch,
use_ue8m0=self.use_ue8m0_dispatch,
async_finish=False,
return_recv_hook=True,
)

View File

@ -149,6 +149,15 @@ class FusedMoEPrepareAndFinalize(ABC):
described above.
"""
def post_init_setup(self, fused_experts: "FusedMoEPermuteExpertsUnpermute"):
"""
Initialize FusedMoEPrepareAndFinalize settings that depend on
FusedMoEPermuteExpertsUnpermute experts object.
The FusedMoEPrepareAndFinalize implementations that have such
dependencies may choose to override this function.
"""
return
@abstractmethod
def prepare(
self,
@ -503,6 +512,13 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
"""
raise NotImplementedError
def supports_packed_ue8m0_act_scales(self) -> bool:
"""
A flag indicating whether or not this class can process packed ue8m0
activation scales.
"""
return False
def workspace_dtype(self, act_dtype: torch.dtype) -> torch.dtype:
"""
Workspace type: The dtype to use for the workspace tensors.
@ -698,6 +714,8 @@ class FusedMoEModularKernel(torch.nn.Module):
self.prepare_finalize = prepare_finalize
self.fused_experts = fused_experts
self.shared_experts = shared_experts
self._post_init_setup()
assert (
prepare_finalize.activation_format == fused_experts.activation_formats[0]
), (
@ -707,6 +725,13 @@ class FusedMoEModularKernel(torch.nn.Module):
f"{fused_experts.activation_formats[0]}"
)
def _post_init_setup(self):
"""
Resolve any leftover setup dependencies between self.prepare_finalize
and self.fused_experts here.
"""
self.prepare_finalize.post_init_setup(self.fused_experts)
def supports_expert_map(self) -> bool:
"""
A flag indicating whether or not this class supports expert maps.

View File

@ -60,11 +60,10 @@ from vllm.model_executor.layers.quantization.utils.fp8_utils import (
create_fp8_input_scale,
create_fp8_scale_parameter,
create_fp8_weight_parameter,
expert_weight_is_col_major,
deepgemm_post_process_fp8_weight_block,
maybe_post_process_fp8_weight_block,
process_fp8_weight_block_strategy,
process_fp8_weight_tensor_strategy,
requant_weight_ue8m0_inplace,
validate_fp8_block_shape,
)
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
@ -94,7 +93,6 @@ from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform
from vllm.scalar_type import scalar_types
from vllm.utils.deep_gemm import (
get_col_major_tma_aligned_tensor,
is_deep_gemm_e8m0_used,
is_deep_gemm_supported,
)
@ -846,15 +844,31 @@ class Fp8MoEMethod(FusedMoEMethodBase):
# DeepGemm scales need to be transposed and aligned. We try to do
# it ahead of time for performance reasons.
if self.allow_deep_gemm and not is_deep_gemm_e8m0_used():
if expert_weight_is_col_major(layer.w13_weight_scale_inv):
layer.w13_weight_scale_inv = get_col_major_tma_aligned_tensor(
layer.w13_weight_scale_inv
if self.allow_deep_gemm:
dg_w13_weight, dg_w13_weight_scale_inv = (
deepgemm_post_process_fp8_weight_block(
wq=layer.w13_weight.data,
ws=layer.w13_weight_scale_inv.data,
quant_block_shape=tuple(layer.weight_block_size),
use_e8m0=is_deep_gemm_e8m0_used(),
)
if expert_weight_is_col_major(layer.w2_weight_scale_inv):
layer.w2_weight_scale_inv = get_col_major_tma_aligned_tensor(
layer.w2_weight_scale_inv
)
dg_w2_weight, dg_w2_weight_scale_inv = (
deepgemm_post_process_fp8_weight_block(
wq=layer.w2_weight.data,
ws=layer.w2_weight_scale_inv.data,
quant_block_shape=tuple(layer.weight_block_size),
use_e8m0=is_deep_gemm_e8m0_used(),
)
)
layer.w13_weight = Parameter(dg_w13_weight, requires_grad=False)
layer.w13_weight_scale_inv = Parameter(
dg_w13_weight_scale_inv, requires_grad=False
)
layer.w2_weight = Parameter(dg_w2_weight, requires_grad=False)
layer.w2_weight_scale_inv = Parameter(
dg_w2_weight_scale_inv, requires_grad=False
)
# If checkpoint is fp16, quantize in place.
elif not self.quant_config.is_checkpoint_fp8_serialized:
@ -990,31 +1004,6 @@ class Fp8MoEMethod(FusedMoEMethodBase):
del layer.w13_input_scale
del layer.w2_input_scale
if is_deep_gemm_e8m0_used() and self.block_quant:
assert layer.weight_block_size is not None
# Re-quantise the expert weights so their scales are UE8M0.
block_sz = tuple(layer.weight_block_size)
requant_weight_ue8m0_inplace(
layer.w13_weight.data,
layer.w13_weight_scale_inv.data,
block_sz,
)
requant_weight_ue8m0_inplace(
layer.w2_weight.data,
layer.w2_weight_scale_inv.data,
block_sz,
)
# Ensure column-major TMA alignment expected by DeepGEMM.
if expert_weight_is_col_major(layer.w13_weight_scale_inv):
layer.w13_weight_scale_inv = get_col_major_tma_aligned_tensor(
layer.w13_weight_scale_inv
)
if expert_weight_is_col_major(layer.w2_weight_scale_inv):
layer.w2_weight_scale_inv = get_col_major_tma_aligned_tensor(
layer.w2_weight_scale_inv
)
def maybe_make_prepare_finalize(self) -> mk.FusedMoEPrepareAndFinalize | None:
if (
self.rocm_aiter_moe_enabled
@ -1037,7 +1026,8 @@ class Fp8MoEMethod(FusedMoEMethodBase):
layer: torch.nn.Module,
) -> FusedMoEPermuteExpertsUnpermute:
from vllm.model_executor.layers.fused_moe import (
BatchedTritonOrDeepGemmExperts,
BatchedDeepGemmExperts,
BatchedTritonExperts,
TritonOrDeepGemmExperts,
)
@ -1053,20 +1043,24 @@ class Fp8MoEMethod(FusedMoEMethodBase):
):
max_num_tokens_per_rank = prepare_finalize.max_num_tokens_per_rank()
assert max_num_tokens_per_rank is not None
experts_impl = (
BatchedDeepGemmExperts if self.allow_deep_gemm else BatchedTritonExperts
)
logger.debug(
"BatchedTritonOrDeepGemmExperts(%s): "
"max_tokens_per_rank=%s, block_size=%s, per_act_token=%s",
"%s(%s): max_tokens_per_rank=%s, block_size=%s, per_act_token=%s",
experts_impl.__name__,
self.__class__.__name__,
max_num_tokens_per_rank,
self.weight_block_size,
False,
)
return BatchedTritonOrDeepGemmExperts(
return experts_impl(
max_num_tokens=max_num_tokens_per_rank,
num_dispatchers=prepare_finalize.num_dispatchers(),
quant_config=self.moe_quant_config,
allow_deep_gemm=self.allow_deep_gemm,
)
elif self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS:
experts = select_cutlass_fp8_gemm_impl(
self.moe,

View File

@ -34,6 +34,7 @@ from vllm.utils.deep_gemm import (
is_deep_gemm_e8m0_used,
is_deep_gemm_supported,
should_use_deepgemm_for_fp8_linear,
transform_sf_into_required_layout,
)
from vllm.utils.torch_utils import direct_register_custom_op
@ -929,6 +930,50 @@ def requant_weight_ue8m0_inplace(
s_old.copy_(s_requant)
def deepgemm_post_process_fp8_weight_block(
wq: torch.Tensor, ws: torch.Tensor, quant_block_shape: tuple[int], use_e8m0: bool
) -> tuple[torch.Tensor, torch.Tensor]:
assert wq.dtype == torch.float8_e4m3fn, (
"Expected quantized tensor dtype "
f"to be torch.float8_e4m3fn, got {wq.dtype} instead."
)
assert ws.dtype == torch.float32, (
f"Expected tensor scales dtype to be torch.float32, got {ws.dtype} instead"
)
if use_e8m0:
requant_weight_ue8m0_inplace(wq, ws, block_size=quant_block_shape)
original_ndim = wq.ndim
if wq.ndim == 2:
assert ws.ndim == 2
wq = wq.unsqueeze(0)
ws = ws.unsqueeze(0)
# From https://github.com/deepseek-ai/DeepGEMM/blob/c9f8b34dcdacc20aa746b786f983492c51072870/csrc/utils/layout.hpp#L46
recipe = (1, 128, 128)
# Ref : https://github.com/deepseek-ai/DeepGEMM/blob/c9f8b34dcdacc20aa746b786f983492c51072870/csrc/apis/gemm.hpp
# DeepGemm uses the `transform_sf_into_required_layout` function to
# represent scales in the correct format.
dg_ws = transform_sf_into_required_layout(
sf=ws,
mn=wq.size(1),
k=wq.size(2),
recipe=recipe,
num_groups=wq.size(0),
# is the scale factors for A in (Refers to the argument A in A @ B).
# Weights are B.
is_sfa=False,
)
if original_ndim == 2:
wq = wq.squeeze(0)
dg_ws = dg_ws.squeeze(0)
return wq, dg_ws
def _maybe_pad_fp8_weight(weight: torch.Tensor) -> torch.Tensor:
"""Pad the weight tensor. This is an optimization on ROCm platform, which
can benefit from tensors located far enough from one another in memory"""
@ -1141,11 +1186,15 @@ def maybe_post_process_fp8_weight_block(layer: torch.nn.Module):
should_use_deepgemm = should_use_deepgemm_for_fp8_linear(
layer.orig_dtype, layer.weight
)
if is_deep_gemm_e8m0_used() and should_use_deepgemm:
block_sz = tuple(layer.weight_block_size)
requant_weight_ue8m0_inplace(
layer.weight.data, layer.weight_scale.data, block_sz
if should_use_deepgemm:
dg_weight, dg_weight_scale = deepgemm_post_process_fp8_weight_block(
wq=layer.weight.data,
ws=layer.weight_scale.data,
quant_block_shape=tuple(layer.weight_block_size),
use_e8m0=is_deep_gemm_e8m0_used(),
)
layer.weight = torch.nn.Parameter(dg_weight, requires_grad=False)
layer.weight_scale = torch.nn.Parameter(dg_weight_scale, requires_grad=False)
def expert_weight_is_col_major(x: torch.Tensor) -> bool:

View File

@ -49,10 +49,6 @@ def is_deep_gemm_e8m0_used() -> bool:
logger.info_once("DeepGEMM E8M0 disabled: _fp8_gemm_nt_impl not found")
return False
if envs.VLLM_USE_FLASHINFER_MOE_FP8:
logger.info_once("DeepGEMM E8M0 disabled: FlashInfer MOE is enabled.")
return False
if envs.VLLM_USE_DEEP_GEMM_E8M0:
logger.info_once("DeepGEMM E8M0 enabled on current platform.")
return True
@ -77,6 +73,7 @@ _fp8_paged_mqa_logits_impl: Callable[..., Any] | None = None
_get_paged_mqa_logits_metadata_impl: Callable[..., Any] | None = None
_get_mn_major_tma_aligned_tensor_impl: Callable[..., Any] | None = None
_get_mk_alignment_for_contiguous_layout_impl: Callable[..., Any] | None = None
_transform_sf_into_required_layout_impl: Callable[..., Any] | None = None
def _lazy_init() -> None:
@ -86,6 +83,7 @@ def _lazy_init() -> None:
global _get_paged_mqa_logits_metadata_impl
global _get_mn_major_tma_aligned_tensor_impl
global _get_mk_alignment_for_contiguous_layout_impl
global _transform_sf_into_required_layout_impl
# fast path
if (
_fp8_gemm_nt_impl is not None
@ -95,6 +93,7 @@ def _lazy_init() -> None:
or _fp8_paged_mqa_logits_impl is not None
or _get_paged_mqa_logits_metadata_impl is not None
or _get_mk_alignment_for_contiguous_layout_impl is not None
or _transform_sf_into_required_layout_impl is not None
):
return
@ -124,6 +123,9 @@ def _lazy_init() -> None:
_get_mk_alignment_for_contiguous_layout_impl = getattr(
_dg, "get_mk_alignment_for_contiguous_layout", None
)
_transform_sf_into_required_layout_impl = getattr(
_dg, "transform_sf_into_required_layout", None
)
def get_num_sms() -> int:
@ -179,6 +181,15 @@ def fp8_m_grouped_gemm_nt_masked(*args, **kwargs):
)
def transform_sf_into_required_layout(*args, **kwargs):
_lazy_init()
if _transform_sf_into_required_layout_impl is None:
return _missing(*args, **kwargs)
return _transform_sf_into_required_layout_impl(
*args, disable_ue8m0_cast=not is_deep_gemm_e8m0_used(), **kwargs
)
def fp8_mqa_logits(
q: torch.Tensor,
kv: tuple[torch.Tensor, torch.Tensor],