diff --git a/vllm/model_executor/layers/quantization/modelopt.py b/vllm/model_executor/layers/quantization/modelopt.py index adce598c4ff1f..9d4e453ffc545 100644 --- a/vllm/model_executor/layers/quantization/modelopt.py +++ b/vllm/model_executor/layers/quantization/modelopt.py @@ -891,7 +891,11 @@ class ModelOptNvFp4LinearMethod(LinearMethodBase): assert (layer.weight_scale.dtype == torch.float8_e4m3fn), ( "Weight Block scale must be represented as FP8-E4M3") - if self.backend == "flashinfer-trtllm": + if self.backend == "marlin": + prepare_fp4_layer_for_marlin(layer) + del layer.alpha + del layer.input_scale + elif self.backend == "flashinfer-trtllm": # FlashInfer TRTLLM FP4 GEMM requires a different weight layout. # FlashInfer provides nvfp4_quantize to quantize + shuffle the # layout but we use our own quantization so we have to call @@ -916,11 +920,6 @@ class ModelOptNvFp4LinearMethod(LinearMethodBase): requires_grad=False) layer.weight = Parameter(layer.weight.data, requires_grad=False) - if self.backend == "marlin": - prepare_fp4_layer_for_marlin(layer) - del layer.alpha - del layer.input_scale - def apply( self, layer: torch.nn.Module, @@ -1312,6 +1311,13 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase): del layer.w2_weight_scale del layer.w13_weight del layer.w13_weight_scale + elif self.use_marlin: + # Marlin processing + prepare_moe_fp4_layer_for_marlin(layer) + del layer.g1_alphas + del layer.g2_alphas + del layer.w13_input_scale_quant + del layer.w2_input_scale_quant else: # Non-TRT-LLM processing (Cutlass or non-flashinfer) assert (layer.w13_weight_scale.shape[2] % 16 == 0), ( @@ -1333,13 +1339,6 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase): layer.w2_weight = Parameter(layer.w2_weight.data, requires_grad=False) - if self.use_marlin: - prepare_moe_fp4_layer_for_marlin(layer) - del layer.g1_alphas - del layer.g2_alphas - del layer.w13_input_scale_quant - del layer.w2_input_scale_quant - def apply( self, layer: torch.nn.Module,