diff --git a/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py b/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py index 1753c4f6e2387..3e79a1a8c24b2 100644 --- a/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py +++ b/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py @@ -181,12 +181,12 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute): g2_alphas, ] _ = flashinfer_cutlass_fused_moe( - hidden_states, - topk_ids.to(torch.int), - topk_weights, + input=hidden_states, + token_selected_experts=topk_ids.to(torch.int), + token_final_scales=topk_weights, # FlashInfer API requires weight to be long for nvfp4 - w1.view(torch.long), - w2.view(torch.long), + fc1_expert_weights=w1.view(torch.long), + fc2_expert_weights=w2.view(torch.long), output_dtype=out_dtype, quant_scales=quant_scales, input_sf=a1q_scale, diff --git a/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py b/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py index 49819504c8ec8..e658990e95e58 100644 --- a/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py +++ b/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py @@ -11,7 +11,7 @@ from vllm.forward_context import get_forward_context from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig from vllm.model_executor.layers.fused_moe.utils import ( extract_required_args, moe_kernel_quantize_input) -from vllm.utils.flashinfer import fp4_swizzle_blockscale +from vllm.utils.flashinfer import block_scale_interleave def get_local_sizes(local_tokens): @@ -92,7 +92,7 @@ class FlashInferCutlassMoEPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): dim=0, sizes=get_local_sizes(local_tokens)) a1_m, a1_n = a1q.shape - a1q_scale = fp4_swizzle_blockscale(a1q_scale, a1_m, a1_n * 2) + a1q_scale = block_scale_interleave(a1q_scale) return a1q, a1q_scale, None, topk_ids, topk_weights diff --git a/vllm/utils/flashinfer.py b/vllm/utils/flashinfer.py index fd8b384a616f5..1ddafbae7fc0e 100644 --- a/vllm/utils/flashinfer.py +++ b/vllm/utils/flashinfer.py @@ -69,8 +69,8 @@ flashinfer_trtllm_fp8_block_scale_moe = _lazy_import_wrapper( flashinfer_cutlass_fused_moe = _lazy_import_wrapper("flashinfer.fused_moe", "cutlass_fused_moe") fp4_quantize = _lazy_import_wrapper("flashinfer", "fp4_quantize") -fp4_swizzle_blockscale = _lazy_import_wrapper("flashinfer", - "fp4_swizzle_blockscale") +block_scale_interleave = _lazy_import_wrapper("flashinfer", + "block_scale_interleave") # Special case for autotune since it returns a context manager autotune = _lazy_import_wrapper( @@ -95,7 +95,7 @@ def has_flashinfer_cutlass_fused_moe() -> bool: required_functions = [ ("flashinfer.fused_moe", "cutlass_fused_moe"), ("flashinfer", "fp4_quantize"), - ("flashinfer", "fp4_swizzle_blockscale"), + ("flashinfer", "block_scale_interleave"), ] for module_name, attr_name in required_functions: @@ -110,7 +110,7 @@ __all__ = [ "flashinfer_trtllm_fp8_block_scale_moe", "flashinfer_cutlass_fused_moe", "fp4_quantize", - "fp4_swizzle_blockscale", + "block_scale_interleave", "autotune", "has_flashinfer_moe", "has_flashinfer_cutlass_fused_moe",