diff --git a/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py index b8a97e92ab79..869082f8231d 100644 --- a/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py @@ -232,6 +232,7 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): """ super().__init__(quant_config) assert self.block_shape == get_mk_alignment_for_contiguous_layout() + assert self.quant_config.use_fp8_w8a8 self.max_num_tokens = max_num_tokens self.num_dispatchers = num_dispatchers @@ -250,6 +251,12 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): def supports_expert_map(self) -> bool: return False + def supports_packed_ue8m0_act_scales(self) -> bool: + """ + DeepGemm supports packed ue8m0 activation scales format in devices == sm100 + """ + return current_platform.is_device_capability(100) + def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce: # Let PrepareAndFinalize::finalize() decide the impl. return TopKWeightAndReduceDelegate() diff --git a/vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py b/vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py index 500bcefcfaa9..06c9df317f7c 100644 --- a/vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py +++ b/vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py @@ -6,6 +6,7 @@ import deep_ep import torch import vllm.model_executor.layers.fused_moe.modular_kernel as mk +from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( TopKWeightAndReduceDelegate, @@ -20,6 +21,8 @@ from vllm.v1.worker.ubatching import ( dbo_maybe_run_recv_hook, ) +logger = init_logger(__name__) + # DeepEP kernels quantize dispatch inputs in 128 element chunks. DEEPEP_QUANT_BLOCK_SIZE = 128 DEEPEP_QUANT_BLOCK_SHAPE = [DEEPEP_QUANT_BLOCK_SIZE, DEEPEP_QUANT_BLOCK_SIZE] @@ -94,6 +97,29 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): self.handles: list[tuple | None] = [None, None] self.num_dispatchers_ = num_dispatchers + # We don't have enough information to determine if we should dispatch + # activation scales in a packed ue8m0 format during object construction + # time. This setting is handled by post_init_setup. + self.use_ue8m0_dispatch = False + + def post_init_setup(self, fused_experts: mk.FusedMoEPermuteExpertsUnpermute): + if not fused_experts.supports_packed_ue8m0_act_scales(): + # Early exit. + return + + if self.use_fp8_dispatch: + logger.debug_once( + "Update DeepEPLLPrepareFinalize to do packed ue8m0 scales dispatch." + ) + self.use_ue8m0_dispatch = True + else: + logger.warning_once( + "DeepEPLLPrepareAndFinalize is setup to dispatch raw/unquantized " + f"activations despite ({fused_experts.__class__.__name__}) being able " + "to support quantized activations.", + scope="local", + ) + def num_dispatchers(self) -> int: return self.num_dispatchers_ @@ -206,6 +232,9 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): self.max_tokens_per_rank, num_experts, use_fp8=self.use_fp8_dispatch, + # round_scale needs to be set to dispatch in ue8m0 + round_scale=self.use_ue8m0_dispatch, + use_ue8m0=self.use_ue8m0_dispatch, async_finish=False, return_recv_hook=True, ) diff --git a/vllm/model_executor/layers/fused_moe/modular_kernel.py b/vllm/model_executor/layers/fused_moe/modular_kernel.py index b5fa2c71bec5..a3142f37053f 100644 --- a/vllm/model_executor/layers/fused_moe/modular_kernel.py +++ b/vllm/model_executor/layers/fused_moe/modular_kernel.py @@ -149,6 +149,15 @@ class FusedMoEPrepareAndFinalize(ABC): described above. """ + def post_init_setup(self, fused_experts: "FusedMoEPermuteExpertsUnpermute"): + """ + Initialize FusedMoEPrepareAndFinalize settings that depend on + FusedMoEPermuteExpertsUnpermute experts object. + The FusedMoEPrepareAndFinalize implementations that have such + dependencies may choose to override this function. + """ + return + @abstractmethod def prepare( self, @@ -503,6 +512,13 @@ class FusedMoEPermuteExpertsUnpermute(ABC): """ raise NotImplementedError + def supports_packed_ue8m0_act_scales(self) -> bool: + """ + A flag indicating whether or not this class can process packed ue8m0 + activation scales. + """ + return False + def workspace_dtype(self, act_dtype: torch.dtype) -> torch.dtype: """ Workspace type: The dtype to use for the workspace tensors. @@ -698,6 +714,8 @@ class FusedMoEModularKernel(torch.nn.Module): self.prepare_finalize = prepare_finalize self.fused_experts = fused_experts self.shared_experts = shared_experts + + self._post_init_setup() assert ( prepare_finalize.activation_format == fused_experts.activation_formats[0] ), ( @@ -707,6 +725,13 @@ class FusedMoEModularKernel(torch.nn.Module): f"{fused_experts.activation_formats[0]}" ) + def _post_init_setup(self): + """ + Resolve any leftover setup dependencies between self.prepare_finalize + and self.fused_experts here. + """ + self.prepare_finalize.post_init_setup(self.fused_experts) + def supports_expert_map(self) -> bool: """ A flag indicating whether or not this class supports expert maps. diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index cb065eb68b66..bbd0a4df1048 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -60,11 +60,10 @@ from vllm.model_executor.layers.quantization.utils.fp8_utils import ( create_fp8_input_scale, create_fp8_scale_parameter, create_fp8_weight_parameter, - expert_weight_is_col_major, + deepgemm_post_process_fp8_weight_block, maybe_post_process_fp8_weight_block, process_fp8_weight_block_strategy, process_fp8_weight_tensor_strategy, - requant_weight_ue8m0_inplace, validate_fp8_block_shape, ) from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import ( @@ -94,7 +93,6 @@ from vllm.model_executor.utils import set_weight_attrs from vllm.platforms import current_platform from vllm.scalar_type import scalar_types from vllm.utils.deep_gemm import ( - get_col_major_tma_aligned_tensor, is_deep_gemm_e8m0_used, is_deep_gemm_supported, ) @@ -846,15 +844,31 @@ class Fp8MoEMethod(FusedMoEMethodBase): # DeepGemm scales need to be transposed and aligned. We try to do # it ahead of time for performance reasons. - if self.allow_deep_gemm and not is_deep_gemm_e8m0_used(): - if expert_weight_is_col_major(layer.w13_weight_scale_inv): - layer.w13_weight_scale_inv = get_col_major_tma_aligned_tensor( - layer.w13_weight_scale_inv + if self.allow_deep_gemm: + dg_w13_weight, dg_w13_weight_scale_inv = ( + deepgemm_post_process_fp8_weight_block( + wq=layer.w13_weight.data, + ws=layer.w13_weight_scale_inv.data, + quant_block_shape=tuple(layer.weight_block_size), + use_e8m0=is_deep_gemm_e8m0_used(), ) - if expert_weight_is_col_major(layer.w2_weight_scale_inv): - layer.w2_weight_scale_inv = get_col_major_tma_aligned_tensor( - layer.w2_weight_scale_inv + ) + dg_w2_weight, dg_w2_weight_scale_inv = ( + deepgemm_post_process_fp8_weight_block( + wq=layer.w2_weight.data, + ws=layer.w2_weight_scale_inv.data, + quant_block_shape=tuple(layer.weight_block_size), + use_e8m0=is_deep_gemm_e8m0_used(), ) + ) + layer.w13_weight = Parameter(dg_w13_weight, requires_grad=False) + layer.w13_weight_scale_inv = Parameter( + dg_w13_weight_scale_inv, requires_grad=False + ) + layer.w2_weight = Parameter(dg_w2_weight, requires_grad=False) + layer.w2_weight_scale_inv = Parameter( + dg_w2_weight_scale_inv, requires_grad=False + ) # If checkpoint is fp16, quantize in place. elif not self.quant_config.is_checkpoint_fp8_serialized: @@ -990,31 +1004,6 @@ class Fp8MoEMethod(FusedMoEMethodBase): del layer.w13_input_scale del layer.w2_input_scale - if is_deep_gemm_e8m0_used() and self.block_quant: - assert layer.weight_block_size is not None - # Re-quantise the expert weights so their scales are UE8M0. - block_sz = tuple(layer.weight_block_size) - requant_weight_ue8m0_inplace( - layer.w13_weight.data, - layer.w13_weight_scale_inv.data, - block_sz, - ) - requant_weight_ue8m0_inplace( - layer.w2_weight.data, - layer.w2_weight_scale_inv.data, - block_sz, - ) - - # Ensure column-major TMA alignment expected by DeepGEMM. - if expert_weight_is_col_major(layer.w13_weight_scale_inv): - layer.w13_weight_scale_inv = get_col_major_tma_aligned_tensor( - layer.w13_weight_scale_inv - ) - if expert_weight_is_col_major(layer.w2_weight_scale_inv): - layer.w2_weight_scale_inv = get_col_major_tma_aligned_tensor( - layer.w2_weight_scale_inv - ) - def maybe_make_prepare_finalize(self) -> mk.FusedMoEPrepareAndFinalize | None: if ( self.rocm_aiter_moe_enabled @@ -1037,7 +1026,8 @@ class Fp8MoEMethod(FusedMoEMethodBase): layer: torch.nn.Module, ) -> FusedMoEPermuteExpertsUnpermute: from vllm.model_executor.layers.fused_moe import ( - BatchedTritonOrDeepGemmExperts, + BatchedDeepGemmExperts, + BatchedTritonExperts, TritonOrDeepGemmExperts, ) @@ -1053,20 +1043,24 @@ class Fp8MoEMethod(FusedMoEMethodBase): ): max_num_tokens_per_rank = prepare_finalize.max_num_tokens_per_rank() assert max_num_tokens_per_rank is not None + + experts_impl = ( + BatchedDeepGemmExperts if self.allow_deep_gemm else BatchedTritonExperts + ) logger.debug( - "BatchedTritonOrDeepGemmExperts(%s): " - "max_tokens_per_rank=%s, block_size=%s, per_act_token=%s", + "%s(%s): max_tokens_per_rank=%s, block_size=%s, per_act_token=%s", + experts_impl.__name__, self.__class__.__name__, max_num_tokens_per_rank, self.weight_block_size, False, ) - return BatchedTritonOrDeepGemmExperts( + return experts_impl( max_num_tokens=max_num_tokens_per_rank, num_dispatchers=prepare_finalize.num_dispatchers(), quant_config=self.moe_quant_config, - allow_deep_gemm=self.allow_deep_gemm, ) + elif self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS: experts = select_cutlass_fp8_gemm_impl( self.moe, diff --git a/vllm/model_executor/layers/quantization/utils/fp8_utils.py b/vllm/model_executor/layers/quantization/utils/fp8_utils.py index 4384857f9270..03d086bda8e3 100644 --- a/vllm/model_executor/layers/quantization/utils/fp8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/fp8_utils.py @@ -34,6 +34,7 @@ from vllm.utils.deep_gemm import ( is_deep_gemm_e8m0_used, is_deep_gemm_supported, should_use_deepgemm_for_fp8_linear, + transform_sf_into_required_layout, ) from vllm.utils.torch_utils import direct_register_custom_op @@ -929,6 +930,50 @@ def requant_weight_ue8m0_inplace( s_old.copy_(s_requant) +def deepgemm_post_process_fp8_weight_block( + wq: torch.Tensor, ws: torch.Tensor, quant_block_shape: tuple[int], use_e8m0: bool +) -> tuple[torch.Tensor, torch.Tensor]: + assert wq.dtype == torch.float8_e4m3fn, ( + "Expected quantized tensor dtype " + f"to be torch.float8_e4m3fn, got {wq.dtype} instead." + ) + assert ws.dtype == torch.float32, ( + f"Expected tensor scales dtype to be torch.float32, got {ws.dtype} instead" + ) + + if use_e8m0: + requant_weight_ue8m0_inplace(wq, ws, block_size=quant_block_shape) + + original_ndim = wq.ndim + if wq.ndim == 2: + assert ws.ndim == 2 + wq = wq.unsqueeze(0) + ws = ws.unsqueeze(0) + + # From https://github.com/deepseek-ai/DeepGEMM/blob/c9f8b34dcdacc20aa746b786f983492c51072870/csrc/utils/layout.hpp#L46 + recipe = (1, 128, 128) + + # Ref : https://github.com/deepseek-ai/DeepGEMM/blob/c9f8b34dcdacc20aa746b786f983492c51072870/csrc/apis/gemm.hpp + # DeepGemm uses the `transform_sf_into_required_layout` function to + # represent scales in the correct format. + dg_ws = transform_sf_into_required_layout( + sf=ws, + mn=wq.size(1), + k=wq.size(2), + recipe=recipe, + num_groups=wq.size(0), + # is the scale factors for A in (Refers to the argument A in A @ B). + # Weights are B. + is_sfa=False, + ) + + if original_ndim == 2: + wq = wq.squeeze(0) + dg_ws = dg_ws.squeeze(0) + + return wq, dg_ws + + def _maybe_pad_fp8_weight(weight: torch.Tensor) -> torch.Tensor: """Pad the weight tensor. This is an optimization on ROCm platform, which can benefit from tensors located far enough from one another in memory""" @@ -1141,11 +1186,15 @@ def maybe_post_process_fp8_weight_block(layer: torch.nn.Module): should_use_deepgemm = should_use_deepgemm_for_fp8_linear( layer.orig_dtype, layer.weight ) - if is_deep_gemm_e8m0_used() and should_use_deepgemm: - block_sz = tuple(layer.weight_block_size) - requant_weight_ue8m0_inplace( - layer.weight.data, layer.weight_scale.data, block_sz + if should_use_deepgemm: + dg_weight, dg_weight_scale = deepgemm_post_process_fp8_weight_block( + wq=layer.weight.data, + ws=layer.weight_scale.data, + quant_block_shape=tuple(layer.weight_block_size), + use_e8m0=is_deep_gemm_e8m0_used(), ) + layer.weight = torch.nn.Parameter(dg_weight, requires_grad=False) + layer.weight_scale = torch.nn.Parameter(dg_weight_scale, requires_grad=False) def expert_weight_is_col_major(x: torch.Tensor) -> bool: diff --git a/vllm/utils/deep_gemm.py b/vllm/utils/deep_gemm.py index a928cce09011..4c15baf7a8f9 100644 --- a/vllm/utils/deep_gemm.py +++ b/vllm/utils/deep_gemm.py @@ -49,10 +49,6 @@ def is_deep_gemm_e8m0_used() -> bool: logger.info_once("DeepGEMM E8M0 disabled: _fp8_gemm_nt_impl not found") return False - if envs.VLLM_USE_FLASHINFER_MOE_FP8: - logger.info_once("DeepGEMM E8M0 disabled: FlashInfer MOE is enabled.") - return False - if envs.VLLM_USE_DEEP_GEMM_E8M0: logger.info_once("DeepGEMM E8M0 enabled on current platform.") return True @@ -77,6 +73,7 @@ _fp8_paged_mqa_logits_impl: Callable[..., Any] | None = None _get_paged_mqa_logits_metadata_impl: Callable[..., Any] | None = None _get_mn_major_tma_aligned_tensor_impl: Callable[..., Any] | None = None _get_mk_alignment_for_contiguous_layout_impl: Callable[..., Any] | None = None +_transform_sf_into_required_layout_impl: Callable[..., Any] | None = None def _lazy_init() -> None: @@ -86,6 +83,7 @@ def _lazy_init() -> None: global _get_paged_mqa_logits_metadata_impl global _get_mn_major_tma_aligned_tensor_impl global _get_mk_alignment_for_contiguous_layout_impl + global _transform_sf_into_required_layout_impl # fast path if ( _fp8_gemm_nt_impl is not None @@ -95,6 +93,7 @@ def _lazy_init() -> None: or _fp8_paged_mqa_logits_impl is not None or _get_paged_mqa_logits_metadata_impl is not None or _get_mk_alignment_for_contiguous_layout_impl is not None + or _transform_sf_into_required_layout_impl is not None ): return @@ -124,6 +123,9 @@ def _lazy_init() -> None: _get_mk_alignment_for_contiguous_layout_impl = getattr( _dg, "get_mk_alignment_for_contiguous_layout", None ) + _transform_sf_into_required_layout_impl = getattr( + _dg, "transform_sf_into_required_layout", None + ) def get_num_sms() -> int: @@ -179,6 +181,15 @@ def fp8_m_grouped_gemm_nt_masked(*args, **kwargs): ) +def transform_sf_into_required_layout(*args, **kwargs): + _lazy_init() + if _transform_sf_into_required_layout_impl is None: + return _missing(*args, **kwargs) + return _transform_sf_into_required_layout_impl( + *args, disable_ue8m0_cast=not is_deep_gemm_e8m0_used(), **kwargs + ) + + def fp8_mqa_logits( q: torch.Tensor, kv: tuple[torch.Tensor, torch.Tensor],