mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-16 02:15:26 +08:00
Add TRTLLM MoE NVFP4 kernel to CompressedTensorsW4A4MoeMethod (#28892)
Signed-off-by: mingyuanm <mingyuanm@nvidia.com> Signed-off-by: mgoin <mgoin64@gmail.com> Co-authored-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
parent
e99e467384
commit
b4c8fbaae2
@ -8,6 +8,7 @@ from enum import Enum
|
|||||||
import torch
|
import torch
|
||||||
from compressed_tensors import CompressionFormat
|
from compressed_tensors import CompressionFormat
|
||||||
from compressed_tensors.quantization import ActivationOrdering, QuantizationStrategy
|
from compressed_tensors.quantization import ActivationOrdering, QuantizationStrategy
|
||||||
|
from torch.nn.parameter import Parameter
|
||||||
|
|
||||||
import vllm.envs as envs
|
import vllm.envs as envs
|
||||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||||
@ -50,9 +51,15 @@ from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
|
|||||||
from vllm.model_executor.layers.quantization.utils import replace_parameter
|
from vllm.model_executor.layers.quantization.utils import replace_parameter
|
||||||
from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import (
|
from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import (
|
||||||
build_flashinfer_fp4_cutlass_moe_prepare_finalize,
|
build_flashinfer_fp4_cutlass_moe_prepare_finalize,
|
||||||
|
flashinfer_trtllm_fp4_moe,
|
||||||
|
prepare_static_weights_for_trtllm_fp4_moe,
|
||||||
reorder_w1w3_to_w3w1,
|
reorder_w1w3_to_w3w1,
|
||||||
select_nvfp4_gemm_impl,
|
select_nvfp4_gemm_impl,
|
||||||
)
|
)
|
||||||
|
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
|
||||||
|
FlashinferMoeBackend,
|
||||||
|
get_flashinfer_moe_backend,
|
||||||
|
)
|
||||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||||
expert_weight_is_col_major,
|
expert_weight_is_col_major,
|
||||||
requant_weight_ue8m0_inplace,
|
requant_weight_ue8m0_inplace,
|
||||||
@ -193,6 +200,13 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod):
|
|||||||
self.allow_flashinfer = _nvfp4.allow_flashinfer
|
self.allow_flashinfer = _nvfp4.allow_flashinfer
|
||||||
self.use_marlin = _nvfp4.use_marlin
|
self.use_marlin = _nvfp4.use_marlin
|
||||||
self.group_size = 16
|
self.group_size = 16
|
||||||
|
self.flashinfer_moe_backend = None
|
||||||
|
if self.allow_flashinfer:
|
||||||
|
self.flashinfer_moe_backend = get_flashinfer_moe_backend()
|
||||||
|
logger.info_once(
|
||||||
|
f"Using FlashInfer {self.flashinfer_moe_backend.value} kernels"
|
||||||
|
" for CompressedTensorsW4A4MoeMethod."
|
||||||
|
)
|
||||||
|
|
||||||
def create_weights(
|
def create_weights(
|
||||||
self,
|
self,
|
||||||
@ -344,21 +358,20 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod):
|
|||||||
if self.use_marlin:
|
if self.use_marlin:
|
||||||
prepare_moe_fp4_layer_for_marlin(layer)
|
prepare_moe_fp4_layer_for_marlin(layer)
|
||||||
return
|
return
|
||||||
|
|
||||||
# swizzle weight scales
|
|
||||||
layer.w13_weight_scale = torch.nn.Parameter(
|
|
||||||
swizzle_blockscale(layer.w13_weight_scale), requires_grad=False
|
|
||||||
)
|
|
||||||
|
|
||||||
layer.w2_weight_scale = torch.nn.Parameter(
|
|
||||||
swizzle_blockscale(layer.w2_weight_scale), requires_grad=False
|
|
||||||
)
|
|
||||||
|
|
||||||
# w13
|
# w13
|
||||||
w13_input_global_scale = layer.w13_input_global_scale.max(dim=1).values.to(
|
if (
|
||||||
torch.float32
|
self.allow_flashinfer
|
||||||
)
|
and self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM
|
||||||
|
):
|
||||||
|
w13_input_global_scale = (
|
||||||
|
layer.w13_input_global_scale.min()
|
||||||
|
.to(torch.float32)
|
||||||
|
.expand(layer.num_experts)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
w13_input_global_scale = layer.w13_input_global_scale.min(dim=1).values.to(
|
||||||
|
torch.float32
|
||||||
|
)
|
||||||
layer.g1_alphas = torch.nn.Parameter(
|
layer.g1_alphas = torch.nn.Parameter(
|
||||||
((1 / w13_input_global_scale) * layer.w13_weight_scale_2),
|
((1 / w13_input_global_scale) * layer.w13_weight_scale_2),
|
||||||
requires_grad=False,
|
requires_grad=False,
|
||||||
@ -369,22 +382,92 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# w2
|
# w2
|
||||||
|
if (
|
||||||
|
self.allow_flashinfer
|
||||||
|
and self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM
|
||||||
|
):
|
||||||
|
w2_input_global_scale = (
|
||||||
|
layer.w2_input_global_scale.min()
|
||||||
|
.to(torch.float32)
|
||||||
|
.expand(layer.num_experts)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
w2_input_global_scale = layer.w2_input_global_scale
|
||||||
|
|
||||||
layer.g2_alphas = torch.nn.Parameter(
|
layer.g2_alphas = torch.nn.Parameter(
|
||||||
((1 / layer.w2_input_global_scale) * layer.w2_weight_scale_2).to(
|
((1 / w2_input_global_scale) * layer.w2_weight_scale_2).to(torch.float32),
|
||||||
torch.float32
|
|
||||||
),
|
|
||||||
requires_grad=False,
|
requires_grad=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
layer.w2_input_scale_quant = torch.nn.Parameter(
|
layer.w2_input_scale_quant = torch.nn.Parameter(
|
||||||
(layer.w2_input_global_scale), requires_grad=False
|
(w2_input_global_scale), requires_grad=False
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# TensorRT-LLM specific processing
|
||||||
|
if (
|
||||||
|
self.allow_flashinfer
|
||||||
|
and self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM
|
||||||
|
):
|
||||||
|
# Prepare static weights for TRT-LLM kernel
|
||||||
|
# alternate: prepare_static_weight_layouts_for_trtllm_moe
|
||||||
|
(
|
||||||
|
gemm1_weights_fp4_shuffled,
|
||||||
|
gemm1_scales_fp4_shuffled,
|
||||||
|
gemm2_weights_fp4_shuffled,
|
||||||
|
gemm2_scales_fp4_shuffled,
|
||||||
|
) = prepare_static_weights_for_trtllm_fp4_moe(
|
||||||
|
layer.w13_weight,
|
||||||
|
layer.w2_weight,
|
||||||
|
layer.w13_weight_scale,
|
||||||
|
layer.w2_weight_scale,
|
||||||
|
layer.w2_weight.size(-2), # hidden_size
|
||||||
|
layer.w13_weight.size(-2) // 2, # intermediate_size
|
||||||
|
layer.w13_weight.size(0), # num_experts
|
||||||
|
)
|
||||||
|
logger.debug_once("Finished shuffling weights for TRT-LLM MOE")
|
||||||
|
|
||||||
|
layer.gemm1_weights_fp4_shuffled = Parameter(
|
||||||
|
gemm1_weights_fp4_shuffled, requires_grad=False
|
||||||
|
)
|
||||||
|
layer.gemm2_weights_fp4_shuffled = Parameter(
|
||||||
|
gemm2_weights_fp4_shuffled, requires_grad=False
|
||||||
|
)
|
||||||
|
layer.gemm1_scales_fp4_shuffled = Parameter(
|
||||||
|
gemm1_scales_fp4_shuffled, requires_grad=False
|
||||||
|
)
|
||||||
|
layer.gemm2_scales_fp4_shuffled = Parameter(
|
||||||
|
gemm2_scales_fp4_shuffled, requires_grad=False
|
||||||
|
)
|
||||||
|
|
||||||
|
# Additional parameter needed for TRT-LLM
|
||||||
|
layer.g1_scale_c = Parameter(
|
||||||
|
(layer.w2_input_scale_quant * layer.g1_alphas).to(torch.float32),
|
||||||
|
requires_grad=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Clean up weights that won't be used by TRT-LLM
|
||||||
|
del layer.w2_weight
|
||||||
|
del layer.w2_weight_scale
|
||||||
|
del layer.w13_weight
|
||||||
|
del layer.w13_weight_scale
|
||||||
|
else:
|
||||||
|
# swizzle weight scales
|
||||||
|
layer.w13_weight_scale = torch.nn.Parameter(
|
||||||
|
swizzle_blockscale(layer.w13_weight_scale), requires_grad=False
|
||||||
|
)
|
||||||
|
|
||||||
|
layer.w2_weight_scale = torch.nn.Parameter(
|
||||||
|
swizzle_blockscale(layer.w2_weight_scale), requires_grad=False
|
||||||
|
)
|
||||||
|
|
||||||
def maybe_make_prepare_finalize(
|
def maybe_make_prepare_finalize(
|
||||||
self,
|
self,
|
||||||
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
|
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
|
||||||
) -> mk.FusedMoEPrepareAndFinalize | None:
|
) -> mk.FusedMoEPrepareAndFinalize | None:
|
||||||
if self.use_marlin:
|
if self.use_marlin or (
|
||||||
|
self.allow_flashinfer
|
||||||
|
and self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM
|
||||||
|
):
|
||||||
return None
|
return None
|
||||||
elif not self.allow_flashinfer:
|
elif not self.allow_flashinfer:
|
||||||
return super().maybe_make_prepare_finalize(routing_tables)
|
return super().maybe_make_prepare_finalize(routing_tables)
|
||||||
@ -411,7 +494,10 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod):
|
|||||||
def get_fused_moe_quant_config(
|
def get_fused_moe_quant_config(
|
||||||
self, layer: torch.nn.Module
|
self, layer: torch.nn.Module
|
||||||
) -> FusedMoEQuantConfig | None:
|
) -> FusedMoEQuantConfig | None:
|
||||||
if self.use_marlin:
|
if (
|
||||||
|
self.use_marlin
|
||||||
|
or self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM
|
||||||
|
):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
return nvfp4_moe_quant_config(
|
return nvfp4_moe_quant_config(
|
||||||
@ -452,6 +538,22 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod):
|
|||||||
)
|
)
|
||||||
assert activation == "silu", "Only SiLU activation is supported."
|
assert activation == "silu", "Only SiLU activation is supported."
|
||||||
|
|
||||||
|
if (
|
||||||
|
self.allow_flashinfer
|
||||||
|
and self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM
|
||||||
|
):
|
||||||
|
return flashinfer_trtllm_fp4_moe(
|
||||||
|
layer=layer,
|
||||||
|
x=x,
|
||||||
|
router_logits=router_logits,
|
||||||
|
top_k=top_k,
|
||||||
|
global_num_experts=global_num_experts,
|
||||||
|
num_expert_group=num_expert_group,
|
||||||
|
topk_group=topk_group,
|
||||||
|
custom_routing_function=custom_routing_function,
|
||||||
|
e_score_correction_bias=e_score_correction_bias,
|
||||||
|
)
|
||||||
|
|
||||||
topk_weights, topk_ids, _ = FusedMoE.select_experts(
|
topk_weights, topk_ids, _ = FusedMoE.select_experts(
|
||||||
hidden_states=x,
|
hidden_states=x,
|
||||||
router_logits=router_logits,
|
router_logits=router_logits,
|
||||||
|
|||||||
@ -15,7 +15,6 @@ from vllm._custom_ops import cutlass_scaled_fp4_mm, scaled_fp4_quant
|
|||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.layers.fused_moe.config import (
|
from vllm.model_executor.layers.fused_moe.config import (
|
||||||
FusedMoEQuantConfig,
|
FusedMoEQuantConfig,
|
||||||
RoutingMethodType,
|
|
||||||
fp8_w8a8_moe_quant_config,
|
fp8_w8a8_moe_quant_config,
|
||||||
nvfp4_moe_quant_config,
|
nvfp4_moe_quant_config,
|
||||||
)
|
)
|
||||||
@ -38,6 +37,8 @@ from vllm.model_executor.layers.quantization.base_config import (
|
|||||||
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
|
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
|
||||||
from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import (
|
from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import (
|
||||||
build_flashinfer_fp4_cutlass_moe_prepare_finalize,
|
build_flashinfer_fp4_cutlass_moe_prepare_finalize,
|
||||||
|
flashinfer_trtllm_fp4_moe,
|
||||||
|
prepare_static_weights_for_trtllm_fp4_moe,
|
||||||
reorder_w1w3_to_w3w1,
|
reorder_w1w3_to_w3w1,
|
||||||
select_nvfp4_gemm_impl,
|
select_nvfp4_gemm_impl,
|
||||||
)
|
)
|
||||||
@ -1136,7 +1137,6 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
|
|||||||
self.allow_flashinfer = _nvfp4.allow_flashinfer
|
self.allow_flashinfer = _nvfp4.allow_flashinfer
|
||||||
self.use_marlin = _nvfp4.use_marlin
|
self.use_marlin = _nvfp4.use_marlin
|
||||||
self.flashinfer_moe_backend = None
|
self.flashinfer_moe_backend = None
|
||||||
self._cache_permute_indices: dict[torch.Size, torch.Tensor] = {}
|
|
||||||
if self.allow_flashinfer:
|
if self.allow_flashinfer:
|
||||||
self.flashinfer_moe_backend = get_flashinfer_moe_backend()
|
self.flashinfer_moe_backend = get_flashinfer_moe_backend()
|
||||||
logger.info_once(
|
logger.info_once(
|
||||||
@ -1303,138 +1303,14 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
|
|||||||
)
|
)
|
||||||
layer.register_parameter("w2_input_scale", w2_input_scale)
|
layer.register_parameter("w2_input_scale", w2_input_scale)
|
||||||
|
|
||||||
def prepare_static_weights_for_trtllm_fp4_moe(
|
|
||||||
self,
|
|
||||||
# args_dequant,
|
|
||||||
# args,
|
|
||||||
gemm1_weights,
|
|
||||||
gemm2_weights,
|
|
||||||
gemm1_scales_linear_fp4_bytes,
|
|
||||||
gemm2_scales_linear_fp4_bytes,
|
|
||||||
hidden_size,
|
|
||||||
intermediate_size,
|
|
||||||
num_experts,
|
|
||||||
):
|
|
||||||
from flashinfer import nvfp4_block_scale_interleave
|
|
||||||
from flashinfer.fused_moe.core import (
|
|
||||||
_maybe_get_cached_w3_w1_permute_indices,
|
|
||||||
get_w2_permute_indices_with_cache,
|
|
||||||
)
|
|
||||||
|
|
||||||
"""Prepare quantized weights for kernel (done offline with weights)."""
|
|
||||||
epilogue_tile_m = 128 # FIXME: this depends on the kernel internals
|
|
||||||
|
|
||||||
# Convert quantized weights to proper formats
|
|
||||||
gemm1_weights_fp4 = gemm1_weights.view(torch.float8_e4m3fn).reshape(
|
|
||||||
num_experts, 2 * intermediate_size, hidden_size // 2
|
|
||||||
) # packed fp4
|
|
||||||
gemm1_scales_linear_fp4 = gemm1_scales_linear_fp4_bytes.view(
|
|
||||||
torch.float8_e4m3fn
|
|
||||||
).reshape(
|
|
||||||
num_experts, 2 * intermediate_size, hidden_size // 16
|
|
||||||
) # fp8 scaling factors
|
|
||||||
|
|
||||||
gemm2_weights_fp4 = gemm2_weights.view(torch.float8_e4m3fn).reshape(
|
|
||||||
num_experts, hidden_size, intermediate_size // 2
|
|
||||||
) # packed fp4
|
|
||||||
gemm2_scales_linear_fp4 = gemm2_scales_linear_fp4_bytes.view(
|
|
||||||
torch.float8_e4m3fn
|
|
||||||
).reshape(
|
|
||||||
num_experts, hidden_size, intermediate_size // 16
|
|
||||||
) # fp8 scaling factors
|
|
||||||
|
|
||||||
gemm1_weights_fp4_shuffled = []
|
|
||||||
gemm1_scales_fp4_shuffled = []
|
|
||||||
gemm2_weights_fp4_shuffled = []
|
|
||||||
gemm2_scales_fp4_shuffled = []
|
|
||||||
for i in range(num_experts):
|
|
||||||
# Calculate the permute indices for the following:
|
|
||||||
# 1. Reorder rows of W1 and scales for fused gated activation
|
|
||||||
# 2. Shuffle weights and scaling factors for transposed mma output
|
|
||||||
# for both w3_w1 and w2 weights and scale factors
|
|
||||||
permute_indices = _maybe_get_cached_w3_w1_permute_indices(
|
|
||||||
self._cache_permute_indices,
|
|
||||||
gemm1_weights_fp4[i].view(torch.uint8),
|
|
||||||
epilogue_tile_m,
|
|
||||||
)
|
|
||||||
gemm1_weights_fp4_shuffled.append(
|
|
||||||
gemm1_weights_fp4[i]
|
|
||||||
.view(torch.uint8)[permute_indices.to(gemm1_weights_fp4.device)]
|
|
||||||
.contiguous()
|
|
||||||
)
|
|
||||||
|
|
||||||
permute_sf_indices = _maybe_get_cached_w3_w1_permute_indices(
|
|
||||||
self._cache_permute_indices,
|
|
||||||
gemm1_scales_linear_fp4[i].view(torch.uint8),
|
|
||||||
epilogue_tile_m,
|
|
||||||
num_elts_per_sf=16,
|
|
||||||
)
|
|
||||||
gemm1_scales_fp4_shuffled.append(
|
|
||||||
nvfp4_block_scale_interleave(
|
|
||||||
gemm1_scales_linear_fp4[i]
|
|
||||||
.view(torch.uint8)[
|
|
||||||
permute_sf_indices.to(gemm1_scales_linear_fp4.device)
|
|
||||||
]
|
|
||||||
.contiguous()
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
permute_indices = get_w2_permute_indices_with_cache(
|
|
||||||
self._cache_permute_indices,
|
|
||||||
gemm2_weights_fp4[i].view(torch.uint8),
|
|
||||||
epilogue_tile_m,
|
|
||||||
)
|
|
||||||
gemm2_weights_fp4_shuffled.append(
|
|
||||||
gemm2_weights_fp4[i]
|
|
||||||
.view(torch.uint8)[permute_indices.to(gemm2_weights_fp4.device)]
|
|
||||||
.contiguous()
|
|
||||||
)
|
|
||||||
|
|
||||||
permute_sf_indices = get_w2_permute_indices_with_cache(
|
|
||||||
self._cache_permute_indices,
|
|
||||||
gemm2_scales_linear_fp4[i].view(torch.uint8),
|
|
||||||
epilogue_tile_m,
|
|
||||||
num_elts_per_sf=16,
|
|
||||||
)
|
|
||||||
gemm2_scales_fp4_shuffled.append(
|
|
||||||
nvfp4_block_scale_interleave(
|
|
||||||
gemm2_scales_linear_fp4[i]
|
|
||||||
.view(torch.uint8)[
|
|
||||||
permute_sf_indices.to(gemm2_scales_linear_fp4.device)
|
|
||||||
]
|
|
||||||
.contiguous()
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
# Stack weights for all experts
|
|
||||||
gemm1_weights_fp4_shuffled = torch.stack(gemm1_weights_fp4_shuffled)
|
|
||||||
gemm1_scales_fp4_shuffled = (
|
|
||||||
torch.stack(gemm1_scales_fp4_shuffled)
|
|
||||||
.view(torch.float8_e4m3fn)
|
|
||||||
.reshape(num_experts, 2 * intermediate_size, hidden_size // 16)
|
|
||||||
)
|
|
||||||
|
|
||||||
gemm2_weights_fp4_shuffled = torch.stack(gemm2_weights_fp4_shuffled)
|
|
||||||
gemm2_scales_fp4_shuffled = (
|
|
||||||
torch.stack(gemm2_scales_fp4_shuffled)
|
|
||||||
.view(torch.float8_e4m3fn)
|
|
||||||
.reshape(num_experts, hidden_size, intermediate_size // 16)
|
|
||||||
)
|
|
||||||
return (
|
|
||||||
gemm1_weights_fp4_shuffled,
|
|
||||||
gemm1_scales_fp4_shuffled,
|
|
||||||
gemm2_weights_fp4_shuffled,
|
|
||||||
gemm2_scales_fp4_shuffled,
|
|
||||||
)
|
|
||||||
|
|
||||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||||
# GEMM 1 processing
|
# GEMM 1 processing
|
||||||
gemm1_weight = layer.w13_weight.data
|
gemm1_weight = layer.w13_weight.data
|
||||||
gemm1_weight_scale = layer.w13_weight_scale.data
|
gemm1_weight_scale = layer.w13_weight_scale.data
|
||||||
|
|
||||||
if (
|
if self.allow_flashinfer and (
|
||||||
self.allow_flashinfer
|
self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS
|
||||||
and self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS
|
or self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM
|
||||||
):
|
):
|
||||||
gemm1_weight, gemm1_weight_scale = reorder_w1w3_to_w3w1(
|
gemm1_weight, gemm1_weight_scale = reorder_w1w3_to_w3w1(
|
||||||
gemm1_weight, gemm1_weight_scale, dim=-2
|
gemm1_weight, gemm1_weight_scale, dim=-2
|
||||||
@ -1508,7 +1384,7 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
|
|||||||
gemm1_scales_fp4_shuffled,
|
gemm1_scales_fp4_shuffled,
|
||||||
gemm2_weights_fp4_shuffled,
|
gemm2_weights_fp4_shuffled,
|
||||||
gemm2_scales_fp4_shuffled,
|
gemm2_scales_fp4_shuffled,
|
||||||
) = self.prepare_static_weights_for_trtllm_fp4_moe(
|
) = prepare_static_weights_for_trtllm_fp4_moe(
|
||||||
layer.w13_weight,
|
layer.w13_weight,
|
||||||
layer.w2_weight,
|
layer.w2_weight,
|
||||||
layer.w13_weight_scale,
|
layer.w13_weight_scale,
|
||||||
@ -1614,68 +1490,17 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
|
|||||||
self.allow_flashinfer
|
self.allow_flashinfer
|
||||||
and self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM
|
and self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM
|
||||||
):
|
):
|
||||||
import flashinfer
|
return flashinfer_trtllm_fp4_moe(
|
||||||
|
layer=layer,
|
||||||
from vllm.model_executor.models.llama4 import Llama4MoE
|
x=x,
|
||||||
|
router_logits=router_logits,
|
||||||
a1_gscale = layer.w13_input_scale_quant
|
|
||||||
(hidden_states_fp4, hidden_states_scale_linear_fp4) = (
|
|
||||||
flashinfer.fp4_quantize(
|
|
||||||
x,
|
|
||||||
a1_gscale,
|
|
||||||
is_sf_swizzled_layout=False,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
use_llama4_routing = (
|
|
||||||
custom_routing_function is Llama4MoE.custom_routing_function
|
|
||||||
)
|
|
||||||
routing_method_type = layer.routing_method_type
|
|
||||||
if use_llama4_routing:
|
|
||||||
routing_method_type = RoutingMethodType.Llama4
|
|
||||||
router_logits = (
|
|
||||||
router_logits.to(torch.float32)
|
|
||||||
if routing_method_type == RoutingMethodType.DeepSeekV3
|
|
||||||
else router_logits
|
|
||||||
)
|
|
||||||
routing_bias = e_score_correction_bias
|
|
||||||
if routing_bias is not None:
|
|
||||||
routing_bias = routing_bias.to(torch.bfloat16)
|
|
||||||
out = flashinfer.fused_moe.trtllm_fp4_block_scale_moe(
|
|
||||||
routing_logits=router_logits,
|
|
||||||
routing_bias=routing_bias,
|
|
||||||
hidden_states=hidden_states_fp4,
|
|
||||||
hidden_states_scale=hidden_states_scale_linear_fp4.view(
|
|
||||||
torch.float8_e4m3fn
|
|
||||||
).flatten(),
|
|
||||||
gemm1_weights=layer.gemm1_weights_fp4_shuffled.data,
|
|
||||||
gemm1_weights_scale=layer.gemm1_scales_fp4_shuffled.data.view(
|
|
||||||
torch.float8_e4m3fn
|
|
||||||
),
|
|
||||||
gemm1_bias=None,
|
|
||||||
gemm1_alpha=None,
|
|
||||||
gemm1_beta=None,
|
|
||||||
gemm1_clamp_limit=None,
|
|
||||||
gemm2_weights=layer.gemm2_weights_fp4_shuffled.data,
|
|
||||||
gemm2_weights_scale=layer.gemm2_scales_fp4_shuffled.data.view(
|
|
||||||
torch.float8_e4m3fn
|
|
||||||
),
|
|
||||||
gemm2_bias=None,
|
|
||||||
output1_scale_scalar=layer.g1_scale_c.data,
|
|
||||||
output1_scale_gate_scalar=layer.g1_alphas.data,
|
|
||||||
output2_scale_scalar=layer.g2_alphas.data,
|
|
||||||
num_experts=global_num_experts,
|
|
||||||
top_k=top_k,
|
top_k=top_k,
|
||||||
n_group=num_expert_group,
|
global_num_experts=global_num_experts,
|
||||||
|
num_expert_group=num_expert_group,
|
||||||
topk_group=topk_group,
|
topk_group=topk_group,
|
||||||
intermediate_size=layer.intermediate_size_per_partition,
|
custom_routing_function=custom_routing_function,
|
||||||
local_expert_offset=layer.ep_rank * layer.local_num_experts,
|
e_score_correction_bias=e_score_correction_bias,
|
||||||
local_num_experts=layer.local_num_experts,
|
)
|
||||||
routed_scaling_factor=1.0,
|
|
||||||
tile_tokens_dim=None,
|
|
||||||
routing_method_type=routing_method_type,
|
|
||||||
do_finalize=True,
|
|
||||||
)[0]
|
|
||||||
return out
|
|
||||||
|
|
||||||
topk_weights, topk_ids, _ = FusedMoE.select_experts(
|
topk_weights, topk_ids, _ = FusedMoE.select_experts(
|
||||||
hidden_states=x,
|
hidden_states=x,
|
||||||
|
|||||||
@ -9,6 +9,7 @@ import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
|||||||
from vllm.model_executor.layers.fused_moe.config import (
|
from vllm.model_executor.layers.fused_moe.config import (
|
||||||
FusedMoEConfig,
|
FusedMoEConfig,
|
||||||
FusedMoEQuantConfig,
|
FusedMoEQuantConfig,
|
||||||
|
RoutingMethodType,
|
||||||
)
|
)
|
||||||
from vllm.model_executor.layers.fused_moe.flashinfer_cutedsl_moe import (
|
from vllm.model_executor.layers.fused_moe.flashinfer_cutedsl_moe import (
|
||||||
FlashInferCuteDSLExperts,
|
FlashInferCuteDSLExperts,
|
||||||
@ -110,3 +111,223 @@ def select_nvfp4_gemm_impl(
|
|||||||
"CutlassExpertsFp4 doesn't support DP. Use flashinfer CUTLASS "
|
"CutlassExpertsFp4 doesn't support DP. Use flashinfer CUTLASS "
|
||||||
"Fused MoE backend instead (set VLLM_USE_FLASHINFER_MOE_FP4=1)"
|
"Fused MoE backend instead (set VLLM_USE_FLASHINFER_MOE_FP4=1)"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_static_weights_for_trtllm_fp4_moe(
|
||||||
|
# args_dequant,
|
||||||
|
# args,
|
||||||
|
gemm1_weights,
|
||||||
|
gemm2_weights,
|
||||||
|
gemm1_scales_linear_fp4_bytes,
|
||||||
|
gemm2_scales_linear_fp4_bytes,
|
||||||
|
hidden_size,
|
||||||
|
intermediate_size,
|
||||||
|
num_experts,
|
||||||
|
):
|
||||||
|
from flashinfer import nvfp4_block_scale_interleave
|
||||||
|
from flashinfer.fused_moe.core import (
|
||||||
|
_maybe_get_cached_w3_w1_permute_indices,
|
||||||
|
get_w2_permute_indices_with_cache,
|
||||||
|
)
|
||||||
|
|
||||||
|
_cache_permute_indices: dict[torch.Size, torch.Tensor] = {}
|
||||||
|
"""Prepare quantized weights for kernel (done offline with weights)."""
|
||||||
|
epilogue_tile_m = 128 # FIXME: this depends on the kernel internals
|
||||||
|
|
||||||
|
# Convert quantized weights to proper formats
|
||||||
|
gemm1_weights_fp4 = gemm1_weights.view(torch.float8_e4m3fn).reshape(
|
||||||
|
num_experts, 2 * intermediate_size, hidden_size // 2
|
||||||
|
) # packed fp4
|
||||||
|
gemm1_scales_linear_fp4 = gemm1_scales_linear_fp4_bytes.view(
|
||||||
|
torch.float8_e4m3fn
|
||||||
|
).reshape(
|
||||||
|
num_experts, 2 * intermediate_size, hidden_size // 16
|
||||||
|
) # fp8 scaling factors
|
||||||
|
|
||||||
|
gemm2_weights_fp4 = gemm2_weights.view(torch.float8_e4m3fn).reshape(
|
||||||
|
num_experts, hidden_size, intermediate_size // 2
|
||||||
|
) # packed fp4
|
||||||
|
gemm2_scales_linear_fp4 = gemm2_scales_linear_fp4_bytes.view(
|
||||||
|
torch.float8_e4m3fn
|
||||||
|
).reshape(num_experts, hidden_size, intermediate_size // 16) # fp8 scaling factors
|
||||||
|
|
||||||
|
gemm1_weights_fp4_shuffled = []
|
||||||
|
gemm1_scales_fp4_shuffled = []
|
||||||
|
gemm2_weights_fp4_shuffled = []
|
||||||
|
gemm2_scales_fp4_shuffled = []
|
||||||
|
for i in range(num_experts):
|
||||||
|
# Calculate the permute indices for the following:
|
||||||
|
# 1. Reorder rows of W1 and scales for fused gated activation
|
||||||
|
# 2. Shuffle weights and scaling factors for transposed mma output
|
||||||
|
# for both w3_w1 and w2 weights and scale factors
|
||||||
|
permute_indices = _maybe_get_cached_w3_w1_permute_indices(
|
||||||
|
_cache_permute_indices,
|
||||||
|
gemm1_weights_fp4[i].view(torch.uint8),
|
||||||
|
epilogue_tile_m,
|
||||||
|
)
|
||||||
|
gemm1_weights_fp4_shuffled.append(
|
||||||
|
gemm1_weights_fp4[i]
|
||||||
|
.view(torch.uint8)[permute_indices.to(gemm1_weights_fp4.device)]
|
||||||
|
.contiguous()
|
||||||
|
)
|
||||||
|
|
||||||
|
permute_sf_indices = _maybe_get_cached_w3_w1_permute_indices(
|
||||||
|
_cache_permute_indices,
|
||||||
|
gemm1_scales_linear_fp4[i].view(torch.uint8),
|
||||||
|
epilogue_tile_m,
|
||||||
|
num_elts_per_sf=16,
|
||||||
|
)
|
||||||
|
gemm1_scales_fp4_shuffled.append(
|
||||||
|
nvfp4_block_scale_interleave(
|
||||||
|
gemm1_scales_linear_fp4[i]
|
||||||
|
.view(torch.uint8)[
|
||||||
|
permute_sf_indices.to(gemm1_scales_linear_fp4.device)
|
||||||
|
]
|
||||||
|
.contiguous()
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
permute_indices = get_w2_permute_indices_with_cache(
|
||||||
|
_cache_permute_indices,
|
||||||
|
gemm2_weights_fp4[i].view(torch.uint8),
|
||||||
|
epilogue_tile_m,
|
||||||
|
)
|
||||||
|
gemm2_weights_fp4_shuffled.append(
|
||||||
|
gemm2_weights_fp4[i]
|
||||||
|
.view(torch.uint8)[permute_indices.to(gemm2_weights_fp4.device)]
|
||||||
|
.contiguous()
|
||||||
|
)
|
||||||
|
|
||||||
|
permute_sf_indices = get_w2_permute_indices_with_cache(
|
||||||
|
_cache_permute_indices,
|
||||||
|
gemm2_scales_linear_fp4[i].view(torch.uint8),
|
||||||
|
epilogue_tile_m,
|
||||||
|
num_elts_per_sf=16,
|
||||||
|
)
|
||||||
|
gemm2_scales_fp4_shuffled.append(
|
||||||
|
nvfp4_block_scale_interleave(
|
||||||
|
gemm2_scales_linear_fp4[i]
|
||||||
|
.view(torch.uint8)[
|
||||||
|
permute_sf_indices.to(gemm2_scales_linear_fp4.device)
|
||||||
|
]
|
||||||
|
.contiguous()
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Stack weights for all experts
|
||||||
|
gemm1_weights_fp4_shuffled = torch.stack(gemm1_weights_fp4_shuffled)
|
||||||
|
gemm1_scales_fp4_shuffled = (
|
||||||
|
torch.stack(gemm1_scales_fp4_shuffled)
|
||||||
|
.view(torch.float8_e4m3fn)
|
||||||
|
.reshape(num_experts, 2 * intermediate_size, hidden_size // 16)
|
||||||
|
)
|
||||||
|
|
||||||
|
gemm2_weights_fp4_shuffled = torch.stack(gemm2_weights_fp4_shuffled)
|
||||||
|
gemm2_scales_fp4_shuffled = (
|
||||||
|
torch.stack(gemm2_scales_fp4_shuffled)
|
||||||
|
.view(torch.float8_e4m3fn)
|
||||||
|
.reshape(num_experts, hidden_size, intermediate_size // 16)
|
||||||
|
)
|
||||||
|
return (
|
||||||
|
gemm1_weights_fp4_shuffled,
|
||||||
|
gemm1_scales_fp4_shuffled,
|
||||||
|
gemm2_weights_fp4_shuffled,
|
||||||
|
gemm2_scales_fp4_shuffled,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def flashinfer_trtllm_fp4_moe(
|
||||||
|
layer: torch.nn.Module,
|
||||||
|
x: torch.Tensor,
|
||||||
|
router_logits: torch.Tensor,
|
||||||
|
top_k: int,
|
||||||
|
global_num_experts: int,
|
||||||
|
num_expert_group: int | None,
|
||||||
|
topk_group: int | None,
|
||||||
|
custom_routing_function: object | None,
|
||||||
|
e_score_correction_bias: torch.Tensor | None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Apply FlashInfer TensorRT-LLM FP4 MoE kernel.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
layer: The MoE layer with weights and scales
|
||||||
|
x: Input tensor
|
||||||
|
router_logits: Router logits for expert selection
|
||||||
|
top_k: Number of experts to select per token
|
||||||
|
global_num_experts: Total number of experts across all ranks
|
||||||
|
num_expert_group: Number of expert groups (for grouped routing)
|
||||||
|
topk_group: Top-k within each group
|
||||||
|
custom_routing_function: Custom routing function (e.g., Llama4)
|
||||||
|
e_score_correction_bias: Optional routing bias correction
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Output tensor from the MoE layer
|
||||||
|
"""
|
||||||
|
import flashinfer
|
||||||
|
|
||||||
|
from vllm.model_executor.models.llama4 import Llama4MoE
|
||||||
|
|
||||||
|
# Quantize input to FP4
|
||||||
|
a1_gscale = layer.w13_input_scale_quant
|
||||||
|
(hidden_states_fp4, hidden_states_scale_linear_fp4) = flashinfer.fp4_quantize(
|
||||||
|
x,
|
||||||
|
a1_gscale,
|
||||||
|
is_sf_swizzled_layout=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Determine routing method type
|
||||||
|
use_llama4_routing = custom_routing_function is Llama4MoE.custom_routing_function
|
||||||
|
routing_method_type = layer.routing_method_type
|
||||||
|
if use_llama4_routing:
|
||||||
|
routing_method_type = flashinfer.RoutingMethodType.Llama4
|
||||||
|
|
||||||
|
# Prepare routing bias
|
||||||
|
routing_bias = e_score_correction_bias
|
||||||
|
if routing_bias is not None:
|
||||||
|
routing_bias = routing_bias.to(torch.bfloat16)
|
||||||
|
|
||||||
|
router_logits = (
|
||||||
|
router_logits.to(torch.float32)
|
||||||
|
if routing_method_type == RoutingMethodType.DeepSeekV3
|
||||||
|
else router_logits
|
||||||
|
)
|
||||||
|
|
||||||
|
# Call TRT-LLM FP4 block-scale MoE kernel
|
||||||
|
out = flashinfer.fused_moe.trtllm_fp4_block_scale_moe(
|
||||||
|
routing_logits=router_logits,
|
||||||
|
routing_bias=routing_bias,
|
||||||
|
hidden_states=hidden_states_fp4,
|
||||||
|
hidden_states_scale=hidden_states_scale_linear_fp4.view(
|
||||||
|
torch.float8_e4m3fn
|
||||||
|
).flatten(),
|
||||||
|
gemm1_weights=layer.gemm1_weights_fp4_shuffled.data,
|
||||||
|
gemm1_weights_scale=layer.gemm1_scales_fp4_shuffled.data.view(
|
||||||
|
torch.float8_e4m3fn
|
||||||
|
),
|
||||||
|
gemm1_bias=None,
|
||||||
|
gemm1_alpha=None,
|
||||||
|
gemm1_beta=None,
|
||||||
|
gemm1_clamp_limit=None,
|
||||||
|
gemm2_weights=layer.gemm2_weights_fp4_shuffled.data,
|
||||||
|
gemm2_weights_scale=layer.gemm2_scales_fp4_shuffled.data.view(
|
||||||
|
torch.float8_e4m3fn
|
||||||
|
),
|
||||||
|
gemm2_bias=None,
|
||||||
|
output1_scale_scalar=layer.g1_scale_c.data,
|
||||||
|
output1_scale_gate_scalar=layer.g1_alphas.data,
|
||||||
|
output2_scale_scalar=layer.g2_alphas.data,
|
||||||
|
num_experts=global_num_experts,
|
||||||
|
top_k=top_k,
|
||||||
|
n_group=num_expert_group if num_expert_group is not None else 0,
|
||||||
|
topk_group=topk_group if topk_group is not None else 0,
|
||||||
|
intermediate_size=layer.intermediate_size_per_partition,
|
||||||
|
local_expert_offset=layer.ep_rank * layer.local_num_experts,
|
||||||
|
local_num_experts=layer.local_num_experts,
|
||||||
|
routed_scaling_factor=None,
|
||||||
|
tile_tokens_dim=None,
|
||||||
|
routing_method_type=routing_method_type,
|
||||||
|
do_finalize=True,
|
||||||
|
)[0]
|
||||||
|
|
||||||
|
return out
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user