[oss] Init gpt-oss bf16 support (#22508)

Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
Jee Jee Li 2025-08-10 11:19:13 +08:00 committed by GitHub
parent 61f67d8acd
commit 0c5254b82a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 340 additions and 123 deletions

View File

@ -324,6 +324,8 @@ class FusedMoEConfig:
max_num_tokens: int = envs.VLLM_MOE_DP_CHUNK_SIZE
has_bias: bool = False
def __post_init__(self):
if self.dp_size > 1:
logger.debug_once("Using FusedMoEConfig::max_num_tokens=%d",
@ -413,7 +415,8 @@ class FusedMoEConfig:
in_dtype: torch.dtype,
max_num_tokens: int = envs.VLLM_MOE_DP_CHUNK_SIZE,
quant_config: Optional[Union[FusedMoEQuantConfig,
QuantizationConfig]] = None
QuantizationConfig]] = None,
has_bias: bool = False,
) -> "FusedMoEConfig":
_quant_config: Optional[FusedMoEQuantConfig] = None
@ -482,4 +485,5 @@ class FusedMoEConfig:
in_dtype=in_dtype,
quant_config=_quant_config,
max_num_tokens=max_num_tokens,
has_bias=has_bias,
)

View File

@ -275,6 +275,7 @@ def fused_moe_kernel(
a_ptr,
b_ptr,
c_ptr,
b_bias_ptr,
a_scale_ptr,
b_scale_ptr,
topk_weights_ptr,
@ -302,6 +303,8 @@ def fused_moe_kernel(
stride_bse,
stride_bsk,
stride_bsn,
stride_bbe, # bias expert stride
stride_bbn, # bias N stride
# Block size for block-wise quantization
group_n: tl.constexpr,
group_k: tl.constexpr,
@ -317,6 +320,7 @@ def fused_moe_kernel(
use_int8_w8a8: tl.constexpr,
use_int8_w8a16: tl.constexpr,
per_channel_quant: tl.constexpr,
HAS_BIAS: tl.constexpr,
):
"""
Implements the fused computation for a Mixture of Experts (MOE) using
@ -414,7 +418,10 @@ def fused_moe_kernel(
else:
a_scale = tl.load(a_scale_ptr)
b_scale = tl.load(b_scale_ptr + off_experts)
if HAS_BIAS:
# bias shape: [num_experts, N]
bias_ptrs = b_bias_ptr + off_experts * stride_bbe + offs_bn * stride_bbn
bias = tl.load(bias_ptrs, mask=(offs_bn < N), other=0.0)
# -----------------------------------------------------------
# Iterate to compute a block of the C matrix.
# We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block
@ -456,7 +463,8 @@ def fused_moe_kernel(
# Advance the ptrs to the next K block.
a_ptrs += BLOCK_SIZE_K * stride_ak
b_ptrs += BLOCK_SIZE_K * stride_bk
if HAS_BIAS:
accumulator = accumulator + bias[None, :]
if MUL_ROUTED_WEIGHT:
moe_weight = tl.load(topk_weights_ptr + offs_token,
mask=token_mask,
@ -471,6 +479,7 @@ def fused_moe_kernel(
accumulator = (accumulator * a_scale * b_scale).to(compute_type)
else:
accumulator = accumulator.to(compute_type)
# -----------------------------------------------------------
# Write back the block of the output
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
@ -499,7 +508,8 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
use_int8_w8a16: bool,
use_int4_w4a16: bool,
per_channel_quant: bool,
block_shape: Optional[list[int]] = None) -> None:
block_shape: Optional[list[int]] = None,
B_bias: Optional[torch.Tensor] = None) -> None:
assert topk_weights is not None or not mul_routed_weight
assert topk_weights is None or topk_weights.stride(1) == 1
assert sorted_token_ids.stride(0) == 1
@ -531,7 +541,7 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
A.size(0) * top_k * config['BLOCK_SIZE_M'])
grid = lambda META: (triton.cdiv(EM, META['BLOCK_SIZE_M']) * triton.cdiv(
B.size(1), META['BLOCK_SIZE_N']), )
HAS_BIAS = B_bias is not None
if (use_int8_w8a16 or use_int4_w4a16) and \
block_shape is not None and block_shape[1] > 0:
assert B_scale is not None and B_scale.ndim == 3
@ -611,6 +621,7 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
A,
B,
C,
B_bias,
A_scale,
B_scale,
topk_weights,
@ -638,6 +649,8 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
if B_scale is not None and B_scale.ndim == 3 else 0,
B_scale.stride(1)
if B_scale is not None and B_scale.ndim >= 2 else 0,
B_bias.stride(0) if B_bias is not None else 0,
B_bias.stride(1) if B_bias is not None else 0,
0 if block_shape is None else block_shape[0],
0 if block_shape is None else block_shape[1],
MUL_ROUTED_WEIGHT=mul_routed_weight,
@ -647,6 +660,7 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
use_int8_w8a8=use_int8_w8a8,
use_int8_w8a16=use_int8_w8a16,
per_channel_quant=per_channel_quant,
HAS_BIAS=HAS_BIAS,
BLOCK_SIZE_K=BLOCK_SIZE_K,
**config,
)
@ -1024,40 +1038,43 @@ def inplace_fused_experts(
w2_zp: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[List[int]] = None) -> None: #noqa: UP006
block_shape: Optional[List[int]] = None,
w1_bias: Optional[torch.Tensor] = None,
w2_bias: Optional[torch.Tensor] = None) -> None: #noqa: UP006
fused_experts_impl(hidden_states, w1, w2, topk_weights, topk_ids, True,
activation, is_act_and_mul,
apply_router_weight_on_input, use_fp8_w8a8,
use_int8_w8a8, use_int8_w8a16, use_int4_w4a16,
use_mxfp4_w4a4, per_channel_quant, global_num_experts,
expert_map, w1_scale, w2_scale, w1_zp, w2_zp, a1_scale,
a2_scale, block_shape)
a2_scale, block_shape, w1_bias, w2_bias)
def inplace_fused_experts_fake(
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
activation: str = "silu",
is_act_and_mul: bool = True,
apply_router_weight_on_input: bool = False,
use_fp8_w8a8: bool = False,
use_int8_w8a8: bool = False,
use_int8_w8a16: bool = False,
use_int4_w4a16: bool = False,
use_mxfp4_w4a4: bool = False,
per_channel_quant: bool = False,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None,
w1_zp: Optional[torch.Tensor] = None,
w2_zp: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[list[int]] = None) -> None:
def inplace_fused_experts_fake(hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
activation: str = "silu",
is_act_and_mul: bool = True,
apply_router_weight_on_input: bool = False,
use_fp8_w8a8: bool = False,
use_int8_w8a8: bool = False,
use_int8_w8a16: bool = False,
use_int4_w4a16: bool = False,
use_mxfp4_w4a4: bool = False,
per_channel_quant: bool = False,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None,
w1_zp: Optional[torch.Tensor] = None,
w2_zp: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[list[int]] = None,
w1_bias: Optional[torch.Tensor] = None,
w2_bias: Optional[torch.Tensor] = None) -> None:
pass
@ -1246,36 +1263,38 @@ direct_register_custom_op(
def outplace_fused_experts(
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
activation: str = "silu",
is_act_and_mul: bool = True,
apply_router_weight_on_input: bool = False,
use_fp8_w8a8: bool = False,
use_int8_w8a8: bool = False,
use_int8_w8a16: bool = False,
use_int4_w4a16: bool = False,
use_mxfp4_w4a4: bool = False,
per_channel_quant: bool = False,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None,
w1_zp: Optional[torch.Tensor] = None,
w2_zp: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[List[int]] = None, #noqa: UP006
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
activation: str = "silu",
is_act_and_mul: bool = True,
apply_router_weight_on_input: bool = False,
use_fp8_w8a8: bool = False,
use_int8_w8a8: bool = False,
use_int8_w8a16: bool = False,
use_int4_w4a16: bool = False,
use_mxfp4_w4a4: bool = False,
per_channel_quant: bool = False,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None,
w1_zp: Optional[torch.Tensor] = None,
w2_zp: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[List[int]] = None, #noqa: UP006
w1_bias: Optional[torch.Tensor] = None,
w2_bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
return fused_experts_impl(
hidden_states, w1, w2, topk_weights, topk_ids, False, activation,
is_act_and_mul, apply_router_weight_on_input, use_fp8_w8a8,
use_int8_w8a8, use_int8_w8a16, use_int4_w4a16, use_mxfp4_w4a4,
per_channel_quant, global_num_experts, expert_map, w1_scale, w2_scale,
w1_zp, w2_zp, a1_scale, a2_scale, block_shape)
w1_zp, w2_zp, a1_scale, a2_scale, block_shape, w1_bias, w2_bias)
def outplace_fused_experts_fake(
@ -1300,7 +1319,9 @@ def outplace_fused_experts_fake(
w2_zp: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[list[int]] = None) -> torch.Tensor:
block_shape: Optional[list[int]] = None,
w1_bias: Optional[torch.Tensor] = None,
w2_bias: Optional[torch.Tensor] = None) -> torch.Tensor:
return torch.empty_like(hidden_states)
@ -1332,33 +1353,34 @@ def dispatch_fused_experts_func(inplace: bool) -> Callable[..., torch.Tensor]:
# TODO (bnell): replace this with modular op. Can get rid of inplace/outplace
# torch ops.
def fused_experts(
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
inplace: bool = False,
activation: str = "silu",
is_act_and_mul: bool = True,
apply_router_weight_on_input: bool = False,
use_fp8_w8a8: bool = False,
use_int8_w8a8: bool = False,
use_int8_w8a16: bool = False,
use_int4_w4a16: bool = False,
use_mxfp4_w4a4: bool = False,
per_channel_quant: bool = False,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None,
w1_zp: Optional[torch.Tensor] = None,
w2_zp: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[list[int]] = None,
allow_deep_gemm: bool = False,
allow_cutlass_block_scaled_grouped_gemm: bool = False) -> torch.Tensor:
def fused_experts(hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
inplace: bool = False,
activation: str = "silu",
is_act_and_mul: bool = True,
apply_router_weight_on_input: bool = False,
use_fp8_w8a8: bool = False,
use_int8_w8a8: bool = False,
use_int8_w8a16: bool = False,
use_int4_w4a16: bool = False,
use_mxfp4_w4a4: bool = False,
per_channel_quant: bool = False,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None,
w1_zp: Optional[torch.Tensor] = None,
w2_zp: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[list[int]] = None,
allow_deep_gemm: bool = False,
allow_cutlass_block_scaled_grouped_gemm: bool = False,
w1_bias: Optional[torch.Tensor] = None,
w2_bias: Optional[torch.Tensor] = None) -> torch.Tensor:
# For now, disable DeepGemm for small N (<= 512) until better
# permute/unpermute ops are available.
# However, on B200, we use DeepGemm for all cases because they only support
@ -1423,7 +1445,10 @@ def fused_experts(
w2_zp=w2_zp,
a1_scale=a1_scale,
a2_scale=a2_scale,
block_shape=block_shape)
block_shape=block_shape,
w1_bias=w1_bias,
w2_bias=w2_bias,
)
def fused_experts_impl(
@ -1451,6 +1476,8 @@ def fused_experts_impl(
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[list[int]] = None,
w1_bias: Optional[torch.Tensor] = None,
w2_bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
# Check constraints.
if use_int4_w4a16:
@ -1591,7 +1618,19 @@ def fused_experts_impl(
use_int8_w8a16=use_int8_w8a16,
use_int4_w4a16=use_int4_w4a16,
per_channel_quant=per_channel_quant,
block_shape=block_shape)
block_shape=block_shape,
B_bias=w1_bias)
# TODO fused kernel
def swiglu_oai(gate_up):
alpha = 1.702
limit = 7.0
gate, up = gate_up[..., ::2], gate_up[..., 1::2]
gate = gate.clamp(min=None, max=limit)
up = up.clamp(min=-limit, max=limit)
glu = gate * torch.sigmoid(gate * alpha)
gated_output = (up + 1) * glu
return gated_output
# Activation function with multiplication
if activation == "silu" and is_act_and_mul:
@ -1605,6 +1644,8 @@ def fused_experts_impl(
intermediate_cache2 = F.silu(intermediate_cache1.view(-1, N))
elif activation == "gelu":
intermediate_cache2 = F.gelu(intermediate_cache1.view(-1, N))
elif activation == "swiglu_oai":
intermediate_cache2 = swiglu_oai(intermediate_cache1.view(-1, N))
else:
raise ValueError(f"Unsupported FusedMoe activation: {activation}, "
f"with is_act_and_mul={is_act_and_mul}.")
@ -1635,7 +1676,8 @@ def fused_experts_impl(
use_int8_w8a16=use_int8_w8a16,
use_int4_w4a16=use_int4_w4a16,
per_channel_quant=per_channel_quant,
block_shape=block_shape)
block_shape=block_shape,
B_bias=w2_bias)
ops.moe_sum(intermediate_cache3.view(*intermediate_cache3.size()),
out_hidden_states[begin_chunk_idx:end_chunk_idx])
@ -1672,6 +1714,8 @@ def fused_moe(
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[list[int]] = None,
w1_bias: Optional[torch.Tensor] = None,
w2_bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
This function computes a Mixture of Experts (MoE) layer using two sets of
@ -1766,7 +1810,9 @@ def fused_moe(
w2_zp=w2_zp,
a1_scale=a1_scale,
a2_scale=a2_scale,
block_shape=block_shape)
block_shape=block_shape,
w1_bias=w1_bias,
w2_bias=w2_bias)
class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
@ -1937,7 +1983,9 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
use_int8_w8a16=self.use_int8_w8a16,
use_int4_w4a16=self.use_int4_w4a16,
per_channel_quant=self.per_act_token_quant,
block_shape=self.block_shape)
block_shape=self.block_shape,
B_bias=None # TODO support B_bias
)
self.activation(activation, intermediate_cache2,
intermediate_cache1.view(-1, N))
@ -1948,26 +1996,29 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
intermediate_cache2, a2_scale, self.quant_dtype,
self.per_act_token_quant, self.block_shape)
invoke_fused_moe_kernel(qintermediate_cache2,
w2,
intermediate_cache3,
a2q_scale,
w2_scale,
w2_zp,
topk_weights,
sorted_token_ids,
expert_ids,
num_tokens_post_padded,
not apply_router_weight_on_input,
1,
config,
compute_type=compute_type,
use_fp8_w8a8=self.use_fp8_w8a8,
use_int8_w8a8=self.use_int8_w8a8,
use_int8_w8a16=self.use_int8_w8a16,
use_int4_w4a16=self.use_int4_w4a16,
per_channel_quant=self.per_act_token_quant,
block_shape=self.block_shape)
invoke_fused_moe_kernel(
qintermediate_cache2,
w2,
intermediate_cache3,
a2q_scale,
w2_scale,
w2_zp,
topk_weights,
sorted_token_ids,
expert_ids,
num_tokens_post_padded,
not apply_router_weight_on_input,
1,
config,
compute_type=compute_type,
use_fp8_w8a8=self.use_fp8_w8a8,
use_int8_w8a8=self.use_int8_w8a8,
use_int8_w8a16=self.use_int8_w8a16,
use_int4_w4a16=self.use_int4_w4a16,
per_channel_quant=self.per_act_token_quant,
block_shape=self.block_shape,
B_bias=None # TODO support B_bias
)
ops.moe_sum(intermediate_cache3, output)

View File

@ -255,7 +255,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
self.fused_experts = fused_experts # type: ignore
self.topk_indices_dtype = None
self.moe = moe
self.has_bias = self.moe.has_bias
self.rocm_aiter_moe_enabled = is_rocm_aiter_moe_enabled()
if self.rocm_aiter_moe_enabled:
from .rocm_aiter_fused_moe import rocm_aiter_fused_experts
@ -291,7 +291,14 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
requires_grad=False)
layer.register_parameter("w13_weight", w13_weight)
set_weight_attrs(w13_weight, extra_weight_attrs)
if self.has_bias:
w13_bias = torch.nn.Parameter(torch.zeros(
num_experts,
2 * intermediate_size_per_partition,
dtype=params_dtype),
requires_grad=False)
layer.register_parameter("w13_bias", w13_bias)
set_weight_attrs(w13_bias, extra_weight_attrs)
# down_proj (row parallel)
w2_weight = torch.nn.Parameter(torch.empty(
num_experts,
@ -301,6 +308,13 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
requires_grad=False)
layer.register_parameter("w2_weight", w2_weight)
set_weight_attrs(w2_weight, extra_weight_attrs)
if self.has_bias:
w2_bias = torch.nn.Parameter(torch.zeros(num_experts,
hidden_size,
dtype=params_dtype),
requires_grad=False)
layer.register_parameter("w2_bias", w2_bias)
set_weight_attrs(w2_bias, extra_weight_attrs)
def _maybe_pad_weight(self, weight: torch.Tensor) -> torch.Tensor:
# Pad the weight tensor. This is an optimization on ROCm platform, which
@ -465,6 +479,8 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
w1_bias=layer.w13_bias if self.has_bias else None,
w2_bias=layer.w2_bias if self.has_bias else None,
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=True,
@ -702,6 +718,7 @@ class FusedMoE(torch.nn.Module):
activation: str = "silu",
enable_eplb: bool = False,
num_redundant_experts: int = 0,
has_bias: bool = False,
):
super().__init__()
if params_dtype is None:
@ -793,16 +810,15 @@ class FusedMoE(torch.nn.Module):
# since model_config is not set in the pytest test.
model_dtype = params_dtype
moe = FusedMoEConfig.make(
num_experts=self.global_num_experts,
experts_per_token=top_k,
hidden_dim=hidden_size,
num_local_experts=self.local_num_experts,
moe_parallel_config=self.moe_parallel_config,
in_dtype=model_dtype,
max_num_tokens=envs.VLLM_MOE_DP_CHUNK_SIZE,
quant_config=quant_config,
)
moe = FusedMoEConfig.make(num_experts=self.global_num_experts,
experts_per_token=top_k,
hidden_dim=hidden_size,
num_local_experts=self.local_num_experts,
moe_parallel_config=self.moe_parallel_config,
in_dtype=model_dtype,
max_num_tokens=envs.VLLM_MOE_DP_CHUNK_SIZE,
quant_config=quant_config,
has_bias=has_bias)
self.moe_config = moe
self.quant_config = quant_config

View File

@ -160,7 +160,9 @@ class MLPBlock(torch.nn.Module):
renormalize=True,
quant_config=quant_config,
prefix=f"{prefix}.experts",
apply_router_weight_on_input=False)
apply_router_weight_on_input=False,
has_bias=True,
activation="swiglu_oai")
def forward(self, x: torch.Tensor) -> torch.Tensor:
t = self.norm(x)
@ -262,8 +264,8 @@ class GptOssForCausalLM(nn.Module):
sampling_metadata)
return logits
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
def _load_weights_mxfp4(
self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
rename_mapping = {
"self_attn": "attn",
"input_layernorm.weight": "attn.norm.weight",
@ -469,3 +471,147 @@ class GptOssForCausalLM(nn.Module):
loaded_params.add(renamed_name)
return loaded_params
def _load_weights_other(
self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
rename_mapping = {
"self_attn": "attn",
"input_layernorm.weight": "attn.norm.weight",
"post_attention_layernorm.weight": "mlp.norm.weight",
"embed_tokens": "embedding",
}
def maybe_rename(name: str) -> str:
for remap_name, new_name in rename_mapping.items():
if remap_name in name:
return name.replace(remap_name, new_name)
return name
params_dict = dict(self.named_parameters())
loaded_params: set[str] = set()
tp_rank = get_tensor_model_parallel_rank()
tp_size = get_tensor_model_parallel_world_size()
intermediate_size = self.model_config.intermediate_size
per_rank_intermediate_size = cdiv(intermediate_size, tp_size)
# Calculate common slicing bounds for current rank
tp_rank_start = tp_rank * per_rank_intermediate_size
tp_rank_end = min((tp_rank + 1) * per_rank_intermediate_size,
intermediate_size)
# Attention heads per rank
heads_per_rank = self.model_config.num_attention_heads // tp_size
head_start = tp_rank * heads_per_rank
use_ep = self.vllm_config.parallel_config.enable_expert_parallel
ep_size = get_ep_group().world_size
ep_rank = get_ep_group().rank
num_experts = self.model_config.num_local_experts
experts_per_rank = num_experts // ep_size
ep_rank_start = ep_rank * experts_per_rank
ep_rank_end = (ep_rank + 1) * experts_per_rank
for name, weight in weights:
if ".experts.gate_up_proj" in name and "bias" not in name:
# Handle MLP gate and up projection weights
new_name = name.replace(".experts.gate_up_proj",
".experts.w13_weight")
# Extract gate and up projection parts
# since the weight is shuffled, we can slice directly
if use_ep:
narrow_weight = weight[ep_rank_start:ep_rank_end, ...]
else:
narrow_weight = weight[:, :,
2 * tp_rank_start:2 * tp_rank_end]
narrow_weight = narrow_weight.permute(0, 2, 1).contiguous()
param = params_dict[new_name]
param.copy_(narrow_weight)
loaded_params.add(new_name)
elif ".experts.down_proj" in name and "bias" not in name:
# Handle MLP down projection weights
new_name = name.replace(".experts.down_proj",
".experts.w2_weight")
if use_ep:
narrow_weight = weight[ep_rank_start:ep_rank_end, ...]
else:
narrow_weight = weight[:, tp_rank_start:tp_rank_end, :]
narrow_weight = narrow_weight.permute(0, 2, 1).contiguous()
param = params_dict[new_name]
param.copy_(narrow_weight)
loaded_params.add(new_name)
elif "gate_up_proj_bias" in name:
# Handle MLP gate and up projection biases
new_name = name.replace("gate_up_proj_bias", "w13_bias")
# Extract gate and up projection bias parts
if use_ep:
narrow_weight = weight[ep_rank_start:ep_rank_end, ...]
else:
narrow_weight = weight[:,
2 * tp_rank_start:2 * tp_rank_end]
param = params_dict[new_name]
param.copy_(narrow_weight)
loaded_params.add(new_name)
elif "down_proj_bias" in name:
# Handle MLP down projection bias
new_name = name.replace("down_proj_bias", "w2_bias")
if use_ep:
weight = weight[ep_rank_start:ep_rank_end, ...]
else:
# (only load on rank 0 to avoid duplication)
if tp_rank != 0:
weight.zero_()
param = params_dict[new_name]
param.copy_(weight)
loaded_params.add(new_name)
elif "sinks" in name:
# Handle attention sinks (distributed across ranks)
name = name.replace("self_attn", "attn")
param = params_dict[name]
narrow_weight = weight.narrow(0, head_start, heads_per_rank)
param.data.copy_(narrow_weight)
loaded_params.add(name)
elif "q_proj" in name or "k_proj" in name or "v_proj" in name:
shard_id = ("q" if "q_proj" in name else
"k" if "k_proj" in name else "v")
name = name.replace("self_attn", "attn")
param_name = name.replace(f"{shard_id}_proj", "qkv")
param = params_dict[param_name]
weight_loader = param.weight_loader
weight_loader(param, weight, loaded_shard_id=shard_id)
loaded_params.add(param_name)
else:
# Handle all other weights with potential renaming
renamed_name = maybe_rename(name)
if renamed_name not in params_dict:
continue
param = params_dict[renamed_name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, weight)
loaded_params.add(renamed_name)
return loaded_params
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
quant_method = (self.model_config.quantization_config['quant_method']
if hasattr(self.model_config, "quantization_config")
else None)
if quant_method == "mxfp4":
return self._load_weights_mxfp4(weights)
else:
return self._load_weights_other(weights)