diff --git a/vllm/model_executor/layers/fused_moe/config.py b/vllm/model_executor/layers/fused_moe/config.py index f2242ade0c0f1..31ea826f1f97a 100644 --- a/vllm/model_executor/layers/fused_moe/config.py +++ b/vllm/model_executor/layers/fused_moe/config.py @@ -324,6 +324,8 @@ class FusedMoEConfig: max_num_tokens: int = envs.VLLM_MOE_DP_CHUNK_SIZE + has_bias: bool = False + def __post_init__(self): if self.dp_size > 1: logger.debug_once("Using FusedMoEConfig::max_num_tokens=%d", @@ -413,7 +415,8 @@ class FusedMoEConfig: in_dtype: torch.dtype, max_num_tokens: int = envs.VLLM_MOE_DP_CHUNK_SIZE, quant_config: Optional[Union[FusedMoEQuantConfig, - QuantizationConfig]] = None + QuantizationConfig]] = None, + has_bias: bool = False, ) -> "FusedMoEConfig": _quant_config: Optional[FusedMoEQuantConfig] = None @@ -482,4 +485,5 @@ class FusedMoEConfig: in_dtype=in_dtype, quant_config=_quant_config, max_num_tokens=max_num_tokens, + has_bias=has_bias, ) diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index f4f5457ebcd03..3ad5f5b7ad31d 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -275,6 +275,7 @@ def fused_moe_kernel( a_ptr, b_ptr, c_ptr, + b_bias_ptr, a_scale_ptr, b_scale_ptr, topk_weights_ptr, @@ -302,6 +303,8 @@ def fused_moe_kernel( stride_bse, stride_bsk, stride_bsn, + stride_bbe, # bias expert stride + stride_bbn, # bias N stride # Block size for block-wise quantization group_n: tl.constexpr, group_k: tl.constexpr, @@ -317,6 +320,7 @@ def fused_moe_kernel( use_int8_w8a8: tl.constexpr, use_int8_w8a16: tl.constexpr, per_channel_quant: tl.constexpr, + HAS_BIAS: tl.constexpr, ): """ Implements the fused computation for a Mixture of Experts (MOE) using @@ -414,7 +418,10 @@ def fused_moe_kernel( else: a_scale = tl.load(a_scale_ptr) b_scale = tl.load(b_scale_ptr + off_experts) - + if HAS_BIAS: + # bias shape: [num_experts, N] + bias_ptrs = b_bias_ptr + off_experts * stride_bbe + offs_bn * stride_bbn + bias = tl.load(bias_ptrs, mask=(offs_bn < N), other=0.0) # ----------------------------------------------------------- # Iterate to compute a block of the C matrix. # We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block @@ -456,7 +463,8 @@ def fused_moe_kernel( # Advance the ptrs to the next K block. a_ptrs += BLOCK_SIZE_K * stride_ak b_ptrs += BLOCK_SIZE_K * stride_bk - + if HAS_BIAS: + accumulator = accumulator + bias[None, :] if MUL_ROUTED_WEIGHT: moe_weight = tl.load(topk_weights_ptr + offs_token, mask=token_mask, @@ -471,6 +479,7 @@ def fused_moe_kernel( accumulator = (accumulator * a_scale * b_scale).to(compute_type) else: accumulator = accumulator.to(compute_type) + # ----------------------------------------------------------- # Write back the block of the output offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) @@ -499,7 +508,8 @@ def invoke_fused_moe_kernel(A: torch.Tensor, use_int8_w8a16: bool, use_int4_w4a16: bool, per_channel_quant: bool, - block_shape: Optional[list[int]] = None) -> None: + block_shape: Optional[list[int]] = None, + B_bias: Optional[torch.Tensor] = None) -> None: assert topk_weights is not None or not mul_routed_weight assert topk_weights is None or topk_weights.stride(1) == 1 assert sorted_token_ids.stride(0) == 1 @@ -531,7 +541,7 @@ def invoke_fused_moe_kernel(A: torch.Tensor, A.size(0) * top_k * config['BLOCK_SIZE_M']) grid = lambda META: (triton.cdiv(EM, META['BLOCK_SIZE_M']) * triton.cdiv( B.size(1), META['BLOCK_SIZE_N']), ) - + HAS_BIAS = B_bias is not None if (use_int8_w8a16 or use_int4_w4a16) and \ block_shape is not None and block_shape[1] > 0: assert B_scale is not None and B_scale.ndim == 3 @@ -611,6 +621,7 @@ def invoke_fused_moe_kernel(A: torch.Tensor, A, B, C, + B_bias, A_scale, B_scale, topk_weights, @@ -638,6 +649,8 @@ def invoke_fused_moe_kernel(A: torch.Tensor, if B_scale is not None and B_scale.ndim == 3 else 0, B_scale.stride(1) if B_scale is not None and B_scale.ndim >= 2 else 0, + B_bias.stride(0) if B_bias is not None else 0, + B_bias.stride(1) if B_bias is not None else 0, 0 if block_shape is None else block_shape[0], 0 if block_shape is None else block_shape[1], MUL_ROUTED_WEIGHT=mul_routed_weight, @@ -647,6 +660,7 @@ def invoke_fused_moe_kernel(A: torch.Tensor, use_int8_w8a8=use_int8_w8a8, use_int8_w8a16=use_int8_w8a16, per_channel_quant=per_channel_quant, + HAS_BIAS=HAS_BIAS, BLOCK_SIZE_K=BLOCK_SIZE_K, **config, ) @@ -1024,40 +1038,43 @@ def inplace_fused_experts( w2_zp: Optional[torch.Tensor] = None, a1_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None, - block_shape: Optional[List[int]] = None) -> None: #noqa: UP006 + block_shape: Optional[List[int]] = None, + w1_bias: Optional[torch.Tensor] = None, + w2_bias: Optional[torch.Tensor] = None) -> None: #noqa: UP006 fused_experts_impl(hidden_states, w1, w2, topk_weights, topk_ids, True, activation, is_act_and_mul, apply_router_weight_on_input, use_fp8_w8a8, use_int8_w8a8, use_int8_w8a16, use_int4_w4a16, use_mxfp4_w4a4, per_channel_quant, global_num_experts, expert_map, w1_scale, w2_scale, w1_zp, w2_zp, a1_scale, - a2_scale, block_shape) + a2_scale, block_shape, w1_bias, w2_bias) -def inplace_fused_experts_fake( - hidden_states: torch.Tensor, - w1: torch.Tensor, - w2: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - activation: str = "silu", - is_act_and_mul: bool = True, - apply_router_weight_on_input: bool = False, - use_fp8_w8a8: bool = False, - use_int8_w8a8: bool = False, - use_int8_w8a16: bool = False, - use_int4_w4a16: bool = False, - use_mxfp4_w4a4: bool = False, - per_channel_quant: bool = False, - global_num_experts: int = -1, - expert_map: Optional[torch.Tensor] = None, - w1_scale: Optional[torch.Tensor] = None, - w2_scale: Optional[torch.Tensor] = None, - w1_zp: Optional[torch.Tensor] = None, - w2_zp: Optional[torch.Tensor] = None, - a1_scale: Optional[torch.Tensor] = None, - a2_scale: Optional[torch.Tensor] = None, - block_shape: Optional[list[int]] = None) -> None: +def inplace_fused_experts_fake(hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + activation: str = "silu", + is_act_and_mul: bool = True, + apply_router_weight_on_input: bool = False, + use_fp8_w8a8: bool = False, + use_int8_w8a8: bool = False, + use_int8_w8a16: bool = False, + use_int4_w4a16: bool = False, + use_mxfp4_w4a4: bool = False, + per_channel_quant: bool = False, + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, + w1_scale: Optional[torch.Tensor] = None, + w2_scale: Optional[torch.Tensor] = None, + w1_zp: Optional[torch.Tensor] = None, + w2_zp: Optional[torch.Tensor] = None, + a1_scale: Optional[torch.Tensor] = None, + a2_scale: Optional[torch.Tensor] = None, + block_shape: Optional[list[int]] = None, + w1_bias: Optional[torch.Tensor] = None, + w2_bias: Optional[torch.Tensor] = None) -> None: pass @@ -1246,36 +1263,38 @@ direct_register_custom_op( def outplace_fused_experts( - hidden_states: torch.Tensor, - w1: torch.Tensor, - w2: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - activation: str = "silu", - is_act_and_mul: bool = True, - apply_router_weight_on_input: bool = False, - use_fp8_w8a8: bool = False, - use_int8_w8a8: bool = False, - use_int8_w8a16: bool = False, - use_int4_w4a16: bool = False, - use_mxfp4_w4a4: bool = False, - per_channel_quant: bool = False, - global_num_experts: int = -1, - expert_map: Optional[torch.Tensor] = None, - w1_scale: Optional[torch.Tensor] = None, - w2_scale: Optional[torch.Tensor] = None, - w1_zp: Optional[torch.Tensor] = None, - w2_zp: Optional[torch.Tensor] = None, - a1_scale: Optional[torch.Tensor] = None, - a2_scale: Optional[torch.Tensor] = None, - block_shape: Optional[List[int]] = None, #noqa: UP006 + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + activation: str = "silu", + is_act_and_mul: bool = True, + apply_router_weight_on_input: bool = False, + use_fp8_w8a8: bool = False, + use_int8_w8a8: bool = False, + use_int8_w8a16: bool = False, + use_int4_w4a16: bool = False, + use_mxfp4_w4a4: bool = False, + per_channel_quant: bool = False, + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, + w1_scale: Optional[torch.Tensor] = None, + w2_scale: Optional[torch.Tensor] = None, + w1_zp: Optional[torch.Tensor] = None, + w2_zp: Optional[torch.Tensor] = None, + a1_scale: Optional[torch.Tensor] = None, + a2_scale: Optional[torch.Tensor] = None, + block_shape: Optional[List[int]] = None, #noqa: UP006 + w1_bias: Optional[torch.Tensor] = None, + w2_bias: Optional[torch.Tensor] = None, ) -> torch.Tensor: return fused_experts_impl( hidden_states, w1, w2, topk_weights, topk_ids, False, activation, is_act_and_mul, apply_router_weight_on_input, use_fp8_w8a8, use_int8_w8a8, use_int8_w8a16, use_int4_w4a16, use_mxfp4_w4a4, per_channel_quant, global_num_experts, expert_map, w1_scale, w2_scale, - w1_zp, w2_zp, a1_scale, a2_scale, block_shape) + w1_zp, w2_zp, a1_scale, a2_scale, block_shape, w1_bias, w2_bias) def outplace_fused_experts_fake( @@ -1300,7 +1319,9 @@ def outplace_fused_experts_fake( w2_zp: Optional[torch.Tensor] = None, a1_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None, - block_shape: Optional[list[int]] = None) -> torch.Tensor: + block_shape: Optional[list[int]] = None, + w1_bias: Optional[torch.Tensor] = None, + w2_bias: Optional[torch.Tensor] = None) -> torch.Tensor: return torch.empty_like(hidden_states) @@ -1332,33 +1353,34 @@ def dispatch_fused_experts_func(inplace: bool) -> Callable[..., torch.Tensor]: # TODO (bnell): replace this with modular op. Can get rid of inplace/outplace # torch ops. -def fused_experts( - hidden_states: torch.Tensor, - w1: torch.Tensor, - w2: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - inplace: bool = False, - activation: str = "silu", - is_act_and_mul: bool = True, - apply_router_weight_on_input: bool = False, - use_fp8_w8a8: bool = False, - use_int8_w8a8: bool = False, - use_int8_w8a16: bool = False, - use_int4_w4a16: bool = False, - use_mxfp4_w4a4: bool = False, - per_channel_quant: bool = False, - global_num_experts: int = -1, - expert_map: Optional[torch.Tensor] = None, - w1_scale: Optional[torch.Tensor] = None, - w2_scale: Optional[torch.Tensor] = None, - w1_zp: Optional[torch.Tensor] = None, - w2_zp: Optional[torch.Tensor] = None, - a1_scale: Optional[torch.Tensor] = None, - a2_scale: Optional[torch.Tensor] = None, - block_shape: Optional[list[int]] = None, - allow_deep_gemm: bool = False, - allow_cutlass_block_scaled_grouped_gemm: bool = False) -> torch.Tensor: +def fused_experts(hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + inplace: bool = False, + activation: str = "silu", + is_act_and_mul: bool = True, + apply_router_weight_on_input: bool = False, + use_fp8_w8a8: bool = False, + use_int8_w8a8: bool = False, + use_int8_w8a16: bool = False, + use_int4_w4a16: bool = False, + use_mxfp4_w4a4: bool = False, + per_channel_quant: bool = False, + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, + w1_scale: Optional[torch.Tensor] = None, + w2_scale: Optional[torch.Tensor] = None, + w1_zp: Optional[torch.Tensor] = None, + w2_zp: Optional[torch.Tensor] = None, + a1_scale: Optional[torch.Tensor] = None, + a2_scale: Optional[torch.Tensor] = None, + block_shape: Optional[list[int]] = None, + allow_deep_gemm: bool = False, + allow_cutlass_block_scaled_grouped_gemm: bool = False, + w1_bias: Optional[torch.Tensor] = None, + w2_bias: Optional[torch.Tensor] = None) -> torch.Tensor: # For now, disable DeepGemm for small N (<= 512) until better # permute/unpermute ops are available. # However, on B200, we use DeepGemm for all cases because they only support @@ -1423,7 +1445,10 @@ def fused_experts( w2_zp=w2_zp, a1_scale=a1_scale, a2_scale=a2_scale, - block_shape=block_shape) + block_shape=block_shape, + w1_bias=w1_bias, + w2_bias=w2_bias, + ) def fused_experts_impl( @@ -1451,6 +1476,8 @@ def fused_experts_impl( a1_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None, block_shape: Optional[list[int]] = None, + w1_bias: Optional[torch.Tensor] = None, + w2_bias: Optional[torch.Tensor] = None, ) -> torch.Tensor: # Check constraints. if use_int4_w4a16: @@ -1591,7 +1618,19 @@ def fused_experts_impl( use_int8_w8a16=use_int8_w8a16, use_int4_w4a16=use_int4_w4a16, per_channel_quant=per_channel_quant, - block_shape=block_shape) + block_shape=block_shape, + B_bias=w1_bias) + + # TODO fused kernel + def swiglu_oai(gate_up): + alpha = 1.702 + limit = 7.0 + gate, up = gate_up[..., ::2], gate_up[..., 1::2] + gate = gate.clamp(min=None, max=limit) + up = up.clamp(min=-limit, max=limit) + glu = gate * torch.sigmoid(gate * alpha) + gated_output = (up + 1) * glu + return gated_output # Activation function with multiplication if activation == "silu" and is_act_and_mul: @@ -1605,6 +1644,8 @@ def fused_experts_impl( intermediate_cache2 = F.silu(intermediate_cache1.view(-1, N)) elif activation == "gelu": intermediate_cache2 = F.gelu(intermediate_cache1.view(-1, N)) + elif activation == "swiglu_oai": + intermediate_cache2 = swiglu_oai(intermediate_cache1.view(-1, N)) else: raise ValueError(f"Unsupported FusedMoe activation: {activation}, " f"with is_act_and_mul={is_act_and_mul}.") @@ -1635,7 +1676,8 @@ def fused_experts_impl( use_int8_w8a16=use_int8_w8a16, use_int4_w4a16=use_int4_w4a16, per_channel_quant=per_channel_quant, - block_shape=block_shape) + block_shape=block_shape, + B_bias=w2_bias) ops.moe_sum(intermediate_cache3.view(*intermediate_cache3.size()), out_hidden_states[begin_chunk_idx:end_chunk_idx]) @@ -1672,6 +1714,8 @@ def fused_moe( a1_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None, block_shape: Optional[list[int]] = None, + w1_bias: Optional[torch.Tensor] = None, + w2_bias: Optional[torch.Tensor] = None, ) -> torch.Tensor: """ This function computes a Mixture of Experts (MoE) layer using two sets of @@ -1766,7 +1810,9 @@ def fused_moe( w2_zp=w2_zp, a1_scale=a1_scale, a2_scale=a2_scale, - block_shape=block_shape) + block_shape=block_shape, + w1_bias=w1_bias, + w2_bias=w2_bias) class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute): @@ -1937,7 +1983,9 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute): use_int8_w8a16=self.use_int8_w8a16, use_int4_w4a16=self.use_int4_w4a16, per_channel_quant=self.per_act_token_quant, - block_shape=self.block_shape) + block_shape=self.block_shape, + B_bias=None # TODO support B_bias + ) self.activation(activation, intermediate_cache2, intermediate_cache1.view(-1, N)) @@ -1948,26 +1996,29 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute): intermediate_cache2, a2_scale, self.quant_dtype, self.per_act_token_quant, self.block_shape) - invoke_fused_moe_kernel(qintermediate_cache2, - w2, - intermediate_cache3, - a2q_scale, - w2_scale, - w2_zp, - topk_weights, - sorted_token_ids, - expert_ids, - num_tokens_post_padded, - not apply_router_weight_on_input, - 1, - config, - compute_type=compute_type, - use_fp8_w8a8=self.use_fp8_w8a8, - use_int8_w8a8=self.use_int8_w8a8, - use_int8_w8a16=self.use_int8_w8a16, - use_int4_w4a16=self.use_int4_w4a16, - per_channel_quant=self.per_act_token_quant, - block_shape=self.block_shape) + invoke_fused_moe_kernel( + qintermediate_cache2, + w2, + intermediate_cache3, + a2q_scale, + w2_scale, + w2_zp, + topk_weights, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + not apply_router_weight_on_input, + 1, + config, + compute_type=compute_type, + use_fp8_w8a8=self.use_fp8_w8a8, + use_int8_w8a8=self.use_int8_w8a8, + use_int8_w8a16=self.use_int8_w8a16, + use_int4_w4a16=self.use_int4_w4a16, + per_channel_quant=self.per_act_token_quant, + block_shape=self.block_shape, + B_bias=None # TODO support B_bias + ) ops.moe_sum(intermediate_cache3, output) diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index d664a92841bbe..d5a89655e36d6 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -255,7 +255,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): self.fused_experts = fused_experts # type: ignore self.topk_indices_dtype = None self.moe = moe - + self.has_bias = self.moe.has_bias self.rocm_aiter_moe_enabled = is_rocm_aiter_moe_enabled() if self.rocm_aiter_moe_enabled: from .rocm_aiter_fused_moe import rocm_aiter_fused_experts @@ -291,7 +291,14 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): requires_grad=False) layer.register_parameter("w13_weight", w13_weight) set_weight_attrs(w13_weight, extra_weight_attrs) - + if self.has_bias: + w13_bias = torch.nn.Parameter(torch.zeros( + num_experts, + 2 * intermediate_size_per_partition, + dtype=params_dtype), + requires_grad=False) + layer.register_parameter("w13_bias", w13_bias) + set_weight_attrs(w13_bias, extra_weight_attrs) # down_proj (row parallel) w2_weight = torch.nn.Parameter(torch.empty( num_experts, @@ -301,6 +308,13 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): requires_grad=False) layer.register_parameter("w2_weight", w2_weight) set_weight_attrs(w2_weight, extra_weight_attrs) + if self.has_bias: + w2_bias = torch.nn.Parameter(torch.zeros(num_experts, + hidden_size, + dtype=params_dtype), + requires_grad=False) + layer.register_parameter("w2_bias", w2_bias) + set_weight_attrs(w2_bias, extra_weight_attrs) def _maybe_pad_weight(self, weight: torch.Tensor) -> torch.Tensor: # Pad the weight tensor. This is an optimization on ROCm platform, which @@ -465,6 +479,8 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): hidden_states=x, w1=layer.w13_weight, w2=layer.w2_weight, + w1_bias=layer.w13_bias if self.has_bias else None, + w2_bias=layer.w2_bias if self.has_bias else None, topk_weights=topk_weights, topk_ids=topk_ids, inplace=True, @@ -702,6 +718,7 @@ class FusedMoE(torch.nn.Module): activation: str = "silu", enable_eplb: bool = False, num_redundant_experts: int = 0, + has_bias: bool = False, ): super().__init__() if params_dtype is None: @@ -793,16 +810,15 @@ class FusedMoE(torch.nn.Module): # since model_config is not set in the pytest test. model_dtype = params_dtype - moe = FusedMoEConfig.make( - num_experts=self.global_num_experts, - experts_per_token=top_k, - hidden_dim=hidden_size, - num_local_experts=self.local_num_experts, - moe_parallel_config=self.moe_parallel_config, - in_dtype=model_dtype, - max_num_tokens=envs.VLLM_MOE_DP_CHUNK_SIZE, - quant_config=quant_config, - ) + moe = FusedMoEConfig.make(num_experts=self.global_num_experts, + experts_per_token=top_k, + hidden_dim=hidden_size, + num_local_experts=self.local_num_experts, + moe_parallel_config=self.moe_parallel_config, + in_dtype=model_dtype, + max_num_tokens=envs.VLLM_MOE_DP_CHUNK_SIZE, + quant_config=quant_config, + has_bias=has_bias) self.moe_config = moe self.quant_config = quant_config diff --git a/vllm/model_executor/models/gpt_oss.py b/vllm/model_executor/models/gpt_oss.py index feb323a04524b..6a65bbbe2e0db 100644 --- a/vllm/model_executor/models/gpt_oss.py +++ b/vllm/model_executor/models/gpt_oss.py @@ -160,7 +160,9 @@ class MLPBlock(torch.nn.Module): renormalize=True, quant_config=quant_config, prefix=f"{prefix}.experts", - apply_router_weight_on_input=False) + apply_router_weight_on_input=False, + has_bias=True, + activation="swiglu_oai") def forward(self, x: torch.Tensor) -> torch.Tensor: t = self.norm(x) @@ -262,8 +264,8 @@ class GptOssForCausalLM(nn.Module): sampling_metadata) return logits - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def _load_weights_mxfp4( + self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: rename_mapping = { "self_attn": "attn", "input_layernorm.weight": "attn.norm.weight", @@ -469,3 +471,147 @@ class GptOssForCausalLM(nn.Module): loaded_params.add(renamed_name) return loaded_params + + def _load_weights_other( + self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + rename_mapping = { + "self_attn": "attn", + "input_layernorm.weight": "attn.norm.weight", + "post_attention_layernorm.weight": "mlp.norm.weight", + "embed_tokens": "embedding", + } + + def maybe_rename(name: str) -> str: + for remap_name, new_name in rename_mapping.items(): + if remap_name in name: + return name.replace(remap_name, new_name) + return name + + params_dict = dict(self.named_parameters()) + loaded_params: set[str] = set() + + tp_rank = get_tensor_model_parallel_rank() + tp_size = get_tensor_model_parallel_world_size() + intermediate_size = self.model_config.intermediate_size + + per_rank_intermediate_size = cdiv(intermediate_size, tp_size) + # Calculate common slicing bounds for current rank + tp_rank_start = tp_rank * per_rank_intermediate_size + tp_rank_end = min((tp_rank + 1) * per_rank_intermediate_size, + intermediate_size) + + # Attention heads per rank + heads_per_rank = self.model_config.num_attention_heads // tp_size + head_start = tp_rank * heads_per_rank + + use_ep = self.vllm_config.parallel_config.enable_expert_parallel + ep_size = get_ep_group().world_size + ep_rank = get_ep_group().rank + num_experts = self.model_config.num_local_experts + experts_per_rank = num_experts // ep_size + ep_rank_start = ep_rank * experts_per_rank + ep_rank_end = (ep_rank + 1) * experts_per_rank + + for name, weight in weights: + if ".experts.gate_up_proj" in name and "bias" not in name: + # Handle MLP gate and up projection weights + new_name = name.replace(".experts.gate_up_proj", + ".experts.w13_weight") + + # Extract gate and up projection parts + # since the weight is shuffled, we can slice directly + if use_ep: + narrow_weight = weight[ep_rank_start:ep_rank_end, ...] + else: + narrow_weight = weight[:, :, + 2 * tp_rank_start:2 * tp_rank_end] + + narrow_weight = narrow_weight.permute(0, 2, 1).contiguous() + param = params_dict[new_name] + + param.copy_(narrow_weight) + loaded_params.add(new_name) + + elif ".experts.down_proj" in name and "bias" not in name: + # Handle MLP down projection weights + new_name = name.replace(".experts.down_proj", + ".experts.w2_weight") + + if use_ep: + narrow_weight = weight[ep_rank_start:ep_rank_end, ...] + else: + narrow_weight = weight[:, tp_rank_start:tp_rank_end, :] + narrow_weight = narrow_weight.permute(0, 2, 1).contiguous() + param = params_dict[new_name] + + param.copy_(narrow_weight) + loaded_params.add(new_name) + + elif "gate_up_proj_bias" in name: + # Handle MLP gate and up projection biases + new_name = name.replace("gate_up_proj_bias", "w13_bias") + + # Extract gate and up projection bias parts + if use_ep: + narrow_weight = weight[ep_rank_start:ep_rank_end, ...] + else: + narrow_weight = weight[:, + 2 * tp_rank_start:2 * tp_rank_end] + + param = params_dict[new_name] + + param.copy_(narrow_weight) + loaded_params.add(new_name) + + elif "down_proj_bias" in name: + # Handle MLP down projection bias + new_name = name.replace("down_proj_bias", "w2_bias") + + if use_ep: + weight = weight[ep_rank_start:ep_rank_end, ...] + else: + # (only load on rank 0 to avoid duplication) + if tp_rank != 0: + weight.zero_() + param = params_dict[new_name] + param.copy_(weight) + loaded_params.add(new_name) + elif "sinks" in name: + # Handle attention sinks (distributed across ranks) + name = name.replace("self_attn", "attn") + param = params_dict[name] + narrow_weight = weight.narrow(0, head_start, heads_per_rank) + param.data.copy_(narrow_weight) + loaded_params.add(name) + elif "q_proj" in name or "k_proj" in name or "v_proj" in name: + shard_id = ("q" if "q_proj" in name else + "k" if "k_proj" in name else "v") + name = name.replace("self_attn", "attn") + param_name = name.replace(f"{shard_id}_proj", "qkv") + param = params_dict[param_name] + weight_loader = param.weight_loader + weight_loader(param, weight, loaded_shard_id=shard_id) + loaded_params.add(param_name) + else: + # Handle all other weights with potential renaming + + renamed_name = maybe_rename(name) + if renamed_name not in params_dict: + continue + param = params_dict[renamed_name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, weight) + loaded_params.add(renamed_name) + + return loaded_params + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + quant_method = (self.model_config.quantization_config['quant_method'] + if hasattr(self.model_config, "quantization_config") + else None) + if quant_method == "mxfp4": + return self._load_weights_mxfp4(weights) + else: + return self._load_weights_other(weights)