mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-18 11:56:11 +08:00
[Misc] Enhance code formatting in mxfp4.py (#22423)
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
parent
c2dba2dba8
commit
136825de75
@ -109,55 +109,74 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
|||||||
self.intermediate_size = intermediate_size_per_partition_after_pad
|
self.intermediate_size = intermediate_size_per_partition_after_pad
|
||||||
self.hidden_size = hidden_size
|
self.hidden_size = hidden_size
|
||||||
# Fused gate_up_proj (column parallel)
|
# Fused gate_up_proj (column parallel)
|
||||||
w13_weight = torch.nn.Parameter(torch.zeros(
|
w13_weight = torch.nn.Parameter(
|
||||||
num_experts,
|
torch.zeros(
|
||||||
2 * intermediate_size_per_partition_after_pad,
|
num_experts,
|
||||||
hidden_size // 2,
|
2 * intermediate_size_per_partition_after_pad,
|
||||||
dtype=weight_dtype),
|
hidden_size // 2,
|
||||||
requires_grad=False)
|
dtype=weight_dtype,
|
||||||
|
),
|
||||||
|
requires_grad=False,
|
||||||
|
)
|
||||||
layer.register_parameter("w13_weight", w13_weight)
|
layer.register_parameter("w13_weight", w13_weight)
|
||||||
set_weight_attrs(w13_weight, extra_weight_attrs)
|
set_weight_attrs(w13_weight, extra_weight_attrs)
|
||||||
|
|
||||||
w13_weight_scale = torch.nn.Parameter(torch.zeros(
|
w13_weight_scale = torch.nn.Parameter(
|
||||||
num_experts,
|
torch.zeros(
|
||||||
2 * intermediate_size_per_partition_after_pad,
|
num_experts,
|
||||||
hidden_size // mxfp4_block,
|
2 * intermediate_size_per_partition_after_pad,
|
||||||
dtype=scale_dtype),
|
hidden_size // mxfp4_block,
|
||||||
requires_grad=False)
|
dtype=scale_dtype,
|
||||||
|
),
|
||||||
|
requires_grad=False,
|
||||||
|
)
|
||||||
layer.register_parameter("w13_weight_scale", w13_weight_scale)
|
layer.register_parameter("w13_weight_scale", w13_weight_scale)
|
||||||
set_weight_attrs(w13_weight_scale, extra_weight_attrs)
|
set_weight_attrs(w13_weight_scale, extra_weight_attrs)
|
||||||
|
|
||||||
w13_bias = torch.nn.Parameter(torch.zeros(
|
w13_bias = torch.nn.Parameter(
|
||||||
num_experts,
|
torch.zeros(
|
||||||
2 * intermediate_size_per_partition_after_pad,
|
num_experts,
|
||||||
dtype=torch.bfloat16),
|
2 * intermediate_size_per_partition_after_pad,
|
||||||
requires_grad=False)
|
dtype=torch.bfloat16,
|
||||||
|
),
|
||||||
|
requires_grad=False,
|
||||||
|
)
|
||||||
layer.register_parameter("w13_bias", w13_bias)
|
layer.register_parameter("w13_bias", w13_bias)
|
||||||
set_weight_attrs(w13_bias, extra_weight_attrs)
|
set_weight_attrs(w13_bias, extra_weight_attrs)
|
||||||
|
|
||||||
# down_proj (row parallel)
|
# down_proj (row parallel)
|
||||||
w2_weight = torch.nn.Parameter(torch.zeros(
|
w2_weight = torch.nn.Parameter(
|
||||||
num_experts,
|
torch.zeros(
|
||||||
hidden_size,
|
num_experts,
|
||||||
intermediate_size_per_partition_after_pad // 2,
|
hidden_size,
|
||||||
dtype=weight_dtype),
|
intermediate_size_per_partition_after_pad // 2,
|
||||||
requires_grad=False)
|
dtype=weight_dtype,
|
||||||
|
),
|
||||||
|
requires_grad=False,
|
||||||
|
)
|
||||||
layer.register_parameter("w2_weight", w2_weight)
|
layer.register_parameter("w2_weight", w2_weight)
|
||||||
set_weight_attrs(w2_weight, extra_weight_attrs)
|
set_weight_attrs(w2_weight, extra_weight_attrs)
|
||||||
|
|
||||||
w2_weight_scale = torch.nn.Parameter(torch.zeros(
|
w2_weight_scale = torch.nn.Parameter(
|
||||||
num_experts,
|
torch.zeros(
|
||||||
hidden_size,
|
num_experts,
|
||||||
intermediate_size_per_partition_after_pad // mxfp4_block,
|
hidden_size,
|
||||||
dtype=scale_dtype),
|
intermediate_size_per_partition_after_pad // mxfp4_block,
|
||||||
requires_grad=False)
|
dtype=scale_dtype,
|
||||||
|
),
|
||||||
|
requires_grad=False,
|
||||||
|
)
|
||||||
layer.register_parameter("w2_weight_scale", w2_weight_scale)
|
layer.register_parameter("w2_weight_scale", w2_weight_scale)
|
||||||
set_weight_attrs(w2_weight_scale, extra_weight_attrs)
|
set_weight_attrs(w2_weight_scale, extra_weight_attrs)
|
||||||
|
|
||||||
w2_bias = torch.nn.Parameter(torch.zeros(num_experts,
|
w2_bias = torch.nn.Parameter(
|
||||||
hidden_size,
|
torch.zeros(
|
||||||
dtype=torch.bfloat16),
|
num_experts,
|
||||||
requires_grad=False)
|
hidden_size,
|
||||||
|
dtype=torch.bfloat16,
|
||||||
|
),
|
||||||
|
requires_grad=False,
|
||||||
|
)
|
||||||
layer.register_parameter("w2_bias", w2_bias)
|
layer.register_parameter("w2_bias", w2_bias)
|
||||||
set_weight_attrs(w2_bias, extra_weight_attrs)
|
set_weight_attrs(w2_bias, extra_weight_attrs)
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user