diff --git a/vllm/envs.py b/vllm/envs.py index 261cc7855b70..0896ae3a96c7 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -119,7 +119,8 @@ if TYPE_CHECKING: VLLM_TPU_BUCKET_PADDING_GAP: int = 0 VLLM_TPU_MOST_MODEL_LEN: Optional[int] = None VLLM_USE_DEEP_GEMM: bool = False - VLLM_USE_FLASHINFER_MOE: bool = False + VLLM_USE_FLASHINFER_MOE_FP8: bool = False + VLLM_USE_FLASHINFER_MOE_FP4: bool = False VLLM_XGRAMMAR_CACHE_MB: int = 0 VLLM_MSGPACK_ZERO_COPY_THRESHOLD: int = 256 VLLM_ALLOW_INSECURE_SERIALIZATION: bool = False @@ -854,9 +855,13 @@ environment_variables: dict[str, Callable[[], Any]] = { "VLLM_USE_DEEP_GEMM": lambda: bool(int(os.getenv("VLLM_USE_DEEP_GEMM", "0"))), + # Allow use of FlashInfer MoE kernels for fused moe ops. + "VLLM_USE_FLASHINFER_MOE_FP8": + lambda: bool(int(os.getenv("VLLM_USE_FLASHINFER_MOE_FP8", "0"))), + # Allow use of FlashInfer CUTLASS kernels for fused moe ops. - "VLLM_USE_FLASHINFER_MOE": - lambda: bool(int(os.getenv("VLLM_USE_FLASHINFER_MOE", "0"))), + "VLLM_USE_FLASHINFER_MOE_FP4": + lambda: bool(int(os.getenv("VLLM_USE_FLASHINFER_MOE_FP4", "0"))), # Control the cache sized used by the xgrammar compiler. The default # of 512 MB should be enough for roughly 1000 JSON schemas. diff --git a/vllm/model_executor/layers/fused_moe/config.py b/vllm/model_executor/layers/fused_moe/config.py index 9bebb6a65fce..51c421bd228f 100644 --- a/vllm/model_executor/layers/fused_moe/config.py +++ b/vllm/model_executor/layers/fused_moe/config.py @@ -191,7 +191,7 @@ class FusedMoEParallelConfig: @property def use_flashinfer_cutlass_kernels(self): - return (envs.VLLM_USE_FLASHINFER_MOE + return (envs.VLLM_USE_FLASHINFER_MOE_FP4 and has_flashinfer_cutlass_fused_moe()) @staticmethod diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index aec5d7b252e3..c412f695ae76 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -28,7 +28,7 @@ from vllm.model_executor.layers.fused_moe.prepare_finalize import ( from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( TopKWeightAndReduceNoOP) from vllm.model_executor.layers.fused_moe.utils import ( - _resize_cache, moe_kernel_quantize_input) + _resize_cache, moe_kernel_quantize_input, per_token_group_quant_fp8) from vllm.model_executor.layers.quantization.utils.mxfp4_utils import ( dequant_mxfp4) from vllm.platforms import current_platform @@ -1061,6 +1061,104 @@ direct_register_custom_op( ) +def next_positive_power_of_2(x: int) -> int: + if x < 1: + return 1 + return 1 << (x - 1).bit_length() + + +def _get_tile_tokens_dim(num_tokens, top_k, num_experts): + # Guess tokens per expert assuming perfect expert distribution first. + num_tokens_per_expert = (num_tokens * top_k) // num_experts + # And pad the number to the next power of 2. + tile_tokens_dim = next_positive_power_of_2(num_tokens_per_expert) + # Cap to 8-64 tokens per CTA tile as it's the range supported by the kernel. + tile_tokens_dim = min(max(tile_tokens_dim, 8), 64) + return tile_tokens_dim + + +def flashinfer_fused_moe_blockscale_fp8( + routing_logits: torch.Tensor, + routing_bias: torch.Tensor, + x: torch.Tensor, + w13_weight: torch.Tensor, + w13_weight_scale_inv: torch.Tensor, + w2_weight: torch.Tensor, + w2_weight_scale_inv: torch.Tensor, + global_num_experts: int, + top_k: int, + num_expert_group: int, + topk_group: int, + intermediate_size: int, + expert_offset: int, + local_num_experts: int, + block_shape: list[int], + routed_scaling: float = 1.0) -> torch.Tensor: + from vllm.utils.flashinfer import flashinfer_trtllm_fp8_block_scale_moe + assert top_k <= global_num_experts + assert top_k <= 8 + assert topk_group <= 4 + assert global_num_experts > num_expert_group + assert global_num_experts % num_expert_group == 0 + assert global_num_experts % 4 == 0 + assert top_k < (topk_group * global_num_experts / num_expert_group) + assert block_shape == [128, 128] + + a_q, a_sf = per_token_group_quant_fp8(x, block_shape[1]) + # NOTE: scales of hidden states have to be transposed! + a_sf_t = a_sf.t().contiguous() + return flashinfer_trtllm_fp8_block_scale_moe( + routing_logits=routing_logits, + routing_bias=routing_bias, + hidden_states=a_q, + hidden_states_scale=a_sf_t, + gemm1_weights=w13_weight, + gemm1_weights_scale=w13_weight_scale_inv, + gemm2_weights=w2_weight, + gemm2_weights_scale=w2_weight_scale_inv, + num_experts=global_num_experts, + top_k=top_k, + n_group=num_expert_group, + topk_group=topk_group, + intermediate_size=intermediate_size, + local_expert_offset=expert_offset, + local_num_experts=local_num_experts, + routed_scaling_factor=routed_scaling, + tile_tokens_dim=_get_tile_tokens_dim(x.shape[0], top_k, + global_num_experts), + routing_method_type=2, # DeepSeek-styled routing method + ) + + +def flashinfer_fused_moe_blockscale_fp8_fake( + routing_logits: torch.Tensor, + routing_bias: torch.Tensor, + x: torch.Tensor, + w13_weight: torch.Tensor, + w13_weight_scale_inv: torch.Tensor, + w2_weight: torch.Tensor, + w2_weight_scale_inv: torch.Tensor, + global_num_experts: int, + top_k: int, + num_expert_group: int, + topk_group: int, + intermediate_size: int, + expert_offset: int, + local_num_experts: int, + block_shape: list[int], + routed_scaling: float = 1.0) -> torch.Tensor: + return torch.empty_like(x) + + +direct_register_custom_op( + op_name="flashinfer_fused_moe_blockscale_fp8", + op_func=flashinfer_fused_moe_blockscale_fp8, + mutates_args=[], + fake_impl=flashinfer_fused_moe_blockscale_fp8_fake, + tags=(torch.Tag.needs_fixed_stride_order, ), +) + + def outplace_fused_experts( hidden_states: torch.Tensor, w1: torch.Tensor, diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index 824dfe15ae25..35d7545d8c6a 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -43,6 +43,7 @@ from vllm.platforms import current_platform from vllm.scalar_type import scalar_types from vllm.utils import has_deep_gemm from vllm.utils.deep_gemm import is_blackwell_deep_gemm_used +from vllm.utils.flashinfer import has_flashinfer_moe if TYPE_CHECKING: from vllm.model_executor.models.utils import WeightsMapper @@ -52,6 +53,11 @@ ACTIVATION_SCHEMES = ["static", "dynamic"] logger = init_logger(__name__) +def _swap_w13_to_w31(x: torch.Tensor) -> torch.Tensor: + return x.reshape(-1, 2, x.shape[-2] // 2, + x.shape[-1]).flip(dims=[1]).reshape(x.shape) + + def _is_col_major(x: torch.Tensor) -> bool: assert x.dim() == 3 b, m, n = x.shape @@ -473,6 +479,11 @@ class Fp8MoEMethod(FusedMoEMethodBase): self.quant_config = quant_config self.block_quant = self.quant_config.weight_block_size is not None + self.flashinfer_moe_enabled = False + if envs.VLLM_USE_FLASHINFER_MOE_FP8 and has_flashinfer_moe(): + logger.info_once( + "Using FlashInfer MoE FP8 kernels for Fp8MoEMethod.") + self.flashinfer_moe_enabled = True # For GPUs that lack FP8 hardware support, we can leverage the Marlin # kernel for fast weight-only FP8 quantization self.use_marlin = (not current_platform.has_device_capability(89) @@ -674,6 +685,14 @@ class Fp8MoEMethod(FusedMoEMethodBase): normalize_e4m3fn_to_e4m3fnuz( layer.w2_weight, layer.w2_weight_scale_inv, layer.w2_input_scale) + elif self.flashinfer_moe_enabled: + # NOTE: weights have to be swapped since the activation is + # applied on different half for flashinfer vs vllm + w13_weight = _swap_w13_to_w31(layer.w13_weight.data) + w13_weight_scale_inv = _swap_w13_to_w31( + layer.w13_weight_scale_inv.data) + w2_weight = layer.w2_weight.data + w2_weight_scale_inv = layer.w2_weight_scale_inv.data else: w13_weight = layer.w13_weight.data w13_weight_scale_inv = layer.w13_weight_scale_inv.data @@ -915,25 +934,25 @@ class Fp8MoEMethod(FusedMoEMethodBase): assert logical_to_physical_map is not None assert logical_replica_count is not None assert isinstance(layer, FusedMoE) - - topk_weights, topk_ids = FusedMoE.select_experts( - hidden_states=x, - router_logits=router_logits, - use_grouped_topk=use_grouped_topk, - top_k=top_k, - renormalize=renormalize, - topk_group=topk_group, - num_expert_group=num_expert_group, - custom_routing_function=custom_routing_function, - scoring_func=scoring_func, - e_score_correction_bias=e_score_correction_bias, - indices_type=self.topk_indices_dtype, - enable_eplb=enable_eplb, - expert_map=expert_map, - expert_load_view=expert_load_view, - logical_to_physical_map=logical_to_physical_map, - logical_replica_count=logical_replica_count, - ) + if not self.flashinfer_moe_enabled: + topk_weights, topk_ids = FusedMoE.select_experts( + hidden_states=x, + router_logits=router_logits, + use_grouped_topk=use_grouped_topk, + top_k=top_k, + renormalize=renormalize, + topk_group=topk_group, + num_expert_group=num_expert_group, + custom_routing_function=custom_routing_function, + scoring_func=scoring_func, + e_score_correction_bias=e_score_correction_bias, + indices_type=self.topk_indices_dtype, + enable_eplb=enable_eplb, + expert_map=expert_map, + expert_load_view=expert_load_view, + logical_to_physical_map=logical_to_physical_map, + logical_replica_count=logical_replica_count, + ) if self.rocm_aiter_moe_enabled: from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( # noqa: E501 @@ -971,6 +990,31 @@ class Fp8MoEMethod(FusedMoEMethodBase): apply_router_weight_on_input=apply_router_weight_on_input, global_num_experts=global_num_experts, expert_map=expert_map) + elif self.flashinfer_moe_enabled: + # Currently only work with DS models + assert self.block_quant + assert (renormalize and use_grouped_topk + and scoring_func == 'sigmoid' + and custom_routing_function is None) + assert activation == "silu" + return torch.ops.vllm.flashinfer_fused_moe_blockscale_fp8( + routing_logits=router_logits.to(torch.float32), + routing_bias=e_score_correction_bias, + x=x, + w13_weight=layer.w13_weight, + w13_weight_scale_inv=layer.w13_weight_scale_inv, + w2_weight=layer.w2_weight, + w2_weight_scale_inv=layer.w2_weight_scale_inv, + global_num_experts=global_num_experts, + top_k=top_k, + num_expert_group=num_expert_group, + topk_group=topk_group, + intermediate_size=layer.intermediate_size_per_partition, + expert_offset=layer.ep_rank * layer.local_num_experts, + local_num_experts=layer.local_num_experts, + block_shape=self.quant_config.weight_block_size, + routed_scaling=1.0, + ) else: return self.fused_experts( hidden_states=x, diff --git a/vllm/model_executor/layers/quantization/modelopt.py b/vllm/model_executor/layers/quantization/modelopt.py index 3807899fc3e5..20def70d1976 100644 --- a/vllm/model_executor/layers/quantization/modelopt.py +++ b/vllm/model_executor/layers/quantization/modelopt.py @@ -721,7 +721,7 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase): self.use_marlin = False self.allow_flashinfer_cutlass = False - if envs.VLLM_USE_FLASHINFER_MOE: + if envs.VLLM_USE_FLASHINFER_MOE_FP4: if self.cutlass_nvfp4_supported and current_platform.is_cuda() \ and current_platform.is_device_capability(100): logger.info_once( @@ -800,10 +800,9 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase): assert moe.dp_size > 1 logger.debug_once("Using CutlassExpertsFp4") # Currently CutlassExpertsFp4 doesn't support DP - raise ValueError( - "CutlassExpertsFp4 doesn't support DP. " - "Use flashinfer CUTLASS FusedMoE(VLLM_USE_FLASHINFER_MOE)" - " backend instead.") + raise ValueError("CutlassExpertsFp4 doesn't support DP. " + "Use flashinfer CUTLASS FusedMoE backend instead " + "(set VLLM_USE_FLASHINFER_MOE_FP4=1)") return experts diff --git a/vllm/utils/flashinfer.py b/vllm/utils/flashinfer.py index dbd2dc393046..fd8b384a616f 100644 --- a/vllm/utils/flashinfer.py +++ b/vllm/utils/flashinfer.py @@ -64,6 +64,8 @@ def _lazy_import_wrapper(module_name: str, # Create lazy wrappers for each function +flashinfer_trtllm_fp8_block_scale_moe = _lazy_import_wrapper( + "flashinfer.fused_moe", "trtllm_fp8_block_scale_moe") flashinfer_cutlass_fused_moe = _lazy_import_wrapper("flashinfer.fused_moe", "cutlass_fused_moe") fp4_quantize = _lazy_import_wrapper("flashinfer", "fp4_quantize") @@ -77,10 +79,16 @@ autotune = _lazy_import_wrapper( fallback_fn=lambda *args, **kwargs: contextlib.nullcontext()) +@functools.cache +def has_flashinfer_moe() -> bool: + """Return ``True`` if FlashInfer MoE module is available.""" + return importlib.util.find_spec("flashinfer.fused_moe") is not None + + @functools.cache def has_flashinfer_cutlass_fused_moe() -> bool: """Return ``True`` if FlashInfer CUTLASS fused MoE is available.""" - if not has_flashinfer(): + if not has_flashinfer_moe(): return False # Check if all required functions are available @@ -99,9 +107,11 @@ def has_flashinfer_cutlass_fused_moe() -> bool: __all__ = [ "has_flashinfer", - "has_flashinfer_cutlass_fused_moe", + "flashinfer_trtllm_fp8_block_scale_moe", "flashinfer_cutlass_fused_moe", "fp4_quantize", "fp4_swizzle_blockscale", "autotune", + "has_flashinfer_moe", + "has_flashinfer_cutlass_fused_moe", ]