diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py index 18c2ab026b2ba..f650a6eabbb9c 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py @@ -469,16 +469,14 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod): ) logger.debug_once("Finished shuffling weights for TRT-LLM MOE") - layer.gemm1_weights_fp4_shuffled = Parameter( + layer.w13_weight = Parameter( gemm1_weights_fp4_shuffled, requires_grad=False ) - layer.gemm2_weights_fp4_shuffled = Parameter( - gemm2_weights_fp4_shuffled, requires_grad=False - ) - layer.gemm1_scales_fp4_shuffled = Parameter( + layer.w2_weight = Parameter(gemm2_weights_fp4_shuffled, requires_grad=False) + layer.w13_weight_scale = Parameter( gemm1_scales_fp4_shuffled, requires_grad=False ) - layer.gemm2_scales_fp4_shuffled = Parameter( + layer.w2_weight_scale = Parameter( gemm2_scales_fp4_shuffled, requires_grad=False ) @@ -487,12 +485,6 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod): (layer.w2_input_scale_quant * layer.g1_alphas).to(torch.float32), requires_grad=False, ) - - # Clean up weights that won't be used by TRT-LLM - del layer.w2_weight - del layer.w2_weight_scale - del layer.w13_weight - del layer.w13_weight_scale else: # swizzle weight scales layer.w13_weight_scale = torch.nn.Parameter( diff --git a/vllm/model_executor/layers/quantization/modelopt.py b/vllm/model_executor/layers/quantization/modelopt.py index 030d85080a34d..f71854e6b63c5 100644 --- a/vllm/model_executor/layers/quantization/modelopt.py +++ b/vllm/model_executor/layers/quantization/modelopt.py @@ -1458,16 +1458,14 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase): ) logger.debug_once("Finished shuffling weights for TRT-LLM MOE") - layer.gemm1_weights_fp4_shuffled = Parameter( + layer.w13_weight = Parameter( gemm1_weights_fp4_shuffled, requires_grad=False ) - layer.gemm2_weights_fp4_shuffled = Parameter( - gemm2_weights_fp4_shuffled, requires_grad=False - ) - layer.gemm1_scales_fp4_shuffled = Parameter( + layer.w2_weight = Parameter(gemm2_weights_fp4_shuffled, requires_grad=False) + layer.w13_weight_scale = Parameter( gemm1_scales_fp4_shuffled, requires_grad=False ) - layer.gemm2_scales_fp4_shuffled = Parameter( + layer.w2_weight_scale = Parameter( gemm2_scales_fp4_shuffled, requires_grad=False ) @@ -1476,12 +1474,6 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase): (layer.w2_input_scale_quant * layer.g1_alphas).to(torch.float32), requires_grad=False, ) - - # Clean up weights that won't be used by TRT-LLM - del layer.w2_weight - del layer.w2_weight_scale - del layer.w13_weight - del layer.w13_weight_scale elif self.use_marlin: # Marlin processing prepare_moe_fp4_layer_for_marlin(layer) diff --git a/vllm/model_executor/layers/quantization/utils/flashinfer_fp4_moe.py b/vllm/model_executor/layers/quantization/utils/flashinfer_fp4_moe.py index e424cd0e1ac99..76bce8a8d98d6 100644 --- a/vllm/model_executor/layers/quantization/utils/flashinfer_fp4_moe.py +++ b/vllm/model_executor/layers/quantization/utils/flashinfer_fp4_moe.py @@ -301,18 +301,14 @@ def flashinfer_trtllm_fp4_moe( hidden_states_scale=hidden_states_scale_linear_fp4.view( torch.float8_e4m3fn ).flatten(), - gemm1_weights=layer.gemm1_weights_fp4_shuffled.data, - gemm1_weights_scale=layer.gemm1_scales_fp4_shuffled.data.view( - torch.float8_e4m3fn - ), + gemm1_weights=layer.w13_weight.data, + gemm1_weights_scale=layer.w13_weight_scale.data.view(torch.float8_e4m3fn), gemm1_bias=None, gemm1_alpha=None, gemm1_beta=None, gemm1_clamp_limit=None, - gemm2_weights=layer.gemm2_weights_fp4_shuffled.data, - gemm2_weights_scale=layer.gemm2_scales_fp4_shuffled.data.view( - torch.float8_e4m3fn - ), + gemm2_weights=layer.w2_weight.data, + gemm2_weights_scale=layer.w2_weight_scale.data.view(torch.float8_e4m3fn), gemm2_bias=None, output1_scale_scalar=layer.g1_scale_c.data, output1_scale_gate_scalar=layer.g1_alphas.data, @@ -380,18 +376,14 @@ def flashinfer_trtllm_fp4_routed_moe( hidden_states_scale=hidden_states_scale_linear_fp4.view( torch.float8_e4m3fn ).flatten(), - gemm1_weights=layer.gemm1_weights_fp4_shuffled.data, - gemm1_weights_scale=layer.gemm1_scales_fp4_shuffled.data.view( - torch.float8_e4m3fn - ), + gemm1_weights=layer.w13_weight.data, + gemm1_weights_scale=layer.w13_weight_scale.data.view(torch.float8_e4m3fn), gemm1_bias=None, gemm1_alpha=None, gemm1_beta=None, gemm1_clamp_limit=None, - gemm2_weights=layer.gemm2_weights_fp4_shuffled.data, - gemm2_weights_scale=layer.gemm2_scales_fp4_shuffled.data.view( - torch.float8_e4m3fn - ), + gemm2_weights=layer.w2_weight.data, + gemm2_weights_scale=layer.w2_weight_scale.data.view(torch.float8_e4m3fn), gemm2_bias=None, output1_scale_scalar=layer.g1_scale_c.data, output1_scale_gate_scalar=layer.g1_alphas.data,