[Performance][B200] Fix deepgemm prologue (#27897)

Signed-off-by: Varun Sundar Rabindranath <vsundarr@redhat.com>
Co-authored-by: Varun Sundar Rabindranath <vsundarr@redhat.com>
This commit is contained in:
Varun Sundar Rabindranath 2025-11-12 16:13:03 -05:00 committed by GitHub
parent 478ee511de
commit 74a9a9faad
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 163 additions and 48 deletions

View File

@ -232,6 +232,7 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
""" """
super().__init__(quant_config) super().__init__(quant_config)
assert self.block_shape == get_mk_alignment_for_contiguous_layout() assert self.block_shape == get_mk_alignment_for_contiguous_layout()
assert self.quant_config.use_fp8_w8a8
self.max_num_tokens = max_num_tokens self.max_num_tokens = max_num_tokens
self.num_dispatchers = num_dispatchers self.num_dispatchers = num_dispatchers
@ -250,6 +251,12 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
def supports_expert_map(self) -> bool: def supports_expert_map(self) -> bool:
return False return False
def supports_packed_ue8m0_act_scales(self) -> bool:
"""
DeepGemm supports packed ue8m0 activation scales format in devices == sm100
"""
return current_platform.is_device_capability(100)
def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce: def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
# Let PrepareAndFinalize::finalize() decide the impl. # Let PrepareAndFinalize::finalize() decide the impl.
return TopKWeightAndReduceDelegate() return TopKWeightAndReduceDelegate()

View File

@ -6,6 +6,7 @@ import deep_ep
import torch import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
TopKWeightAndReduceDelegate, TopKWeightAndReduceDelegate,
@ -20,6 +21,8 @@ from vllm.v1.worker.ubatching import (
dbo_maybe_run_recv_hook, dbo_maybe_run_recv_hook,
) )
logger = init_logger(__name__)
# DeepEP kernels quantize dispatch inputs in 128 element chunks. # DeepEP kernels quantize dispatch inputs in 128 element chunks.
DEEPEP_QUANT_BLOCK_SIZE = 128 DEEPEP_QUANT_BLOCK_SIZE = 128
DEEPEP_QUANT_BLOCK_SHAPE = [DEEPEP_QUANT_BLOCK_SIZE, DEEPEP_QUANT_BLOCK_SIZE] DEEPEP_QUANT_BLOCK_SHAPE = [DEEPEP_QUANT_BLOCK_SIZE, DEEPEP_QUANT_BLOCK_SIZE]
@ -94,6 +97,29 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
self.handles: list[tuple | None] = [None, None] self.handles: list[tuple | None] = [None, None]
self.num_dispatchers_ = num_dispatchers self.num_dispatchers_ = num_dispatchers
# We don't have enough information to determine if we should dispatch
# activation scales in a packed ue8m0 format during object construction
# time. This setting is handled by post_init_setup.
self.use_ue8m0_dispatch = False
def post_init_setup(self, fused_experts: mk.FusedMoEPermuteExpertsUnpermute):
if not fused_experts.supports_packed_ue8m0_act_scales():
# Early exit.
return
if self.use_fp8_dispatch:
logger.debug_once(
"Update DeepEPLLPrepareFinalize to do packed ue8m0 scales dispatch."
)
self.use_ue8m0_dispatch = True
else:
logger.warning_once(
"DeepEPLLPrepareAndFinalize is setup to dispatch raw/unquantized "
f"activations despite ({fused_experts.__class__.__name__}) being able "
"to support quantized activations.",
scope="local",
)
def num_dispatchers(self) -> int: def num_dispatchers(self) -> int:
return self.num_dispatchers_ return self.num_dispatchers_
@ -206,6 +232,9 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
self.max_tokens_per_rank, self.max_tokens_per_rank,
num_experts, num_experts,
use_fp8=self.use_fp8_dispatch, use_fp8=self.use_fp8_dispatch,
# round_scale needs to be set to dispatch in ue8m0
round_scale=self.use_ue8m0_dispatch,
use_ue8m0=self.use_ue8m0_dispatch,
async_finish=False, async_finish=False,
return_recv_hook=True, return_recv_hook=True,
) )

View File

