From 7b4424b404a4e8b9397881ac45df38dab632a90f Mon Sep 17 00:00:00 2001 From: Gregory Shtrasberg Date: Mon, 22 Dec 2025 18:31:58 +0000 Subject: [PATCH] Merge: ROCM:upstream_fp8_with_static_scales_gpt_oss Signed-off-by: Gregory Shtrasberg --- vllm/_aiter_ops.py | 6 + vllm/envs.py | 13 + .../model_executor/layers/fused_moe/config.py | 26 + .../fused_moe/gpt_oss_triton_kernels_moe.py | 175 ++++- vllm/model_executor/layers/fused_moe/layer.py | 74 ++- .../layers/quantization/quark/quark.py | 51 ++ .../layers/quantization/quark/quark_moe.py | 617 +++++++++++++++++- vllm/model_executor/models/gpt_oss.py | 300 +++++++++ 8 files changed, 1246 insertions(+), 16 deletions(-) diff --git a/vllm/_aiter_ops.py b/vllm/_aiter_ops.py index 59cfd8627cc5c..5ee8aabcda607 100644 --- a/vllm/_aiter_ops.py +++ b/vllm/_aiter_ops.py @@ -655,6 +655,7 @@ class rocm_aiter_ops: envs.VLLM_ROCM_USE_AITER_TRITON_FUSED_ADD_RMSNORM_PAD ) _FMOE_ENABLED = envs.VLLM_ROCM_USE_AITER_MOE + _CK_MXFP4_MOE = envs.VLLM_ROCM_USE_CK_MXFP4_MOE _MLA_ENABLED = envs.VLLM_ROCM_USE_AITER_MLA _PG_ATTN_ENABLED = envs.VLLM_ROCM_USE_AITER_PAGED_ATTN _MHA_ENABLED = envs.VLLM_ROCM_USE_AITER_MHA @@ -701,6 +702,11 @@ class rocm_aiter_ops: """ "Verifies device specs and availability of env variable.""" return cls._AITER_ENABLED and cls._FMOE_ENABLED + @classmethod + @if_aiter_supported + def is_mxfp4_aiter_moe(cls) -> bool: + return cls._AITER_ENABLED and cls._CK_MXFP4_MOE + @classmethod @if_aiter_supported def is_fusion_moe_shared_experts_enabled(cls) -> bool: diff --git a/vllm/envs.py b/vllm/envs.py index 23b53f5377e13..467f78a06fe82 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -114,6 +114,8 @@ if TYPE_CHECKING: VLLM_ROCM_USE_AITER_PAGED_ATTN: bool = False VLLM_ROCM_USE_AITER_LINEAR: bool = True VLLM_ROCM_USE_AITER_MOE: bool = True + VLLM_ROCM_USE_CK_MXFP4_MOE: bool = True + VLLM_ROCM_USE_AITER_FUSED_MOE_A16W4: bool = True VLLM_ROCM_USE_AITER_RMSNORM: bool = True VLLM_ROCM_USE_AITER_TRITON_FUSED_ADD_RMSNORM_PAD: bool = True VLLM_ROCM_USE_AITER_MLA: bool = True @@ -953,6 +955,17 @@ environment_variables: dict[str, Callable[[], Any]] = { "VLLM_ROCM_USE_AITER_MOE": lambda: ( os.getenv("VLLM_ROCM_USE_AITER_MOE", "True").lower() in ("true", "1") ), + # Whether to use aiter ck mxfp4 moe ops. + # By default is enabled. + "VLLM_ROCM_USE_CK_MXFP4_MOE": lambda: ( + os.getenv("VLLM_ROCM_USE_CK_MXFP4_MOE", "True").lower() in ("true", "1") + ), + # Whether to use aiter w4a16 moe ops. + # By default is disabled. + "VLLM_ROCM_USE_AITER_FUSED_MOE_A16W4": lambda: ( + os.getenv("VLLM_ROCM_USE_AITER_FUSED_MOE_A16W4", "False").lower() + in ("true", "1") + ), # use aiter rms norm op if aiter ops are enabled. "VLLM_ROCM_USE_AITER_RMSNORM": lambda: ( os.getenv("VLLM_ROCM_USE_AITER_RMSNORM", "True").lower() in ("true", "1") diff --git a/vllm/model_executor/layers/fused_moe/config.py b/vllm/model_executor/layers/fused_moe/config.py index a9a2990ca2b53..c7d2a4c1cea89 100644 --- a/vllm/model_executor/layers/fused_moe/config.py +++ b/vllm/model_executor/layers/fused_moe/config.py @@ -711,6 +711,32 @@ def int8_w8a16_moe_quant_config( ) +def mxfp4_w4a4_moe_quant_config( + w1_scale: Union[torch.Tensor, "PrecisionConfig"], + w2_scale: Union[torch.Tensor, "PrecisionConfig"], + a1_scale: torch.Tensor | None = None, + a2_scale: torch.Tensor | None = None, + w1_bias: torch.Tensor | None = None, + w2_bias: torch.Tensor | None = None, + block_shape: list[int] | None = None, +) -> FusedMoEQuantConfig: + """ + Construct a quant config for mxfp4 activations and mxfp4 weights. + """ + return FusedMoEQuantConfig.make( + "mxfp4", + w1_scale=w1_scale, + w2_scale=w2_scale, + a1_scale=a1_scale, + a2_scale=a2_scale, + w1_bias=w1_bias, + w2_bias=w2_bias, + per_act_token_quant=False, + per_out_ch_quant=False, + block_shape=block_shape, + ) + + def int4_w4afp8_moe_quant_config( w1_scale: torch.Tensor, w2_scale: torch.Tensor, diff --git a/vllm/model_executor/layers/fused_moe/gpt_oss_triton_kernels_moe.py b/vllm/model_executor/layers/fused_moe/gpt_oss_triton_kernels_moe.py index 0b006e15632e1..4ebf0053ed8d5 100644 --- a/vllm/model_executor/layers/fused_moe/gpt_oss_triton_kernels_moe.py +++ b/vllm/model_executor/layers/fused_moe/gpt_oss_triton_kernels_moe.py @@ -196,6 +196,178 @@ def triton_kernel_fused_experts( return output_tensor +def triton_kernel_moe_oss_forward( + hidden_states: torch.Tensor, + w1, # Tensor or triton_kernels.Tensor + w2, # Tensor or triton_kernels.Tensor + gating_output: torch.Tensor, + topk: int, + renormalize: bool, + activation: str = "silu", + quant_config: FusedMoEQuantConfig | None = None, + apply_router_weight_on_input: bool = False, + global_num_experts: int = -1, + expert_map: torch.Tensor | None = None, + unpadded_N_w1=None, + unpadded_K_w1=None, + unpadded_N_w2=None, + unpadded_K_w2=None, +) -> torch.Tensor: + assert quant_config is not None + + if quant_config.use_mxfp4_w4a16: + routing_data, gather_idx, scatter_idx = routing( + gating_output, topk, sm_first=not renormalize + ) + elif quant_config.use_mxfp4_w4a4: + from aiter.ops.triton.moe_routing.routing import routing as aiter_routing + + routing_data, gather_idx, scatter_idx = aiter_routing( + gating_output, topk, sm_first=not renormalize + ) + + return triton_kernel_fused_oss_experts( + None, + hidden_states, + w1, + w2, + routing_data, + gather_idx, + scatter_idx, + activation=activation, + quant_config=quant_config, + apply_router_weight_on_input=apply_router_weight_on_input, + global_num_experts=global_num_experts, + expert_map=expert_map, + unpadded_N_w1=unpadded_N_w1, + unpadded_K_w1=unpadded_K_w1, + unpadded_N_w2=unpadded_N_w2, + unpadded_K_w2=unpadded_K_w2, + ) + + +# This is a triton implementation of the fused_experts function +def triton_kernel_fused_oss_experts( + output_tensor: torch.Tensor, + hidden_states: torch.Tensor, + w1, # Tensor or triton_kernels.Tensor + w2, # Tensor or triton_kernels.Tensor + routing_data, # RoutingData + gather_indx, # GatherIndx + scatter_indx, # ScatterIndx + activation: str = "silu", + quant_config: FusedMoEQuantConfig | None = None, + swiglu_alpha: float = 1.702, + swiglu_limit: float = 7.0, + apply_router_weight_on_input: bool = False, + global_num_experts: int = -1, + expert_map: torch.Tensor | None = None, + a1q_scale: torch.Tensor | None = None, + unpadded_N_w1=None, + unpadded_K_w1=None, + unpadded_N_w2=None, + unpadded_K_w2=None, +) -> torch.Tensor: + if quant_config is None: + quant_config = FUSED_MOE_UNQUANTIZED_CONFIG + + # type check, uint8 means mxfp4 + assert hidden_states.dtype == torch.bfloat16 + assert quant_config.w1_bias is None or quant_config.w1_bias.dtype == torch.float32 + assert quant_config.w2_bias is None or quant_config.w2_bias.dtype == torch.float32 + + # Shape check, only check non-mxfp4 + assert hidden_states.shape[-1] == w1.shape[-2] + assert w2.shape[-1] == w1.shape[1] + + E, _, N = w1.shape + + if global_num_experts == -1: + global_num_experts = E + + gammas = routing_data.gate_scal if routing_data else None + + if quant_config.use_mxfp4_w4a16: + act = FusedActivation( + FnSpecs("swiglu", triton_kernels.swiglu.swiglu_fn, ("alpha", "limit")), + (swiglu_alpha, swiglu_limit), + 2, + ) + intermediate_cache1 = matmul_ogs( + hidden_states, + w1, + quant_config.w1_bias, + routing_data, + gather_indx=gather_indx, + precision_config=quant_config.w1_precision, + gammas=gammas if apply_router_weight_on_input else None, + fused_activation=act, + ) + intermediate_cache3 = matmul_ogs( + intermediate_cache1, + w2, + quant_config.w2_bias, + routing_data, + scatter_indx=scatter_indx, + precision_config=quant_config.w2_precision, + gammas=None if apply_router_weight_on_input else gammas, + y=output_tensor, + ) + + elif quant_config.use_mxfp4_w4a4: + from aiter.ops.triton.moe_op_gemm_a8w4 import moe_gemm_a8w4 + from aiter.ops.triton.quant_moe import downcast_to_static_fp8 + + assert quant_config.w1_precision is not None, ( + "w1_precision in quant config can't be None" + ) + assert quant_config.w2_precision is not None, ( + "w2_precision in quant config can't be None" + ) + + hidden_states = downcast_to_static_fp8( + hidden_states, quant_config.w1_precision.flex_ctx.lhs_data.scale + ) + + intermediate_cache1 = moe_gemm_a8w4( + hidden_states, + w1.storage.data, + None, + quant_config.w1_precision.weight_scale.storage.data, + quant_config.w1_precision.flex_ctx.lhs_data.scale, + quant_config.w2_precision.flex_ctx.lhs_data.scale, + quant_config.w1_bias, + routing_data, + gather_indx=gather_indx, + gammas=gammas if apply_router_weight_on_input else None, + swizzle_mx_scale="CDNA4_SCALE", + out_dtype=torch.float8_e4m3fn, + apply_swiglu=True, + alpha=swiglu_alpha, + limit=swiglu_limit, + unpadded_N=unpadded_N_w1, + unpadded_K=unpadded_K_w1, + ) + + intermediate_cache3 = moe_gemm_a8w4( + intermediate_cache1, + w2.storage.data, + None, + quant_config.w2_precision.weight_scale.storage.data, + quant_config.w2_precision.flex_ctx.lhs_data.scale, + None, + quant_config.w2_bias, + routing_data, + scatter_indx=scatter_indx, + gammas=None if apply_router_weight_on_input else gammas, + swizzle_mx_scale="CDNA4_SCALE", + unpadded_N=unpadded_N_w2, + unpadded_K=unpadded_K_w2, + ) + + return intermediate_cache3 + + def make_routing_data( topk_ids: torch.Tensor, topk_weights: torch.Tensor, @@ -443,9 +615,6 @@ class UnfusedOAITritonExperts(BaseOAITritonExperts): expert_tokens_meta: mk.ExpertTokensMetadata | None, apply_router_weight_on_input: bool, ): - if self.quant_config is None: - self.quant_config = FUSED_MOE_UNQUANTIZED_CONFIG - if expert_map is not None: topk_ids = expert_map[topk_ids] diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index cc3afade709d9..16204a93b1a1d 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -275,7 +275,9 @@ def maybe_roundup_hidden_size( ) # we are padding globally so EP buffer allocation works - if quant_config and quant_config.get_name() == "mxfp4": + if quant_config and ( + quant_config.get_name() == "mxfp4" or quant_config.get_name() == "quark" + ): from vllm.model_executor.layers.quantization.mxfp4 import ( Mxfp4Backend, get_mxfp4_backend, @@ -414,7 +416,10 @@ class FusedMoE(CustomOp): # Expert mapping used in self.load_weights self.expert_mapping = expert_mapping + self._is_mxfp4 = self.is_mxfp4_quant(quant_config=quant_config) + # Round up hidden size if needed. + unpadded_hidden_size = hidden_size hidden_size = maybe_roundup_hidden_size( hidden_size, moe_in_dtype, @@ -635,6 +640,7 @@ class FusedMoE(CustomOp): moe_quant_params = { "num_experts": self.local_num_experts, "hidden_size": hidden_size, + "unpadded_hidden_size": unpadded_hidden_size, "intermediate_size_per_partition": self.intermediate_size_per_partition, "params_dtype": params_dtype, "weight_loader": self.weight_loader, @@ -1115,16 +1121,42 @@ class FusedMoE(CustomOp): expert_id: int, return_success: bool = False, ) -> bool | None: - if self.quant_config and self.quant_config.get_name() == "mxfp4": - # (FIXME) for gpt-oss all experts are combined - if "bias" in weight_name: - dim1 = loaded_weight.shape[1] - param.data[:, :dim1].copy_(loaded_weight) - else: - dim1 = loaded_weight.shape[1] - dim2 = loaded_weight.shape[2] - param.data[:, :dim1, :dim2].copy_(loaded_weight) - return True if return_success else None + if self._is_mxfp4: + assert self.quant_config is not None + if self.quant_config.get_name() == "mxfp4": + # (FIXME) for gpt-oss all experts are combined + if "bias" in weight_name: + dim1 = loaded_weight.shape[1] + param.data[:, :dim1].copy_(loaded_weight) + else: + dim1 = loaded_weight.shape[1] + dim2 = loaded_weight.shape[2] + param.data[:, :dim1, :dim2].copy_(loaded_weight) + return True if return_success else None + elif self.quant_config.get_name() == "quark": + # When self._is_mxfp4 is true, model_dtype must be gpt_oss + expert_data = param.data[expert_id] + if "input_scale" in weight_name: + assert loaded_weight.numel() == 1 + expert_data.data.copy_(loaded_weight) + return True if return_success else None + + shard_dim = ( + 0 if shard_id in ("w1", "w3") or "bias" in weight_name else 1 + ) + if shard_id == "w2": + shard_size = loaded_weight.shape[shard_dim] // self.tp_size + loaded_weight = loaded_weight.narrow( + shard_dim, shard_size * self.tp_rank, shard_size + ) + if "bias" in weight_name: + dim1 = loaded_weight.shape[0] + expert_data.data[:dim1].copy_(loaded_weight) + else: + dim1 = loaded_weight.shape[0] + dim2 = loaded_weight.shape[1] + expert_data.data[:dim1, :dim2].copy_(loaded_weight) + return True if return_success else None quant_method_name = self.quant_method.__class__.__name__ global_expert_id = expert_id @@ -2070,6 +2102,26 @@ class FusedMoE(CustomOp): return s + def is_mxfp4_quant(self, quant_config: QuantizationConfig | None = None) -> bool: + if quant_config is None: + return False + + name = quant_config.get_name() + if name == "mxfp4": + return True + elif name == "quark": + from vllm.config import get_current_vllm_config + + vllm_config = get_current_vllm_config() + model_type = getattr(vllm_config.model_config.hf_config, "model_type", None) + from vllm.model_executor.layers.quantization.quark.quark import QuarkConfig + + assert isinstance(quant_config, QuarkConfig) + # Padding for triton kernel only is enabled when it is gpt_oss + return quant_config.is_global_mxfp4 and model_type == "gpt_oss" + + return False + def moe_forward( hidden_states: torch.Tensor, diff --git a/vllm/model_executor/layers/quantization/quark/quark.py b/vllm/model_executor/layers/quantization/quark/quark.py index 3640e5c452786..386880f1719a8 100644 --- a/vllm/model_executor/layers/quantization/quark/quark.py +++ b/vllm/model_executor/layers/quantization/quark/quark.py @@ -59,6 +59,19 @@ class QuarkConfig(QuantizationConfig): self.kv_cache_group = kv_cache_group self.kv_cache_config = kv_cache_config self.pack_method = pack_method + self._is_global_mxfp4() + + def _is_global_mxfp4(self): + # Check if it is MXFP4 to determine if pre-padding should be applied. + # This must be created during the initialization of moe. + global_quant_config = cast( + dict[str, Any], self.quant_config.get("global_quant_config") + ) + weight_quant = global_quant_config.get("weight") + input_quant = global_quant_config.get("input_tensors") + self.is_global_mxfp4 = self._is_mx_fp4( + weight_quant=weight_quant, input_quant=input_quant + ) def get_linear_method(self) -> "QuarkLinearMethod": return QuarkLinearMethod(self) @@ -277,6 +290,44 @@ class QuarkConfig(QuantizationConfig): # Only symmetric weight quantization supported. return is_int8_dtype and is_tensor and is_weight_symmetric and is_static + def _is_mx_fp4( + self, weight_quant: dict[str, Any] | None, input_quant: dict[str, Any] | None + ) -> bool: + # Confirm weights quantized. + # Confirm weights and input quantized. + if weight_quant is None or input_quant is None: + return False + + # Input and weight dtype needs to be fp4. + if weight_quant.get("dtype") != "fp4": + logger.debug("Quark model is not in MX-FP4 format: weight dtype not fp4") + return False + + # Input and weight qscheme needs to be per group. + if weight_quant.get("qscheme") != "per_group": + logger.debug("Quark model is not in MX-FP4 format: not per_group") + return False + + # Input and weight group size needs to be 32. + if weight_quant.get("group_size") != 32: + logger.debug("Quark model is not in MX-FP4 format: not group_size=32") + return False + + # Activations and weight scales need to be in e8m0 format. + if weight_quant.get("scale_format") != "e8m0": + logger.debug("Quark model is not in MX-FP4 format: not scale_format e8m0") + return False + + # Input dtype needs to be one of {'fp4', 'fp6_e2m3', 'fp8_e4m3'}. + if input_quant.get("dtype") not in ("fp4", "fp6_e2m3", "fp8_e4m3"): + logger.debug( + "Quark model is not in MX-FP4 format: expected input dtype " + "to be one of {'fp4', 'fp6_e2m3', 'fp8_e4m3'}" + ) + return False + + return True + def _is_ocp_mx( self, weight_quant: dict[str, Any] | None, diff --git a/vllm/model_executor/layers/quantization/quark/quark_moe.py b/vllm/model_executor/layers/quantization/quark/quark_moe.py index d84e22d1fa0f2..cf7e97c145637 100644 --- a/vllm/model_executor/layers/quantization/quark/quark_moe.py +++ b/vllm/model_executor/layers/quantization/quark/quark_moe.py @@ -5,8 +5,8 @@ from typing import Any import torch -import vllm.envs as envs from vllm import _custom_ops as ops +from vllm import envs from vllm._aiter_ops import rocm_aiter_ops from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe import ( @@ -18,12 +18,14 @@ from vllm.model_executor.layers.fused_moe import ( from vllm.model_executor.layers.fused_moe.config import ( FusedMoEQuantConfig, fp8_w8a8_moe_quant_config, + mxfp4_w4a4_moe_quant_config, ocp_mx_moe_quant_config, ) from vllm.model_executor.layers.fused_moe.fused_marlin_moe import fused_marlin_moe from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import ( prepare_moe_fp8_layer_for_marlin, ) +from vllm.model_executor.layers.quantization.utils.mxfp4_utils import _swizzle_mxfp4 from vllm.model_executor.layers.quantization.utils.ocp_mx_utils import ( OCP_MX_BLOCK_SIZE, OCP_MX_Scheme, @@ -37,10 +39,17 @@ from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( from vllm.model_executor.utils import set_weight_attrs from vllm.platforms import current_platform from vllm.scalar_type import scalar_types +from vllm.utils.math_utils import round_up logger = init_logger(__name__) -__all__ = ["QuarkMoEMethod", "QuarkW8A8Fp8MoEMethod", "QuarkOCP_MX_MoEMethod"] +__all__ = [ + "QuarkMoEMethod", + "QuarkW8A8Fp8MoEMethod", + "QuarkW4MXFp4MoEMethod_OSS", + "QuarkW4MXFp4MoEMethod", + "QuarkOCP_MX_MoEMethod", +] class QuarkMoEMethod(FusedMoEMethodBase): @@ -64,8 +73,16 @@ class QuarkMoEMethod(FusedMoEMethodBase): weight_config = layer_quant_config.get("weight") input_config = layer_quant_config.get("input_tensors") + from vllm.config import get_current_vllm_config + vllm_config = get_current_vllm_config() + model_type = getattr(vllm_config.model_config.hf_config, "model_type", None) + if quant_config._is_fp8_w8a8(weight_config, input_config): return QuarkW8A8Fp8MoEMethod(weight_config, input_config, module.moe_config) + elif quant_config._is_mx_fp4(weight_config, input_config) and model_type == "gpt_oss": + return QuarkW4MXFp4MoEMethod_OSS( + weight_config, input_config, module.moe_config + ) elif quant_config._is_ocp_mx(weight_config, input_config): return QuarkOCP_MX_MoEMethod(weight_config, input_config, module.moe_config) else: @@ -620,3 +637,599 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod): ) return out + + +class QuarkW4MXFp4MoEMethodBase(QuarkMoEMethod): + def __init__( + self, + weight_config: dict[str, Any], + input_config: dict[str, Any], + moe: FusedMoEConfig, + ): + super().__init__(moe) + self.weight_quant = weight_config + self.input_quant = input_config + self.weight_qscheme = self.weight_quant.get("qscheme") + self.input_qscheme = self.input_quant.get("qscheme") + self.static_input_scales = not self.input_quant.get("is_dynamic") + + def create_common_weights( + self, + layer: torch.nn.Module, + num_experts: int, + hidden_size: int, + intermediate_size_per_partition: int, + weight_scale: torch.dtype, + weight_scale_dtype: torch.dtype, + weight_scale_block_size: int, + **extra_weight_attrs, + ): + # WEIGHTS + # Fused gate_up_proj (column parallel) + w13_weight = torch.nn.Parameter( + torch.empty( + num_experts, + 2 * intermediate_size_per_partition, + hidden_size // 2, + dtype=weight_scale, + ), + requires_grad=False, + ) + layer.register_parameter("w13_weight", w13_weight) + set_weight_attrs(w13_weight, extra_weight_attrs) + + w13_weight_scale = torch.nn.Parameter( + torch.ones( + num_experts, + 2 * intermediate_size_per_partition, + hidden_size // weight_scale_block_size, + dtype=weight_scale_dtype, + ), + requires_grad=False, + ) + set_weight_attrs(w13_weight_scale, extra_weight_attrs) + layer.register_parameter("w13_weight_scale", w13_weight_scale) + + # down_proj (row parallel) + w2_weight = torch.nn.Parameter( + torch.empty( + num_experts, + hidden_size, + intermediate_size_per_partition // 2, + dtype=weight_scale, + ), + requires_grad=False, + ) + layer.register_parameter("w2_weight", w2_weight) + set_weight_attrs(w2_weight, extra_weight_attrs) + + # WEIGHT_SCALES + w2_weight_scale = torch.nn.Parameter( + torch.ones( + num_experts, + hidden_size, + intermediate_size_per_partition // weight_scale_block_size, + dtype=weight_scale_dtype, + ), + requires_grad=False, + ) + set_weight_attrs(w2_weight_scale, extra_weight_attrs) + layer.register_parameter("w2_weight_scale", w2_weight_scale) + + +class QuarkW4MXFp4MoEMethod(QuarkW4MXFp4MoEMethodBase): + def __init__( + self, + weight_config: dict[str, Any], + input_config: dict[str, Any], + moe: FusedMoEConfig, + ): + super().__init__(weight_config, input_config, moe) + if not ( + self.weight_qscheme == "per_group" and self.input_qscheme == "per_group" + ): + raise ValueError( + "For MX(FP4) Fused MoE layers, only per-group scales " + "for weights and activations are supported. Found " + f"{self.weight_qscheme}, {self.input_qscheme}" + ) # noqa E501 + + if self.static_input_scales: + raise NotImplementedError( + "QuarkW4MXFp4MoEMethod with static input scales is currently " + "not implemented. Please open an issue." + ) + + self.emulate = not current_platform.supports_mx() or not ( + rocm_aiter_ops.is_mxfp4_aiter_moe() + ) + + if self.emulate: + logger.warning_once( + f"The current mode (supports_mx={current_platform.supports_mx()}, " + f"use_mxfp4_aiter_moe={rocm_aiter_ops.is_mxfp4_aiter_moe()}, " + "does not support native MXFP4/MXFP6 " + "computation. Simulated weight dequantization and activation " + "QDQ (quantize and dequantize) will be used, with the linear " + "layers computed in high precision." + ) + else: + logger.info_once("The current mode supports native MoE MXFP4 computation") + + def create_weights( + self, + layer: torch.nn.Module, + num_experts: int, + hidden_size: int, + intermediate_size_per_partition: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + # Add the quantization method used (per tensor/grouped/channel) + # to ensure the weight scales are loaded in properly + extra_weight_attrs.update( + {"quant_method": FusedMoeWeightScaleSupported.BLOCK.value} + ) + self.create_common_weights( + layer, + num_experts, + hidden_size, + intermediate_size_per_partition, + torch.uint8, + torch.uint8, + OCP_MX_BLOCK_SIZE, + **extra_weight_attrs, + ) + + def process_weights_after_loading(self, layer): + if self.emulate: + return + + from aiter.utility.fp4_utils import e8m0_shuffle + + # Pre-shuffle weight scales + s0, s1, _ = layer.w13_weight_scale.shape + w13_weight_scale = layer.w13_weight_scale.view(s0 * s1, -1) + w13_weight_scale = e8m0_shuffle(w13_weight_scale) + layer.w13_weight_scale.data = w13_weight_scale.view(s0, s1, -1) + + s0, s1, _ = layer.w2_weight_scale.shape + w2_weight_scale = layer.w2_weight_scale.view(s0 * s1, -1) + w2_weight_scale = e8m0_shuffle(w2_weight_scale) + layer.w2_weight_scale.data = w2_weight_scale.view(s0, s1, -1) + torch.cuda.empty_cache() + + def get_fused_moe_quant_config( + self, layer: torch.nn.Module + ) -> FusedMoEQuantConfig | None: + return mxfp4_w4a4_moe_quant_config( + w1_bias=layer.w13_bias, + w2_bias=layer.w2_bias, + w1_scale=layer.w1_weight_scale, + w2_scale=layer.w2_weight_scale, + ) + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + router_logits: torch.Tensor, + expert_map: torch.Tensor | None = None, + ) -> torch.Tensor: + if layer.enable_eplb: + raise NotImplementedError( + "EPLB not supported for `QuarkW4MXFp4MoEMethod` yet." + ) + + topk_weights, topk_ids = layer.select_experts( + hidden_states=x, + router_logits=router_logits, + ) + + if not self.emulate: + from aiter import ActivationType, QuantType + from aiter.fused_moe import fused_moe + + aiter_acts = { + ActivationType.No.name.lower(): ActivationType.No, + ActivationType.Silu.name.lower(): ActivationType.Silu, + ActivationType.Gelu.name.lower(): ActivationType.Gelu, + } + assert layer.activation in aiter_acts, ( + f"Aiter CK fp4 MoE doesn't support activation {layer.activation}" + ) + if hasattr(torch, "float4_e2m1fn_x2"): + w13_weight = layer.w13_weight.view(torch.float4_e2m1fn_x2) + w2_weight = layer.w2_weight.view(torch.float4_e2m1fn_x2) + else: + w13_weight = layer.w13_weight + w2_weight = layer.w2_weight + + out = fused_moe( + x, + w13_weight, + w2_weight, + topk_weights, + topk_ids, + quant_type=QuantType.per_1x32, + w1_scale=layer.w13_weight_scale, + w2_scale=layer.w2_weight_scale, + activation=aiter_acts[layer.activation], + doweight_stage1=False, + ) + else: + from vllm.model_executor.layers.fused_moe import fused_experts + + out = fused_experts( + x, + layer.w13_weight, + layer.w2_weight, + topk_weights=topk_weights, + topk_ids=topk_ids, + inplace=True, + activation=layer.activation, + global_num_experts=layer.global_num_experts, + apply_router_weight_on_input=layer.apply_router_weight_on_input, + expert_map=expert_map, + quant_config=self.moe_quant_config, + ) + return out + + +class QuarkW4MXFp4MoEMethod_OSS(QuarkW4MXFp4MoEMethodBase): + def __init__( + self, + weight_config: dict[str, Any], + input_config: dict[str, Any], + moe: FusedMoEConfig, + ): + super().__init__(weight_config, input_config, moe) + + if not (self.weight_qscheme == "per_group"): + raise ValueError( + "For MX(FP4) Fused MoE layers, only per-group scales " + "for weights and activations are supported. Found " + f"{self.weight_qscheme}, {self.input_qscheme}" + ) # noqa E501 + + self.static_input_scales = not self.input_quant.get("is_dynamic") + self.emulate = not current_platform.supports_mx() + if self.emulate: + logger.warning_once( + "The current platform does not support native MXFP4 " + "computation. Simulated weight dequantization and activation " + "QDQ (quantize and dequantize) will be used, with the linear " + "layers computed in high precision." + ) + else: + logger.warning_once( + "The current platform supports native MXFP4 " + "computation, but kernels are not yet integrated in vLLM. " + "Simulated weight dequantization and activation " + "QDQ (quantize and dequantize) will be used, with the linear " + "layers computed in high precision." + ) + + def create_weights( + self, + layer: torch.nn.Module, + num_experts: int, + hidden_size: int, + intermediate_size_per_partition: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + self.num_experts = num_experts + # Add the quantization method used (per tensor/grouped/channel) + # to ensure the weight scales are loaded in properly + extra_weight_attrs.update( + {"quant_method": FusedMoeWeightScaleSupported.BLOCK.value} + ) + mxfp4_block = 32 + weight_dtype = torch.uint8 + weight_scale_dtype = torch.uint8 + per_tensor_fp8_act_scale_dtype = torch.bfloat16 + self.intermediate_size_per_partition = intermediate_size_per_partition + intermediate_size_per_partition_after_pad = intermediate_size_per_partition + + if current_platform.is_rocm(): + intermediate_size_per_partition_after_pad = round_up( + intermediate_size_per_partition, 256 + ) # 2880 -> 2944 + else: + intermediate_size_per_partition_after_pad = round_up( + intermediate_size_per_partition, 64 + ) + + self.unpadded_hidden_size = extra_weight_attrs.get( + "unpadded_hidden_size", hidden_size + ) + self.hidden_pad = hidden_size - self.unpadded_hidden_size + self.intermediate_pad = ( + intermediate_size_per_partition_after_pad - intermediate_size_per_partition + ) + + self.create_common_weights( + layer, + num_experts, + hidden_size, + intermediate_size_per_partition_after_pad, + weight_dtype, + weight_scale_dtype, + mxfp4_block, + **extra_weight_attrs, + ) + + w13_bias = torch.nn.Parameter( + torch.zeros( + num_experts, + 2 * intermediate_size_per_partition_after_pad, + dtype=torch.bfloat16, + ), + requires_grad=False, + ) + layer.register_parameter("w13_bias", w13_bias) + set_weight_attrs(w13_bias, extra_weight_attrs) + + w2_bias = torch.nn.Parameter( + torch.zeros( + num_experts, + hidden_size, + dtype=torch.bfloat16, + ), + requires_grad=False, + ) + layer.register_parameter("w2_bias", w2_bias) + set_weight_attrs(w2_bias, extra_weight_attrs) + + if self.static_input_scales: + w13_input_scale = torch.nn.Parameter( + torch.zeros( + num_experts, + dtype=per_tensor_fp8_act_scale_dtype, + ), + requires_grad=False, + ) + layer.register_parameter("w13_input_scale", w13_input_scale) + set_weight_attrs(w13_input_scale, extra_weight_attrs) + + w2_input_scale = torch.nn.Parameter( + torch.zeros( + num_experts, + dtype=per_tensor_fp8_act_scale_dtype, + ), + requires_grad=False, + ) + layer.register_parameter("w2_input_scale", w2_input_scale) + set_weight_attrs(w2_input_scale, extra_weight_attrs) + + def process_weights_after_loading(self, layer): + from triton_kernels.matmul_ogs import FlexCtx, PrecisionConfig + + w13_bias = layer.w13_bias.to(torch.float32) + w2_bias = layer.w2_bias.to(torch.float32) + + layer.w13_bias = torch.nn.Parameter(w13_bias, requires_grad=False) + layer.w2_bias = torch.nn.Parameter(w2_bias, requires_grad=False) + + # FIXME warp need to be adjusted based on batch size + # only apply to batched mode + if self.moe.use_ep: + num_warps = 4 if envs.VLLM_MOE_DP_CHUNK_SIZE <= 512 else 8 + else: + num_warps = 8 + + if envs.VLLM_ROCM_USE_AITER_FUSED_MOE_A16W4: + from aiter.ops.shuffle import shuffle_scale_a16w4, shuffle_weight_a16w4 + + w13_aiter_weight = layer.w13_weight.contiguous() + w13_aiter_scale = layer.w13_weight_scale.contiguous() + w2_aiter_weight = layer.w2_weight.contiguous() + w2_aiter_scale = layer.w2_weight_scale.contiguous() + + e, n, k = w13_aiter_weight.shape + w13_aiter_weight = ( + w13_aiter_weight.view(e, n // 2, 2, k) + .permute(0, 2, 1, 3) + .contiguous() + .view(e, n, k) + ) + w13_aiter_scale = ( + w13_aiter_scale.view(e, n // 2, 2, -1) + .permute(0, 2, 1, 3) + .contiguous() + .view(e, n, -1) + ) + + w13_aiter_weight = w13_aiter_weight.view(torch.float4_e2m1fn_x2) + w13_aiter_scale = w13_aiter_scale.view(-1, w13_aiter_scale.shape[-1]) + w2_aiter_weight = w2_aiter_weight.view(torch.float4_e2m1fn_x2) + w2_aiter_scale = w2_aiter_scale.view(-1, w2_aiter_scale.shape[-1]) + + self.w13_weight_aiter_tensor = shuffle_weight_a16w4( + w13_aiter_weight, 16, True + ) + self.w13_scale_aiter_tensor = shuffle_scale_a16w4( + w13_aiter_scale, self.num_experts, True + ) + self.w2_weight_aiter_tensor = shuffle_weight_a16w4( + w2_aiter_weight, 16, False + ) + self.w2_scale_aiter_tensor = shuffle_scale_a16w4( + w2_aiter_scale, self.num_experts, False + ) + self.w13_bias_aiter_tensor = ( + layer.w13_bias.view(-1, n // 2, 2) + .permute(0, 2, 1) + .contiguous() + .view(-1, n) + ) + else: + w13_weight, w13_flex, w13_scale = _swizzle_mxfp4( + layer.w13_weight, layer.w13_weight_scale, num_warps + ) + w2_weight, w2_flex, w2_scale = _swizzle_mxfp4( + layer.w2_weight, layer.w2_weight_scale, num_warps + ) + + self.w13_weight_triton_tensor = w13_weight + self.w2_weight_triton_tensor = w2_weight + + # need to delete the original weights to save memory on single GPU + del layer.w13_weight + del layer.w2_weight + layer.w13_weight = None + layer.w2_weight = None + torch.cuda.empty_cache() + + if not envs.VLLM_ROCM_USE_AITER_FUSED_MOE_A16W4: + if layer.w13_input_scale is None or layer.w2_input_scale is None: + raise ValueError( + "QuantConfig has static quantization, but found " + "activation scales are None." + ) + if not all_close_1d(layer.w13_input_scale) or not all_close_1d( + layer.w2_input_scale + ): + logger.warning_once( + "Found input_scales that are not equal for " + "fp8 MoE layer. Using the maximum across experts " + "for each layer." + ) + # layer.w13_input_scale = torch.nn.Parameter( + # layer.w13_input_scale.max(), requires_grad=False) + # layer.w2_input_scale = torch.nn.Parameter( + # layer.w2_input_scale.max(), requires_grad=False) + + layer.w13_input_scale = torch.nn.Parameter( + layer.w13_input_scale.max().to(torch.float32), requires_grad=False + ) + layer.w2_input_scale = torch.nn.Parameter( + layer.w2_input_scale.max().to(torch.float32), requires_grad=False + ) + + from triton_kernels.numerics import InFlexData + + lhs_data13 = InFlexData(scale=layer.w13_input_scale) + lhs_data2 = InFlexData(scale=layer.w2_input_scale) + + self.w13_precision_config = PrecisionConfig( + weight_scale=w13_scale, + flex_ctx=FlexCtx(rhs_data=w13_flex, lhs_data=lhs_data13), + ) + self.w2_precision_config = PrecisionConfig( + weight_scale=w2_scale, + flex_ctx=FlexCtx(rhs_data=w2_flex, lhs_data=lhs_data2), + ) + + def get_fused_moe_quant_config( + self, layer: torch.nn.Module + ) -> FusedMoEQuantConfig | None: + if envs.VLLM_ROCM_USE_AITER_FUSED_MOE_A16W4: + return mxfp4_w4a4_moe_quant_config( + w1_bias=layer.w13_bias, + w2_bias=layer.w2_bias, + w1_scale=self.w13_scale_aiter_tensor, + w2_scale=self.w2_scale_aiter_tensor, + ) + else: + w1_scale = self.w13_precision_config + w2_scale = self.w2_precision_config + + # TODO: how to set scale? + return mxfp4_w4a4_moe_quant_config( + w1_bias=layer.w13_bias, + w2_bias=layer.w2_bias, + w1_scale=w1_scale, + w2_scale=w2_scale, + ) + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + router_logits: torch.Tensor, + expert_map: torch.Tensor | None = None, + ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: + if layer.enable_eplb: + raise NotImplementedError( + "EPLB not supported for `QuarkW4MXFp4MoEMethod_OSS` yet." + ) + + if envs.VLLM_ROCM_USE_AITER_FUSED_MOE_A16W4: + from aiter import moe_cktile2stages_gemm1, moe_cktile2stages_gemm2 + from aiter.fused_moe import fused_topk, moe_sorting + + token_num = x.shape[0] + BLOCKM = 16 if token_num < 2048 else 32 + topk_weights, topk_ids = fused_topk(x, router_logits, layer.top_k, True) + sorted_ids, sorted_weights, sorted_expert_ids, num_valid_ids, moe_out = ( + moe_sorting( + topk_ids, + topk_weights, + self.num_experts, + x.shape[1], + torch.bfloat16, + BLOCKM, + ) + ) + _, n1, k1 = self.w13_weight_aiter_tensor.shape + _, k2, n2 = self.w2_weight_aiter_tensor.shape + D = n2 if k2 == k1 else n2 * 2 + cktile_moe_out1 = torch.empty( + (token_num, layer.top_k, D), dtype=torch.bfloat16, device=x.device + ) + moe_cktile2stages_gemm1( + x, + self.w13_weight_aiter_tensor, + cktile_moe_out1, + sorted_ids, + sorted_expert_ids, + num_valid_ids, + layer.top_k, + self.intermediate_pad // 64 * 64 * 2, + self.hidden_pad // 128 * 128, # k_pad_zeros + None, # sorted_weights + None, + self.w13_scale_aiter_tensor, + self.w13_bias_aiter_tensor, + BLOCKM, # block_size + ) + moe_cktile2stages_gemm2( + cktile_moe_out1, + self.w2_weight_aiter_tensor, + moe_out, + sorted_ids, + sorted_expert_ids, + num_valid_ids, + layer.top_k, + self.hidden_pad // 64 * 64, # n_pad_zeros + self.intermediate_pad // 128 * 128, + sorted_weights, # sorted_weights + None, + self.w2_scale_aiter_tensor, + layer.w2_bias, + BLOCKM, # block_size + ) + return moe_out + else: + from vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe import ( # noqa: E501 + triton_kernel_moe_oss_forward, + ) + + return triton_kernel_moe_oss_forward( + hidden_states=x, + w1=self.w13_weight_triton_tensor, + w2=self.w2_weight_triton_tensor, + gating_output=router_logits, + topk=layer.top_k, + renormalize=layer.renormalize, + global_num_experts=layer.global_num_experts, + expert_map=expert_map, + quant_config=self.moe_quant_config, + apply_router_weight_on_input=layer.apply_router_weight_on_input, + unpadded_N_w1=self.intermediate_size_per_partition * 2, + unpadded_K_w1=self.unpadded_hidden_size, + unpadded_N_w2=self.unpadded_hidden_size, + unpadded_K_w2=self.intermediate_size_per_partition, + ) diff --git a/vllm/model_executor/models/gpt_oss.py b/vllm/model_executor/models/gpt_oss.py index f9d4cce06cc4a..cca2a24500e4c 100644 --- a/vllm/model_executor/models/gpt_oss.py +++ b/vllm/model_executor/models/gpt_oss.py @@ -1,5 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import typing from collections.abc import Iterable import torch @@ -520,6 +521,287 @@ class GptOssModel(nn.Module): loaded_params.add(name) return loaded_params + def _load_weights_quark( + self, + ep_rank_start: int, + ep_rank_end: int, + heads_per_rank: int, + head_start: int, + weights: Iterable[tuple[str, torch.Tensor]], + stacked_params_mapping: list[tuple[str, ...]], + ) -> set[str]: + params_dict = dict(self.named_parameters()) + loaded_params: set[str] = set() + + mxfp4_block = 32 + use_ep = self.parallel_config.enable_expert_parallel + assert not use_ep, "Expert parallelism is not support for quark MoE" + + tp_rank = get_tensor_model_parallel_rank() + tp_size = get_tensor_model_parallel_world_size() + + intermediate_size = self.config.intermediate_size + per_rank_intermediate_size = cdiv(intermediate_size, tp_size) + # Calculate common slicing bounds for current rank + tp_rank_start = tp_rank * per_rank_intermediate_size + tp_rank_end = min((tp_rank + 1) * per_rank_intermediate_size, intermediate_size) + expert_params_mapping = self.get_expert_mapping() + for name, loaded_weight in weights: + if "sinks" in name: + # Handle attention sinks (distributed across ranks) + param = params_dict[name] + narrow_weight = loaded_weight.narrow(0, head_start, heads_per_rank) + param.data.copy_(narrow_weight) + loaded_params.add(name) + continue + + # mapping to convert individual experts input_scale into fused_moe. + elif "input_scale" in name: # w2 w13 input_scale + parts = name.split(".") + expert_id = int(parts[-2]) + name = ".".join(parts[:-2] + parts[-1:]) + param = params_dict[name] + + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader( + param, + loaded_weight, + weight_name=name, + shard_id=None, + expert_id=expert_id, + ) + loaded_params.add(name) + continue + + elif ".w13_weight_scale" in name: + parts = name.split(".") + expert_id = int(parts[-2]) + name = ".".join(parts[:-2] + parts[-1:]) + # Handle MLP gate and up projection weights scale + if use_ep: + narrow_weight = loaded_weight[ep_rank_start:ep_rank_end, ...] + else: + narrow_weight = loaded_weight[ + 2 * tp_rank_start : 2 * tp_rank_end, ... + ] + + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader( + param, + narrow_weight, + weight_name=name, + shard_id=None, + expert_id=expert_id, + ) + loaded_params.add(name) + continue + + elif ".w2_weight_scale" in name: + parts = name.split(".") + expert_id = int(parts[-2]) + name = ".".join(parts[:-2] + parts[-1:]) + if use_ep: + narrow_weight = loaded_weight[ep_rank_start:ep_rank_end, ...] + else: + narrow_weight = loaded_weight[ + ..., tp_rank_start // mxfp4_block : tp_rank_end // mxfp4_block + ] + + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader( + param, + narrow_weight, + weight_name=name, + shard_id=None, + expert_id=expert_id, + ) + loaded_params.add(name) + continue + + # mapping to convert weight and bias of individual + # experts gate_up_proj into fused_moe. + elif ".w13_weight" in name: + parts = name.split(".") + expert_id = int(parts[-2]) + name = ".".join(parts[:-2] + parts[-1:]) + + # Extract gate and up projection parts + # since the weight is shuffled, we can slice directly + if use_ep: + narrow_weight = loaded_weight[ep_rank_start:ep_rank_end, ...] + else: + narrow_weight = loaded_weight[ + 2 * tp_rank_start : 2 * tp_rank_end, ... + ] + + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader( + param, + narrow_weight, + weight_name=name, + shard_id=None, + expert_id=expert_id, + ) + loaded_params.add(name) + continue + + elif ".w2_weight" in name: + parts = name.split(".") + expert_id = int(parts[-2]) + name = ".".join(parts[:-2] + parts[-1:]) + + if use_ep: + narrow_weight = loaded_weight[ep_rank_start:ep_rank_end, ...] + else: + narrow_weight = loaded_weight[ + ..., tp_rank_start // 2 : tp_rank_end // 2 + ] + + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader( + param, + narrow_weight, + weight_name=name, + shard_id=None, + expert_id=expert_id, + ) + loaded_params.add(name) + continue + + elif ".w13_bias" in name: + parts = name.split(".") + expert_id = int(parts[-2]) + name = ".".join(parts[:-2] + parts[-1:]) + if use_ep: + narrow_weight = loaded_weight[ep_rank_start:ep_rank_end, ...] + else: + narrow_weight = loaded_weight[2 * tp_rank_start : 2 * tp_rank_end] + + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader( + param, + narrow_weight, + weight_name=name, + shard_id=None, + expert_id=expert_id, + ) + loaded_params.add(name) + continue + + elif ".w2_bias" in name: + parts = name.split(".") + expert_id = int(parts[-2]) + name = ".".join(parts[:-2] + parts[-1:]) + # Handle MLP down projection bias + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + if use_ep: + loaded_weight = loaded_weight[ep_rank_start:ep_rank_end, ...] + else: + # (only load on rank 0 to avoid duplication) + if tp_rank != 0: + loaded_weight.zero_() + weight_loader( + param, + loaded_weight, + weight_name=name, + shard_id=None, + expert_id=expert_id, + ) + loaded_params.add(name) + continue + + for param_name, weight_name, shard_id in stacked_params_mapping: + # Skip non-stacked layers and experts (experts handled below). + if weight_name not in name: + continue + # We have mlp.experts[0].gate_proj in the checkpoint. + # Since we handle the experts below in expert_params_mapping, + # we need to skip here BEFORE we update the name, otherwise + # name will be updated to mlp.experts[0].gate_up_proj, which + # will then be updated below in expert_params_mapping + # for mlp.experts[0].gate_gate_up_proj, which breaks load. + if ("mlp.experts." in name) and name not in params_dict: + continue + name = name.replace(weight_name, param_name) + + if is_pp_missing_parameter(name, self): + continue + + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + is_expert_weight = False + for mapping in expert_params_mapping: + param_name, weight_name, expert_id, shard_id = mapping + if weight_name not in name: + continue + + # Anyway, this is an expert weight and should not be + # attempted to load as other weights later + is_expert_weight = True + + # Do not modify `name` since the loop may continue here + # Instead, create a new variable + name_mapped = name.replace(weight_name, param_name) + + if is_pp_missing_parameter(name_mapped, self): + continue + + param = params_dict[name_mapped] + # We should ask the weight loader to return success or not + # here since otherwise we may skip experts with other + # available replicas. + weight_loader = typing.cast( + typing.Callable[..., bool], param.weight_loader + ) + success = weight_loader( + param, + loaded_weight, + name_mapped, + shard_id=shard_id, + expert_id=expert_id, + return_success=True, + ) + if success: + name = name_mapped + break + else: + if is_expert_weight: + # We've checked that this is an expert weight + # However it's not mapped locally to this rank + # So we simply skip it + continue + + if is_pp_missing_parameter(name, self): + continue + + param = params_dict[name] + weight_loader = getattr( + param, "weight_loader", default_weight_loader + ) + weight_loader(param, loaded_weight) + loaded_params.add(name) + + return loaded_params + + def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: + # Params for weights, fp8 weight scales, fp8 activation scales + # (param_name, weight_name, expert_id, shard_id) + return FusedMoE.make_expert_params_mapping( + ckpt_gate_proj_name="gate_proj", + ckpt_down_proj_name="down_proj", + ckpt_up_proj_name="up_proj", + num_experts=self.config.num_local_experts, + ) + def _load_weights_other( self, ep_rank_end: int, @@ -669,6 +951,15 @@ class GptOssModel(nn.Module): weights, stacked_params_mapping, ) + elif quant_method == "quark": + return self._load_weights_quark( + ep_rank_end, + ep_rank_start, + heads_per_rank, + head_start, + weights, + stacked_params_mapping, + ) else: return self._load_weights_other( ep_rank_end, @@ -701,6 +992,15 @@ class GptOssForCausalLM(nn.Module, SupportsPP, SupportsEagle3, SupportsLoRA): # MoE Bias ".gate_up_proj_bias": ".w13_bias", ".down_proj_bias": ".w2_bias", + # For quark format + ".gate_up_proj.weight": ".w13_weight", + ".gate_up_proj.weight_scale": ".w13_weight_scale", + ".gate_up_proj.bias": ".w13_bias", + ".gate_up_proj.input_scale": ".w13_input_scale", + ".down_proj.weight": ".w2_weight", + ".down_proj.weight_scale": ".w2_weight_scale", + ".down_proj.bias": ".w2_bias", + ".down_proj.input_scale": ".w2_input_scale", }, )