diff --git a/vllm/model_executor/layers/quantization/mxfp4.py b/vllm/model_executor/layers/quantization/mxfp4.py index b6d7bc5d5cccd..068af027398ba 100644 --- a/vllm/model_executor/layers/quantization/mxfp4.py +++ b/vllm/model_executor/layers/quantization/mxfp4.py @@ -109,55 +109,74 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): self.intermediate_size = intermediate_size_per_partition_after_pad self.hidden_size = hidden_size # Fused gate_up_proj (column parallel) - w13_weight = torch.nn.Parameter(torch.zeros( - num_experts, - 2 * intermediate_size_per_partition_after_pad, - hidden_size // 2, - dtype=weight_dtype), - requires_grad=False) + w13_weight = torch.nn.Parameter( + torch.zeros( + num_experts, + 2 * intermediate_size_per_partition_after_pad, + hidden_size // 2, + dtype=weight_dtype, + ), + requires_grad=False, + ) layer.register_parameter("w13_weight", w13_weight) set_weight_attrs(w13_weight, extra_weight_attrs) - w13_weight_scale = torch.nn.Parameter(torch.zeros( - num_experts, - 2 * intermediate_size_per_partition_after_pad, - hidden_size // mxfp4_block, - dtype=scale_dtype), - requires_grad=False) + w13_weight_scale = torch.nn.Parameter( + torch.zeros( + num_experts, + 2 * intermediate_size_per_partition_after_pad, + hidden_size // mxfp4_block, + dtype=scale_dtype, + ), + requires_grad=False, + ) layer.register_parameter("w13_weight_scale", w13_weight_scale) set_weight_attrs(w13_weight_scale, extra_weight_attrs) - w13_bias = torch.nn.Parameter(torch.zeros( - num_experts, - 2 * intermediate_size_per_partition_after_pad, - dtype=torch.bfloat16), - requires_grad=False) + w13_bias = torch.nn.Parameter( + torch.zeros( + num_experts, + 2 * intermediate_size_per_partition_after_pad, + dtype=torch.bfloat16, + ), + requires_grad=False, + ) layer.register_parameter("w13_bias", w13_bias) set_weight_attrs(w13_bias, extra_weight_attrs) # down_proj (row parallel) - w2_weight = torch.nn.Parameter(torch.zeros( - num_experts, - hidden_size, - intermediate_size_per_partition_after_pad // 2, - dtype=weight_dtype), - requires_grad=False) + w2_weight = torch.nn.Parameter( + torch.zeros( + num_experts, + hidden_size, + intermediate_size_per_partition_after_pad // 2, + dtype=weight_dtype, + ), + requires_grad=False, + ) layer.register_parameter("w2_weight", w2_weight) set_weight_attrs(w2_weight, extra_weight_attrs) - w2_weight_scale = torch.nn.Parameter(torch.zeros( - num_experts, - hidden_size, - intermediate_size_per_partition_after_pad // mxfp4_block, - dtype=scale_dtype), - requires_grad=False) + w2_weight_scale = torch.nn.Parameter( + torch.zeros( + num_experts, + hidden_size, + intermediate_size_per_partition_after_pad // mxfp4_block, + dtype=scale_dtype, + ), + requires_grad=False, + ) layer.register_parameter("w2_weight_scale", w2_weight_scale) set_weight_attrs(w2_weight_scale, extra_weight_attrs) - w2_bias = torch.nn.Parameter(torch.zeros(num_experts, - hidden_size, - dtype=torch.bfloat16), - requires_grad=False) + w2_bias = torch.nn.Parameter( + torch.zeros( + num_experts, + hidden_size, + dtype=torch.bfloat16, + ), + requires_grad=False, + ) layer.register_parameter("w2_bias", w2_bias) set_weight_attrs(w2_bias, extra_weight_attrs)