@ -149,6 +149,15 @@ class FusedMoEPrepareAndFinalize(ABC):
described above. described above.
""" """
def post_init_setup(self, fused_experts: "FusedMoEPermuteExpertsUnpermute"):
"""
Initialize FusedMoEPrepareAndFinalize settings that depend on
FusedMoEPermuteExpertsUnpermute experts object.
The FusedMoEPrepareAndFinalize implementations that have such
dependencies may choose to override this function.
"""
return
@abstractmethod @abstractmethod
def prepare( def prepare(
self, self,
@ -503,6 +512,13 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
""" """
raise NotImplementedError raise NotImplementedError
def supports_packed_ue8m0_act_scales(self) -> bool:
"""
A flag indicating whether or not this class can process packed ue8m0
activation scales.
"""
return False
def workspace_dtype(self, act_dtype: torch.dtype) -> torch.dtype: def workspace_dtype(self, act_dtype: torch.dtype) -> torch.dtype:
""" """
Workspace type: The dtype to use for the workspace tensors. Workspace type: The dtype to use for the workspace tensors.
@ -698,6 +714,8 @@ class FusedMoEModularKernel(torch.nn.Module):
self.prepare_finalize = prepare_finalize self.prepare_finalize = prepare_finalize
self.fused_experts = fused_experts self.fused_experts = fused_experts
self.shared_experts = shared_experts self.shared_experts = shared_experts
self._post_init_setup()
assert ( assert (
prepare_finalize.activation_format == fused_experts.activation_formats[0] prepare_finalize.activation_format == fused_experts.activation_formats[0]
), ( ), (
@ -707,6 +725,13 @@ class FusedMoEModularKernel(torch.nn.Module):
f"{fused_experts.activation_formats[0]}" f"{fused_experts.activation_formats[0]}"
) )
def _post_init_setup(self):
"""
Resolve any leftover setup dependencies between self.prepare_finalize
and self.fused_experts here.
"""
self.prepare_finalize.post_init_setup(self.fused_experts)
def supports_expert_map(self) -> bool: def supports_expert_map(self) -> bool:
""" """
A flag indicating whether or not this class supports expert maps. A flag indicating whether or not this class supports expert maps.

View File

@ -60,11 +60,10 @@ from vllm.model_executor.layers.quantization.utils.fp8_utils import (
create_fp8_input_scale, create_fp8_input_scale,
create_fp8_scale_parameter, create_fp8_scale_parameter,
create_fp8_weight_parameter, create_fp8_weight_parameter,
expert_weight_is_col_major, deepgemm_post_process_fp8_weight_block,
maybe_post_process_fp8_weight_block, maybe_post_process_fp8_weight_block,
process_fp8_weight_block_strategy, process_fp8_weight_block_strategy,
process_fp8_weight_tensor_strategy, process_fp8_weight_tensor_strategy,
requant_weight_ue8m0_inplace,
validate_fp8_block_shape, validate_fp8_block_shape,
) )
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import ( from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
@ -94,7 +93,6 @@ from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.scalar_type import scalar_types from vllm.scalar_type import scalar_types
from vllm.utils.deep_gemm import ( from vllm.utils.deep_gemm import (
get_col_major_tma_aligned_tensor,
is_deep_gemm_e8m0_used, is_deep_gemm_e8m0_used,
is_deep_gemm_supported, is_deep_gemm_supported,
) )
@ -846,15 +844,31 @@ class Fp8MoEMethod(FusedMoEMethodBase):
# DeepGemm scales need to be transposed and aligned. We try to do # DeepGemm scales need to be transposed and aligned. We try to do
# it ahead of time for performance reasons. # it ahead of time for performance reasons.
if self.allow_deep_gemm and not is_deep_gemm_e8m0_used(): if self.allow_deep_gemm:
if expert_weight_is_col_major(layer.w13_weight_scale_inv): dg_w13_weight, dg_w13_weight_scale_inv = (
layer.w13_weight_scale_inv = get_col_major_tma_aligned_tensor( deepgemm_post_process_fp8_weight_block(
layer.w13_weight_scale_inv wq=layer.w13_weight.data,
ws=layer.w13_weight_scale_inv.data,
quant_block_shape=tuple(layer.weight_block_size),
use_e8m0=is_deep_gemm_e8m0_used(),
) )
if expert_weight_is_col_major(layer.w2_weight_scale_inv): )
layer.w2_weight_scale_inv = get_col_major_tma_aligned_tensor( dg_w2_weight, dg_w2_weight_scale_inv = (
layer.w2_weight_scale_inv deepgemm_post_process_fp8_weight_block(
wq=layer.w2_weight.data,
ws=layer.w2_weight_scale_inv.data,
quant_block_shape=tuple(layer.weight_block_size),
use_e8m0=is_deep_gemm_e8m0_used(),
) )
)
layer.w13_weight = Parameter(dg_w13_weight, requires_grad=False)
layer.w13_weight_scale_inv = Parameter(
dg_w13_weight_scale_inv, requires_grad=False
)
layer.w2_weight = Parameter(dg_w2_weight, requires_grad=False)
layer.w2_weight_scale_inv = Parameter(
dg_w2_weight_scale_inv, requires_grad=False
)
# If checkpoint is fp16, quantize in place. # If checkpoint is fp16, quantize in place.
elif not self.quant_config.is_checkpoint_fp8_serialized: elif not self.quant_config.is_checkpoint_fp8_serialized:
@ -990,31 +1004,6 @@ class Fp8MoEMethod(FusedMoEMethodBase):
del layer.w13_input_scale del layer.w13_input_scale
del layer.w2_input_scale del layer.w2_input_scale
if is_deep_gemm_e8m0_used() and self.block_quant:
assert layer.weight_block_size is not None
# Re-quantise the expert weights so their scales are UE8M0.
block_sz = tuple(layer.weight_block_size)
requant_weight_ue8m0_inplace(
layer.w13_weight.data,
layer.w13_weight_scale_inv.data,
block_sz,
)
requant_weight_ue8m0_inplace(
layer.w2_weight.data,
layer.w2_weight_scale_inv.data,
block_sz,
)
# Ensure column-major TMA alignment expected by DeepGEMM.
if expert_weight_is_col_major(layer.w13_weight_scale_inv):
layer.w13_weight_scale_inv = get_col_major_tma_aligned_tensor(
layer.w13_weight_scale_inv
)
if expert_weight_is_col_major(layer.w2_weight_scale_inv):
layer.w2_weight_scale_inv = get_col_major_tma_aligned_tensor(
layer.w2_weight_scale_inv
)
def maybe_make_prepare_finalize(self) -> mk.FusedMoEPrepareAndFinalize | None: def maybe_make_prepare_finalize(self) -> mk.FusedMoEPrepareAndFinalize | None:
if ( if (
self.rocm_aiter_moe_enabled self.rocm_aiter_moe_enabled
@ -1037,7 +1026,8 @@ class Fp8MoEMethod(FusedMoEMethodBase):
layer: torch.nn.Module, layer: torch.nn.Module,
) -> FusedMoEPermuteExpertsUnpermute: ) -> FusedMoEPermuteExpertsUnpermute:
from vllm.model_executor.layers.fused_moe import ( from vllm.model_executor.layers.fused_moe import (
BatchedTritonOrDeepGemmExperts, BatchedDeepGemmExperts,
BatchedTritonExperts,
TritonOrDeepGemmExperts, TritonOrDeepGemmExperts,
) )
@ -1053,20 +1043,24 @@ class Fp8MoEMethod(FusedMoEMethodBase):
): ):
max_num_tokens_per_rank = prepare_finalize.max_num_tokens_per_rank() max_num_tokens_per_rank = prepare_finalize.max_num_tokens_per_rank()
assert max_num_tokens_per_rank is not None assert max_num_tokens_per_rank is not None
experts_impl = (
BatchedDeepGemmExperts if self.allow_deep_gemm else BatchedTritonExperts
)
logger.debug( logger.debug(
"BatchedTritonOrDeepGemmExperts(%s): " "%s(%s): max_tokens_per_rank=%s, block_size=%s, per_act_token=%s",
"max_tokens_per_rank=%s, block_size=%s, per_act_token=%s", experts_impl.__name__,
self.__class__.__name__, self.__class__.__name__,
max_num_tokens_per_rank, max_num_tokens_per_rank,
self.weight_block_size, self.weight_block_size,
False, False,
) )
return BatchedTritonOrDeepGemmExperts( return experts_impl(
max_num_tokens=max_num_tokens_per_rank, max_num_tokens=max_num_tokens_per_rank,
num_dispatchers=prepare_finalize.num_dispatchers(), num_dispatchers=prepare_finalize.num_dispatchers(),
quant_config=self.moe_quant_config, quant_config=self.moe_quant_config,
allow_deep_gemm=self.allow_deep_gemm,
) )
elif self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS: elif self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS:
experts = select_cutlass_fp8_gemm_impl( experts = select_cutlass_fp8_gemm_impl(
self.moe, self.moe,

View File

@ -34,6 +34,7 @@ from vllm.utils.deep_gemm import (
is_deep_gemm_e8m0_used, is_deep_gemm_e8m0_used,
is_deep_gemm_supported, is_deep_gemm_supported,
should_use_deepgemm_for_fp8_linear, should_use_deepgemm_for_fp8_linear,
transform_sf_into_required_layout,
) )
from vllm.utils.torch_utils import direct_register_custom_op from vllm.utils.torch_utils import direct_register_custom_op
@ -929,6 +930,50 @@ def requant_weight_ue8m0_inplace(
s_old.copy_(s_requant) s_old.copy_(s_requant)
def deepgemm_post_process_fp8_weight_block(
wq: torch.Tensor, ws: torch.Tensor, quant_block_shape: tuple[int], use_e8m0: bool
) -> tuple[torch.Tensor, torch.Tensor]:
assert wq.dtype == torch.float8_e4m3fn, (
"Expected quantized tensor dtype "
f"to be torch.float8_e4m3fn, got {wq.dtype} instead."
)
assert ws.dtype == torch.float32, (
f"Expected tensor scales dtype to be torch.float32, got {ws.dtype} instead"
)
if use_e8m0:
requant_weight_ue8m0_inplace(wq, ws, block_size=quant_block_shape)
original_ndim = wq.ndim
if wq.ndim == 2:
assert ws.ndim == 2
wq = wq.unsqueeze(0)
ws = ws.unsqueeze(0)
# From https://github.com/deepseek-ai/DeepGEMM/blob/c9f8b34dcdacc20aa746b786f983492c51072870/csrc/utils/layout.hpp#L46
recipe = (1, 128, 128)
# Ref : https://github.com/deepseek-ai/DeepGEMM/blob/c9f8b34dcdacc20aa746b786f983492c51072870/csrc/apis/gemm.hpp
# DeepGemm uses the `transform_sf_into_required_layout` function to
# represent scales in the correct format.
dg_ws = transform_sf_into_required_layout(
sf=ws,
mn=wq.size(1),
k=wq.size(2),
recipe=recipe,
num_groups=wq.size(0),
# is the scale factors for A in (Refers to the argument A in A @ B).
# Weights are B.
is_sfa=False,
)
if original_ndim == 2:
wq = wq.squeeze(0)
dg_ws = dg_ws.squeeze(0)
return wq, dg_ws
def _maybe_pad_fp8_weight(weight: torch.Tensor) -> torch.Tensor: def _maybe_pad_fp8_weight(weight: torch.Tensor) -> torch.Tensor:
"""Pad the weight tensor. This is an optimization on ROCm platform, which """Pad the weight tensor. This is an optimization on ROCm platform, which
can benefit from tensors located far enough from one another in memory""" can benefit from tensors located far enough from one another in memory"""
@ -1141,11 +1186,15 @@ def maybe_post_process_fp8_weight_block(layer: torch.nn.Module):
should_use_deepgemm = should_use_deepgemm_for_fp8_linear( should_use_deepgemm = should_use_deepgemm_for_fp8_linear(
layer.orig_dtype, layer.weight layer.orig_dtype, layer.weight
) )
if is_deep_gemm_e8m0_used() and should_use_deepgemm: if should_use_deepgemm:
block_sz = tuple(layer.weight_block_size) dg_weight, dg_weight_scale = deepgemm_post_process_fp8_weight_block(
requant_weight_ue8m0_inplace( wq=layer.weight.data,
layer.weight.data, layer.weight_scale.data, block_sz ws=layer.weight_scale.data,
quant_block_shape=tuple(layer.weight_block_size),
use_e8m0=is_deep_gemm_e8m0_used(),
) )
layer.weight = torch.nn.Parameter(dg_weight, requires_grad=False)
layer.weight_scale = torch.nn.Parameter(dg_weight_scale, requires_grad=False)
def expert_weight_is_col_major(x: torch.Tensor) -> bool: def expert_weight_is_col_major(x: torch.Tensor) -> bool:

View File

@ -49,10 +49,6 @@ def is_deep_gemm_e8m0_used() -> bool:
logger.info_once("DeepGEMM E8M0 disabled: _fp8_gemm_nt_impl not found") logger.info_once("DeepGEMM E8M0 disabled: _fp8_gemm_nt_impl not found")
return False return False
if envs.VLLM_USE_FLASHINFER_MOE_FP8:
logger.info_once("DeepGEMM E8M0 disabled: FlashInfer MOE is enabled.")
return False
if envs.VLLM_USE_DEEP_GEMM_E8M0: if envs.VLLM_USE_DEEP_GEMM_E8M0:
logger.info_once("DeepGEMM E8M0 enabled on current platform.") logger.info_once("DeepGEMM E8M0 enabled on current platform.")
return True return True
@ -77,6 +73,7 @@ _fp8_paged_mqa_logits_impl: Callable[..., Any] | None = None
_get_paged_mqa_logits_metadata_impl: Callable[..., Any] | None = None _get_paged_mqa_logits_metadata_impl: Callable[..., Any] | None = None
_get_mn_major_tma_aligned_tensor_impl: Callable[..., Any] | None = None _get_mn_major_tma_aligned_tensor_impl: Callable[..., Any] | None = None
_get_mk_alignment_for_contiguous_layout_impl: Callable[..., Any] | None = None _get_mk_alignment_for_contiguous_layout_impl: Callable[..., Any] | None = None
_transform_sf_into_required_layout_impl: Callable[..., Any] | None = None
def _lazy_init() -> None: def _lazy_init() -> None:
@ -86,6 +83,7 @@ def _lazy_init() -> None:
global _get_paged_mqa_logits_metadata_impl global _get_paged_mqa_logits_metadata_impl
global _get_mn_major_tma_aligned_tensor_impl global _get_mn_major_tma_aligned_tensor_impl
global _get_mk_alignment_for_contiguous_layout_impl global _get_mk_alignment_for_contiguous_layout_impl
global _transform_sf_into_required_layout_impl
# fast path # fast path
if ( if (
_fp8_gemm_nt_impl is not None _fp8_gemm_nt_impl is not None
@ -95,6 +93,7 @@ def _lazy_init() -> None:
or _fp8_paged_mqa_logits_impl is not None or _fp8_paged_mqa_logits_impl is not None
or _get_paged_mqa_logits_metadata_impl is not None or _get_paged_mqa_logits_metadata_impl is not None
or _get_mk_alignment_for_contiguous_layout_impl is not None or _get_mk_alignment_for_contiguous_layout_impl is not None
or _transform_sf_into_required_layout_impl is not None
): ):
return return
@ -124,6 +123,9 @@ def _lazy_init() -> None:
_get_mk_alignment_for_contiguous_layout_impl = getattr( _get_mk_alignment_for_contiguous_layout_impl = getattr(
_dg, "get_mk_alignment_for_contiguous_layout", None _dg, "get_mk_alignment_for_contiguous_layout", None
) )
_transform_sf_into_required_layout_impl = getattr(
_dg, "transform_sf_into_required_layout", None
)
def get_num_sms() -> int: def get_num_sms() -> int:
@ -179,6 +181,15 @@ def fp8_m_grouped_gemm_nt_masked(*args, **kwargs):
) )
def transform_sf_into_required_layout(*args, **kwargs):
_lazy_init()
if _transform_sf_into_required_layout_impl is None:
return _missing(*args, **kwargs)
return _transform_sf_into_required_layout_impl(
*args, disable_ue8m0_cast=not is_deep_gemm_e8m0_used(), **kwargs
)
def fp8_mqa_logits( def fp8_mqa_logits(
q: torch.Tensor, q: torch.Tensor,
kv: tuple[torch.Tensor, torch.Tensor], kv: tuple[torch.Tensor, torch.Tensor],