mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 05:04:58 +08:00
Add TRTLLM MoE NVFP4 kernel to CompressedTensorsW4A4MoeMethod (#28892)
Signed-off-by: mingyuanm <mingyuanm@nvidia.com> Signed-off-by: mgoin <mgoin64@gmail.com> Co-authored-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
parent
e99e467384
commit
b4c8fbaae2
@ -8,6 +8,7 @@ from enum import Enum
|
||||
import torch
|
||||
from compressed_tensors import CompressionFormat
|
||||
from compressed_tensors.quantization import ActivationOrdering, QuantizationStrategy
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
import vllm.envs as envs
|
||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||
@ -50,9 +51,15 @@ from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
|
||||
from vllm.model_executor.layers.quantization.utils import replace_parameter
|
||||
from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import (
|
||||
build_flashinfer_fp4_cutlass_moe_prepare_finalize,
|
||||
flashinfer_trtllm_fp4_moe,
|
||||
prepare_static_weights_for_trtllm_fp4_moe,
|
||||
reorder_w1w3_to_w3w1,
|
||||
select_nvfp4_gemm_impl,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
|
||||
FlashinferMoeBackend,
|
||||
get_flashinfer_moe_backend,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||
expert_weight_is_col_major,
|
||||
requant_weight_ue8m0_inplace,
|
||||
@ -193,6 +200,13 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod):
|
||||
self.allow_flashinfer = _nvfp4.allow_flashinfer
|
||||
self.use_marlin = _nvfp4.use_marlin
|
||||
self.group_size = 16
|
||||
self.flashinfer_moe_backend = None
|
||||
if self.allow_flashinfer:
|
||||
self.flashinfer_moe_backend = get_flashinfer_moe_backend()
|
||||
logger.info_once(
|
||||
f"Using FlashInfer {self.flashinfer_moe_backend.value} kernels"
|
||||
" for CompressedTensorsW4A4MoeMethod."
|
||||
)
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
@ -344,21 +358,20 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod):
|
||||
if self.use_marlin:
|
||||
prepare_moe_fp4_layer_for_marlin(layer)
|
||||
return
|
||||
|
||||
# swizzle weight scales
|
||||
layer.w13_weight_scale = torch.nn.Parameter(
|
||||
swizzle_blockscale(layer.w13_weight_scale), requires_grad=False
|
||||
)
|
||||
|
||||
layer.w2_weight_scale = torch.nn.Parameter(
|
||||
swizzle_blockscale(layer.w2_weight_scale), requires_grad=False
|
||||
)
|
||||
|
||||
# w13
|
||||
w13_input_global_scale = layer.w13_input_global_scale.max(dim=1).values.to(
|
||||
torch.float32
|
||||
)
|
||||
|
||||
if (
|
||||
self.allow_flashinfer
|
||||
and self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM
|
||||
):
|
||||
w13_input_global_scale = (
|
||||
layer.w13_input_global_scale.min()
|
||||
.to(torch.float32)
|
||||
.expand(layer.num_experts)
|
||||
)
|
||||
else:
|
||||
w13_input_global_scale = layer.w13_input_global_scale.min(dim=1).values.to(
|
||||
torch.float32
|
||||
)
|
||||
layer.g1_alphas = torch.nn.Parameter(
|
||||
((1 / w13_input_global_scale) * layer.w13_weight_scale_2),
|
||||
requires_grad=False,
|
||||
@ -369,22 +382,92 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod):
|
||||
)
|
||||
|
||||
# w2
|
||||
if (
|
||||
self.allow_flashinfer
|
||||
and self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM
|
||||
):
|
||||
w2_input_global_scale = (
|
||||
layer.w2_input_global_scale.min()
|
||||
.to(torch.float32)
|
||||
.expand(layer.num_experts)
|
||||
)
|
||||
else:
|
||||
w2_input_global_scale = layer.w2_input_global_scale
|
||||
|
||||
layer.g2_alphas = torch.nn.Parameter(
|
||||
((1 / layer.w2_input_global_scale) * layer.w2_weight_scale_2).to(
|
||||
torch.float32
|
||||
),
|
||||
((1 / w2_input_global_scale) * layer.w2_weight_scale_2).to(torch.float32),
|
||||
requires_grad=False,
|
||||
)
|
||||
|
||||
layer.w2_input_scale_quant = torch.nn.Parameter(
|
||||
(layer.w2_input_global_scale), requires_grad=False
|
||||
(w2_input_global_scale), requires_grad=False
|
||||
)
|
||||
|
||||
# TensorRT-LLM specific processing
|
||||
if (
|
||||
self.allow_flashinfer
|
||||
and self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM
|
||||
):
|
||||
# Prepare static weights for TRT-LLM kernel
|
||||
# alternate: prepare_static_weight_layouts_for_trtllm_moe
|
||||
(
|
||||
gemm1_weights_fp4_shuffled,
|
||||
gemm1_scales_fp4_shuffled,
|
||||
gemm2_weights_fp4_shuffled,
|
||||
gemm2_scales_fp4_shuffled,
|
||||
) = prepare_static_weights_for_trtllm_fp4_moe(
|
||||
layer.w13_weight,
|
||||
layer.w2_weight,
|
||||
layer.w13_weight_scale,
|
||||
layer.w2_weight_scale,
|
||||
layer.w2_weight.size(-2), # hidden_size
|
||||
layer.w13_weight.size(-2) // 2, # intermediate_size
|
||||
layer.w13_weight.size(0), # num_experts
|
||||
)
|
||||
logger.debug_once("Finished shuffling weights for TRT-LLM MOE")
|
||||
|
||||
layer.gemm1_weights_fp4_shuffled = Parameter(
|
||||
gemm1_weights_fp4_shuffled, requires_grad=False
|
||||
)
|
||||
layer.gemm2_weights_fp4_shuffled = Parameter(
|
||||
gemm2_weights_fp4_shuffled, requires_grad=False
|
||||
)
|
||||
layer.gemm1_scales_fp4_shuffled = Parameter(
|
||||
gemm1_scales_fp4_shuffled, requires_grad=False
|
||||
)
|
||||
layer.gemm2_scales_fp4_shuffled = Parameter(
|
||||
gemm2_scales_fp4_shuffled, requires_grad=False
|
||||
)
|
||||
|
||||
# Additional parameter needed for TRT-LLM
|
||||
layer.g1_scale_c = Parameter(
|
||||
(layer.w2_input_scale_quant * layer.g1_alphas).to(torch.float32),
|
||||
requires_grad=False,
|
||||
)
|
||||
|
||||
# Clean up weights that won't be used by TRT-LLM
|
||||
del layer.w2_weight
|
||||
del layer.w2_weight_scale
|
||||
del layer.w13_weight
|
||||
del layer.w13_weight_scale
|
||||
else:
|
||||
# swizzle weight scales
|
||||
layer.w13_weight_scale = torch.nn.Parameter(
|
||||
swizzle_blockscale(layer.w13_weight_scale), requires_grad=False
|
||||
)
|
||||
|
||||
layer.w2_weight_scale = torch.nn.Parameter(
|
||||
swizzle_blockscale(layer.w2_weight_scale), requires_grad=False
|
||||
)
|
||||
|
||||
def maybe_make_prepare_finalize(
|
||||
self,
|
||||
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
|
||||
) -> mk.FusedMoEPrepareAndFinalize | None:
|
||||
if self.use_marlin:
|
||||
if self.use_marlin or (
|
||||
self.allow_flashinfer
|
||||
and self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM
|
||||
):
|
||||
return None
|
||||
elif not self.allow_flashinfer:
|
||||
return super().maybe_make_prepare_finalize(routing_tables)
|
||||
@ -411,7 +494,10 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod):
|
||||
def get_fused_moe_quant_config(
|
||||
self, layer: torch.nn.Module
|
||||
) -> FusedMoEQuantConfig | None:
|
||||
if self.use_marlin:
|
||||
if (
|
||||
self.use_marlin
|
||||
or self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM
|
||||
):
|
||||
return None
|
||||
|
||||
return nvfp4_moe_quant_config(
|
||||
@ -452,6 +538,22 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod):
|
||||
)
|
||||
assert activation == "silu", "Only SiLU activation is supported."
|
||||
|
||||
if (
|
||||
self.allow_flashinfer
|
||||
and self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM
|
||||
):
|
||||
return flashinfer_trtllm_fp4_moe(
|
||||
layer=layer,
|
||||
x=x,
|
||||
router_logits=router_logits,
|
||||
top_k=top_k,
|
||||
global_num_experts=global_num_experts,
|
||||
num_expert_group=num_expert_group,
|
||||
topk_group=topk_group,
|
||||
custom_routing_function=custom_routing_function,
|
||||
e_score_correction_bias=e_score_correction_bias,
|
||||
)
|
||||
|
||||
topk_weights, topk_ids, _ = FusedMoE.select_experts(
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
|
||||
@ -15,7 +15,6 @@ from vllm._custom_ops import cutlass_scaled_fp4_mm, scaled_fp4_quant
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
FusedMoEQuantConfig,
|
||||
RoutingMethodType,
|
||||
fp8_w8a8_moe_quant_config,
|
||||
nvfp4_moe_quant_config,
|
||||
)
|
||||
@ -38,6 +37,8 @@ from vllm.model_executor.layers.quantization.base_config import (
|
||||
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
|
||||
from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import (
|
||||
build_flashinfer_fp4_cutlass_moe_prepare_finalize,
|
||||
flashinfer_trtllm_fp4_moe,
|
||||
prepare_static_weights_for_trtllm_fp4_moe,
|
||||
reorder_w1w3_to_w3w1,
|
||||
select_nvfp4_gemm_impl,
|
||||
)
|
||||
@ -1136,7 +1137,6 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
|
||||
self.allow_flashinfer = _nvfp4.allow_flashinfer
|
||||
self.use_marlin = _nvfp4.use_marlin
|
||||
self.flashinfer_moe_backend = None
|
||||
self._cache_permute_indices: dict[torch.Size, torch.Tensor] = {}
|
||||
if self.allow_flashinfer:
|
||||
self.flashinfer_moe_backend = get_flashinfer_moe_backend()
|
||||
logger.info_once(
|
||||
@ -1303,138 +1303,14 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
|
||||
)
|
||||
layer.register_parameter("w2_input_scale", w2_input_scale)
|
||||
|
||||
def prepare_static_weights_for_trtllm_fp4_moe(
|
||||
self,
|
||||
# args_dequant,
|
||||
# args,
|
||||
gemm1_weights,
|
||||
gemm2_weights,
|
||||
gemm1_scales_linear_fp4_bytes,
|
||||
gemm2_scales_linear_fp4_bytes,
|
||||
hidden_size,
|
||||
intermediate_size,
|
||||
num_experts,
|
||||
):
|
||||
from flashinfer import nvfp4_block_scale_interleave
|
||||
from flashinfer.fused_moe.core import (
|
||||
_maybe_get_cached_w3_w1_permute_indices,
|
||||
get_w2_permute_indices_with_cache,
|
||||
)
|
||||
|
||||
"""Prepare quantized weights for kernel (done offline with weights)."""
|
||||
epilogue_tile_m = 128 # FIXME: this depends on the kernel internals
|
||||
|
||||
# Convert quantized weights to proper formats
|
||||
gemm1_weights_fp4 = gemm1_weights.view(torch.float8_e4m3fn).reshape(
|
||||
num_experts, 2 * intermediate_size, hidden_size // 2
|
||||
) # packed fp4
|
||||
gemm1_scales_linear_fp4 = gemm1_scales_linear_fp4_bytes.view(
|
||||
torch.float8_e4m3fn
|
||||
).reshape(
|
||||
num_experts, 2 * intermediate_size, hidden_size // 16
|
||||
) # fp8 scaling factors
|
||||
|
||||
gemm2_weights_fp4 = gemm2_weights.view(torch.float8_e4m3fn).reshape(
|
||||
num_experts, hidden_size, intermediate_size // 2
|
||||
) # packed fp4
|
||||
gemm2_scales_linear_fp4 = gemm2_scales_linear_fp4_bytes.view(
|
||||
torch.float8_e4m3fn
|
||||
).reshape(
|
||||
num_experts, hidden_size, intermediate_size // 16
|
||||
) # fp8 scaling factors
|
||||
|
||||
gemm1_weights_fp4_shuffled = []
|
||||
gemm1_scales_fp4_shuffled = []
|
||||
gemm2_weights_fp4_shuffled = []
|
||||
gemm2_scales_fp4_shuffled = []
|
||||
for i in range(num_experts):
|
||||
# Calculate the permute indices for the following:
|
||||
# 1. Reorder rows of W1 and scales for fused gated activation
|
||||
# 2. Shuffle weights and scaling factors for transposed mma output
|
||||
# for both w3_w1 and w2 weights and scale factors
|
||||
permute_indices = _maybe_get_cached_w3_w1_permute_indices(
|
||||
self._cache_permute_indices,
|
||||
gemm1_weights_fp4[i].view(torch.uint8),
|
||||
epilogue_tile_m,
|
||||
)
|
||||
gemm1_weights_fp4_shuffled.append(
|
||||
gemm1_weights_fp4[i]
|
||||
.view(torch.uint8)[permute_indices.to(gemm1_weights_fp4.device)]
|
||||
.contiguous()
|
||||
)
|
||||
|
||||
permute_sf_indices = _maybe_get_cached_w3_w1_permute_indices(
|
||||
self._cache_permute_indices,
|
||||
gemm1_scales_linear_fp4[i].view(torch.uint8),
|
||||
epilogue_tile_m,
|
||||
num_elts_per_sf=16,
|
||||
)
|
||||
gemm1_scales_fp4_shuffled.append(
|
||||
nvfp4_block_scale_interleave(
|
||||
gemm1_scales_linear_fp4[i]
|
||||
.view(torch.uint8)[
|
||||
permute_sf_indices.to(gemm1_scales_linear_fp4.device)
|
||||
]
|
||||
.contiguous()
|
||||
)
|
||||
)
|
||||
|
||||
permute_indices = get_w2_permute_indices_with_cache(
|
||||
self._cache_permute_indices,
|
||||
gemm2_weights_fp4[i].view(torch.uint8),
|
||||
epilogue_tile_m,
|
||||
)
|
||||
gemm2_weights_fp4_shuffled.append(
|
||||
gemm2_weights_fp4[i]
|
||||
.view(torch.uint8)[permute_indices.to(gemm2_weights_fp4.device)]
|
||||
.contiguous()
|
||||
)
|
||||
|
||||
permute_sf_indices = get_w2_permute_indices_with_cache(
|
||||
self._cache_permute_indices,
|
||||
gemm2_scales_linear_fp4[i].view(torch.uint8),
|
||||
epilogue_tile_m,
|
||||
num_elts_per_sf=16,
|
||||
)
|
||||
gemm2_scales_fp4_shuffled.append(
|
||||
nvfp4_block_scale_interleave(
|
||||
gemm2_scales_linear_fp4[i]
|
||||
.view(torch.uint8)[
|
||||
permute_sf_indices.to(gemm2_scales_linear_fp4.device)
|
||||
]
|
||||
.contiguous()
|
||||
)
|
||||
)
|
||||
|
||||
# Stack weights for all experts
|
||||
gemm1_weights_fp4_shuffled = torch.stack(gemm1_weights_fp4_shuffled)
|
||||
gemm1_scales_fp4_shuffled = (
|
||||
torch.stack(gemm1_scales_fp4_shuffled)
|
||||
.view(torch.float8_e4m3fn)
|
||||
.reshape(num_experts, 2 * intermediate_size, hidden_size // 16)
|
||||
)
|
||||
|
||||
gemm2_weights_fp4_shuffled = torch.stack(gemm2_weights_fp4_shuffled)
|
||||
gemm2_scales_fp4_shuffled = (
|
||||
torch.stack(gemm2_scales_fp4_shuffled)
|
||||
.view(torch.float8_e4m3fn)
|
||||
.reshape(num_experts, hidden_size, intermediate_size // 16)
|
||||
)
|
||||
return (
|
||||
gemm1_weights_fp4_shuffled,
|
||||
gemm1_scales_fp4_shuffled,
|
||||
gemm2_weights_fp4_shuffled,
|
||||
gemm2_scales_fp4_shuffled,
|
||||
)
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
# GEMM 1 processing
|
||||
gemm1_weight = layer.w13_weight.data
|
||||
gemm1_weight_scale = layer.w13_weight_scale.data
|
||||
|
||||
if (
|
||||
self.allow_flashinfer
|
||||
and self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS
|
||||
if self.allow_flashinfer and (
|
||||
self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS
|
||||
or self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM
|
||||
):
|
||||
gemm1_weight, gemm1_weight_scale = reorder_w1w3_to_w3w1(
|
||||
gemm1_weight, gemm1_weight_scale, dim=-2
|
||||
@ -1508,7 +1384,7 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
|
||||
gemm1_scales_fp4_shuffled,
|
||||
gemm2_weights_fp4_shuffled,
|
||||
gemm2_scales_fp4_shuffled,
|
||||
) = self.prepare_static_weights_for_trtllm_fp4_moe(
|
||||
) = prepare_static_weights_for_trtllm_fp4_moe(
|
||||
layer.w13_weight,
|
||||
layer.w2_weight,
|
||||
layer.w13_weight_scale,
|
||||
@ -1614,68 +1490,17 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
|
||||
self.allow_flashinfer
|
||||
and self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM
|
||||
):
|
||||
import flashinfer
|
||||
|
||||
from vllm.model_executor.models.llama4 import Llama4MoE
|
||||
|
||||
a1_gscale = layer.w13_input_scale_quant
|
||||
(hidden_states_fp4, hidden_states_scale_linear_fp4) = (
|
||||
flashinfer.fp4_quantize(
|
||||
x,
|
||||
a1_gscale,
|
||||
is_sf_swizzled_layout=False,
|
||||
)
|
||||
)
|
||||
use_llama4_routing = (
|
||||
custom_routing_function is Llama4MoE.custom_routing_function
|
||||
)
|
||||
routing_method_type = layer.routing_method_type
|
||||
if use_llama4_routing:
|
||||
routing_method_type = RoutingMethodType.Llama4
|
||||
router_logits = (
|
||||
router_logits.to(torch.float32)
|
||||
if routing_method_type == RoutingMethodType.DeepSeekV3
|
||||
else router_logits
|
||||
)
|
||||
routing_bias = e_score_correction_bias
|
||||
if routing_bias is not None:
|
||||
routing_bias = routing_bias.to(torch.bfloat16)
|
||||
out = flashinfer.fused_moe.trtllm_fp4_block_scale_moe(
|
||||
routing_logits=router_logits,
|
||||
routing_bias=routing_bias,
|
||||
hidden_states=hidden_states_fp4,
|
||||
hidden_states_scale=hidden_states_scale_linear_fp4.view(
|
||||
torch.float8_e4m3fn
|
||||
).flatten(),
|
||||
gemm1_weights=layer.gemm1_weights_fp4_shuffled.data,
|
||||
gemm1_weights_scale=layer.gemm1_scales_fp4_shuffled.data.view(
|
||||
torch.float8_e4m3fn
|
||||
),
|
||||
gemm1_bias=None,
|
||||
gemm1_alpha=None,
|
||||
gemm1_beta=None,
|
||||
gemm1_clamp_limit=None,
|
||||
gemm2_weights=layer.gemm2_weights_fp4_shuffled.data,
|
||||
gemm2_weights_scale=layer.gemm2_scales_fp4_shuffled.data.view(
|
||||
torch.float8_e4m3fn
|
||||
),
|
||||
gemm2_bias=None,
|
||||
output1_scale_scalar=layer.g1_scale_c.data,
|
||||
output1_scale_gate_scalar=layer.g1_alphas.data,
|
||||
output2_scale_scalar=layer.g2_alphas.data,
|
||||
num_experts=global_num_experts,
|
||||
return flashinfer_trtllm_fp4_moe(
|
||||
layer=layer,
|
||||
x=x,
|
||||
router_logits=router_logits,
|
||||
top_k=top_k,
|
||||
n_group=num_expert_group,
|
||||
global_num_experts=global_num_experts,
|
||||
num_expert_group=num_expert_group,
|
||||
topk_group=topk_group,
|
||||
intermediate_size=layer.intermediate_size_per_partition,
|
||||
local_expert_offset=layer.ep_rank * layer.local_num_experts,
|
||||
local_num_experts=layer.local_num_experts,
|
||||
routed_scaling_factor=1.0,
|
||||
tile_tokens_dim=None,
|
||||
routing_method_type=routing_method_type,
|
||||
do_finalize=True,
|
||||
)[0]
|
||||
return out
|
||||
custom_routing_function=custom_routing_function,
|
||||
e_score_correction_bias=e_score_correction_bias,
|
||||
)
|
||||
|
||||
topk_weights, topk_ids, _ = FusedMoE.select_experts(
|
||||
hidden_states=x,
|
||||
|
||||
@ -9,6 +9,7 @@ import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
FusedMoEConfig,
|
||||
FusedMoEQuantConfig,
|
||||
RoutingMethodType,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.flashinfer_cutedsl_moe import (
|
||||
FlashInferCuteDSLExperts,
|
||||
@ -110,3 +111,223 @@ def select_nvfp4_gemm_impl(
|
||||
"CutlassExpertsFp4 doesn't support DP. Use flashinfer CUTLASS "
|
||||
"Fused MoE backend instead (set VLLM_USE_FLASHINFER_MOE_FP4=1)"
|
||||
)
|
||||
|
||||
|
||||
def prepare_static_weights_for_trtllm_fp4_moe(
|
||||
# args_dequant,
|
||||
# args,
|
||||
gemm1_weights,
|
||||
gemm2_weights,
|
||||
gemm1_scales_linear_fp4_bytes,
|
||||
gemm2_scales_linear_fp4_bytes,
|
||||
hidden_size,
|
||||
intermediate_size,
|
||||
num_experts,
|
||||
):
|
||||
from flashinfer import nvfp4_block_scale_interleave
|
||||
from flashinfer.fused_moe.core import (
|
||||
_maybe_get_cached_w3_w1_permute_indices,
|
||||
get_w2_permute_indices_with_cache,
|
||||
)
|
||||
|
||||
_cache_permute_indices: dict[torch.Size, torch.Tensor] = {}
|
||||
"""Prepare quantized weights for kernel (done offline with weights)."""
|
||||
epilogue_tile_m = 128 # FIXME: this depends on the kernel internals
|
||||
|
||||
# Convert quantized weights to proper formats
|
||||
gemm1_weights_fp4 = gemm1_weights.view(torch.float8_e4m3fn).reshape(
|
||||
num_experts, 2 * intermediate_size, hidden_size // 2
|
||||
) # packed fp4
|
||||
gemm1_scales_linear_fp4 = gemm1_scales_linear_fp4_bytes.view(
|
||||
torch.float8_e4m3fn
|
||||
).reshape(
|
||||
num_experts, 2 * intermediate_size, hidden_size // 16
|
||||
) # fp8 scaling factors
|
||||
|
||||
gemm2_weights_fp4 = gemm2_weights.view(torch.float8_e4m3fn).reshape(
|
||||
num_experts, hidden_size, intermediate_size // 2
|
||||
) # packed fp4
|
||||
gemm2_scales_linear_fp4 = gemm2_scales_linear_fp4_bytes.view(
|
||||
torch.float8_e4m3fn
|
||||
).reshape(num_experts, hidden_size, intermediate_size // 16) # fp8 scaling factors
|
||||
|
||||
gemm1_weights_fp4_shuffled = []
|
||||
gemm1_scales_fp4_shuffled = []
|
||||
gemm2_weights_fp4_shuffled = []
|
||||
gemm2_scales_fp4_shuffled = []
|
||||
for i in range(num_experts):
|
||||
# Calculate the permute indices for the following:
|
||||
# 1. Reorder rows of W1 and scales for fused gated activation
|
||||
# 2. Shuffle weights and scaling factors for transposed mma output
|
||||
# for both w3_w1 and w2 weights and scale factors
|
||||
permute_indices = _maybe_get_cached_w3_w1_permute_indices(
|
||||
_cache_permute_indices,
|
||||
gemm1_weights_fp4[i].view(torch.uint8),
|
||||
epilogue_tile_m,
|
||||
)
|
||||
gemm1_weights_fp4_shuffled.append(
|
||||
gemm1_weights_fp4[i]
|
||||
.view(torch.uint8)[permute_indices.to(gemm1_weights_fp4.device)]
|
||||
.contiguous()
|
||||
)
|
||||
|
||||
permute_sf_indices = _maybe_get_cached_w3_w1_permute_indices(
|
||||
_cache_permute_indices,
|
||||
gemm1_scales_linear_fp4[i].view(torch.uint8),
|
||||
epilogue_tile_m,
|
||||
num_elts_per_sf=16,
|
||||
)
|
||||
gemm1_scales_fp4_shuffled.append(
|
||||
nvfp4_block_scale_interleave(
|
||||
gemm1_scales_linear_fp4[i]
|
||||
.view(torch.uint8)[
|
||||
permute_sf_indices.to(gemm1_scales_linear_fp4.device)
|
||||
]
|
||||
.contiguous()
|
||||
)
|
||||
)
|
||||
|
||||
permute_indices = get_w2_permute_indices_with_cache(
|
||||
_cache_permute_indices,
|
||||
gemm2_weights_fp4[i].view(torch.uint8),
|
||||
epilogue_tile_m,
|
||||
)
|
||||
gemm2_weights_fp4_shuffled.append(
|
||||
gemm2_weights_fp4[i]
|
||||
.view(torch.uint8)[permute_indices.to(gemm2_weights_fp4.device)]
|
||||
.contiguous()
|
||||
)
|
||||
|
||||
permute_sf_indices = get_w2_permute_indices_with_cache(
|
||||
_cache_permute_indices,
|
||||
gemm2_scales_linear_fp4[i].view(torch.uint8),
|
||||
epilogue_tile_m,
|
||||
num_elts_per_sf=16,
|
||||
)
|
||||
gemm2_scales_fp4_shuffled.append(
|
||||
nvfp4_block_scale_interleave(
|
||||
gemm2_scales_linear_fp4[i]
|
||||
.view(torch.uint8)[
|
||||
permute_sf_indices.to(gemm2_scales_linear_fp4.device)
|
||||
]
|
||||
.contiguous()
|
||||
)
|
||||
)
|
||||
|
||||
# Stack weights for all experts
|
||||
gemm1_weights_fp4_shuffled = torch.stack(gemm1_weights_fp4_shuffled)
|
||||
gemm1_scales_fp4_shuffled = (
|
||||
torch.stack(gemm1_scales_fp4_shuffled)
|
||||
.view(torch.float8_e4m3fn)
|
||||
.reshape(num_experts, 2 * intermediate_size, hidden_size // 16)
|
||||
)
|
||||
|
||||
gemm2_weights_fp4_shuffled = torch.stack(gemm2_weights_fp4_shuffled)
|
||||
gemm2_scales_fp4_shuffled = (
|
||||
torch.stack(gemm2_scales_fp4_shuffled)
|
||||
.view(torch.float8_e4m3fn)
|
||||
.reshape(num_experts, hidden_size, intermediate_size // 16)
|
||||
)
|
||||
return (
|
||||
gemm1_weights_fp4_shuffled,
|
||||
gemm1_scales_fp4_shuffled,
|
||||
gemm2_weights_fp4_shuffled,
|
||||
gemm2_scales_fp4_shuffled,
|
||||
)
|
||||
|
||||
|
||||
def flashinfer_trtllm_fp4_moe(
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
top_k: int,
|
||||
global_num_experts: int,
|
||||
num_expert_group: int | None,
|
||||
topk_group: int | None,
|
||||
custom_routing_function: object | None,
|
||||
e_score_correction_bias: torch.Tensor | None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Apply FlashInfer TensorRT-LLM FP4 MoE kernel.
|
||||
|
||||
Args:
|
||||
layer: The MoE layer with weights and scales
|
||||
x: Input tensor
|
||||
router_logits: Router logits for expert selection
|
||||
top_k: Number of experts to select per token
|
||||
global_num_experts: Total number of experts across all ranks
|
||||
num_expert_group: Number of expert groups (for grouped routing)
|
||||
topk_group: Top-k within each group
|
||||
custom_routing_function: Custom routing function (e.g., Llama4)
|
||||
e_score_correction_bias: Optional routing bias correction
|
||||
|
||||
Returns:
|
||||
Output tensor from the MoE layer
|
||||
"""
|
||||
import flashinfer
|
||||
|
||||
from vllm.model_executor.models.llama4 import Llama4MoE
|
||||
|
||||
# Quantize input to FP4
|
||||
a1_gscale = layer.w13_input_scale_quant
|
||||
(hidden_states_fp4, hidden_states_scale_linear_fp4) = flashinfer.fp4_quantize(
|
||||
x,
|
||||
a1_gscale,
|
||||
is_sf_swizzled_layout=False,
|
||||
)
|
||||
|
||||
# Determine routing method type
|
||||
use_llama4_routing = custom_routing_function is Llama4MoE.custom_routing_function
|
||||
routing_method_type = layer.routing_method_type
|
||||
if use_llama4_routing:
|
||||
routing_method_type = flashinfer.RoutingMethodType.Llama4
|
||||
|
||||
# Prepare routing bias
|
||||
routing_bias = e_score_correction_bias
|
||||
if routing_bias is not None:
|
||||
routing_bias = routing_bias.to(torch.bfloat16)
|
||||
|
||||
router_logits = (
|
||||
router_logits.to(torch.float32)
|
||||
if routing_method_type == RoutingMethodType.DeepSeekV3
|
||||
else router_logits
|
||||
)
|
||||
|
||||
# Call TRT-LLM FP4 block-scale MoE kernel
|
||||
out = flashinfer.fused_moe.trtllm_fp4_block_scale_moe(
|
||||
routing_logits=router_logits,
|
||||
routing_bias=routing_bias,
|
||||
hidden_states=hidden_states_fp4,
|
||||
hidden_states_scale=hidden_states_scale_linear_fp4.view(
|
||||
torch.float8_e4m3fn
|
||||
).flatten(),
|
||||
gemm1_weights=layer.gemm1_weights_fp4_shuffled.data,
|
||||
gemm1_weights_scale=layer.gemm1_scales_fp4_shuffled.data.view(
|
||||
torch.float8_e4m3fn
|
||||
),
|
||||
gemm1_bias=None,
|
||||
gemm1_alpha=None,
|
||||
gemm1_beta=None,
|
||||
gemm1_clamp_limit=None,
|
||||
gemm2_weights=layer.gemm2_weights_fp4_shuffled.data,
|
||||
gemm2_weights_scale=layer.gemm2_scales_fp4_shuffled.data.view(
|
||||
torch.float8_e4m3fn
|
||||
),
|
||||
gemm2_bias=None,
|
||||
output1_scale_scalar=layer.g1_scale_c.data,
|
||||
output1_scale_gate_scalar=layer.g1_alphas.data,
|
||||
output2_scale_scalar=layer.g2_alphas.data,
|
||||
num_experts=global_num_experts,
|
||||
top_k=top_k,
|
||||
n_group=num_expert_group if num_expert_group is not None else 0,
|
||||
topk_group=topk_group if topk_group is not None else 0,
|
||||
intermediate_size=layer.intermediate_size_per_partition,
|
||||
local_expert_offset=layer.ep_rank * layer.local_num_experts,
|
||||
local_num_experts=layer.local_num_experts,
|
||||
routed_scaling_factor=None,
|
||||
tile_tokens_dim=None,
|
||||
routing_method_type=routing_method_type,
|
||||
do_finalize=True,
|
||||
)[0]
|
||||
|
||||
return out
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user