diff --git a/vllm/model_executor/layers/quantization/quark/quark.py b/vllm/model_executor/layers/quantization/quark/quark.py index 3640e5c452786..39bcd56bcd3dc 100644 --- a/vllm/model_executor/layers/quantization/quark/quark.py +++ b/vllm/model_executor/layers/quantization/quark/quark.py @@ -218,6 +218,49 @@ class QuarkConfig(QuantizationConfig): else: return False + def _is_fp8_w4a8( + self, + weight_quant: list[dict[str, Any]] | None, + input_quant: dict[str, Any] | None, + ) -> bool: + # Confirm weights and input quantized. + if weight_quant is None or input_quant is None: + return False + + if not isinstance(weight_quant, list) or len(weight_quant) != 2: + return False + + # Confirm weight scheme is supported + is_w4a8_dtype = ( + weight_quant[0].get("dtype") == "fp8_e4m3" + and weight_quant[1].get("dtype") == "int4" + and input_quant.get("dtype") == "fp8_e4m3" + ) + is_static_weight = not weight_quant[0].get("is_dynamic") and not weight_quant[ + 1 + ].get("is_dynamic") + is_per_tensor_fp8_and_per_channel_int4_weight = ( + weight_quant[0].get("qscheme") == "per_tensor" + and weight_quant[1].get("qscheme") == "per_channel" + and weight_quant[1].get("symmetric") is True + and weight_quant[1].get("ch_axis") == 0 + ) + + if not ( + is_w4a8_dtype + and is_static_weight + and is_per_tensor_fp8_and_per_channel_int4_weight + ): + return False + + # Dynamic quantization is always supported if weights supported. + if input_quant.get("is_dynamic"): + return True + + # Confirm activation scheme is supported. + is_per_tensor_activation = input_quant.get("qscheme") == "per_tensor" + return is_per_tensor_activation + def _is_fp8_w8a8( self, weight_quant: dict[str, Any] | None, diff --git a/vllm/model_executor/layers/quantization/quark/quark_moe.py b/vllm/model_executor/layers/quantization/quark/quark_moe.py index d84e22d1fa0f2..0b9b098afb1f6 100644 --- a/vllm/model_executor/layers/quantization/quark/quark_moe.py +++ b/vllm/model_executor/layers/quantization/quark/quark_moe.py @@ -63,8 +63,9 @@ class QuarkMoEMethod(FusedMoEMethodBase): ) weight_config = layer_quant_config.get("weight") input_config = layer_quant_config.get("input_tensors") - - if quant_config._is_fp8_w8a8(weight_config, input_config): + if quant_config._is_fp8_w4a8(weight_config, input_config): + return QuarkW4A8Fp8MoEMethod(weight_config, input_config, module.moe_config) + elif quant_config._is_fp8_w8a8(weight_config, input_config): return QuarkW8A8Fp8MoEMethod(weight_config, input_config, module.moe_config) elif quant_config._is_ocp_mx(weight_config, input_config): return QuarkOCP_MX_MoEMethod(weight_config, input_config, module.moe_config) @@ -396,6 +397,161 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod): ) +class QuarkW4A8Fp8MoEMethod(QuarkMoEMethod): + def __init__( + self, + weight_config: dict[str, Any], + input_config: dict[str, Any], + moe: FusedMoEConfig, + ): + super().__init__(moe) + self.weight_quant = weight_config + self.input_quant = input_config + + assert rocm_aiter_ops.is_fused_moe_enabled(), ( + "W4A8 FP8 MoE requires ROCm AITER fused MoE support." + ) + + def create_weights( + self, + layer: torch.nn.Module, + num_experts: int, + hidden_size: int, + intermediate_size_per_partition: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + params_dtype = torch.uint32 + w13_weight = torch.nn.Parameter( + torch.empty( + num_experts, + 2 * intermediate_size_per_partition, + hidden_size // 8, # INT32 packing for W4 + dtype=params_dtype, + ), + requires_grad=False, + ) + w2_weight = torch.nn.Parameter( + torch.empty( + num_experts, + hidden_size, + intermediate_size_per_partition // 8, # INT32 packing for W4 + dtype=params_dtype, + ), + requires_grad=False, + ) + layer.register_parameter("w13_weight", w13_weight) + layer.register_parameter("w2_weight", w2_weight) + set_weight_attrs(w13_weight, extra_weight_attrs) + set_weight_attrs(w2_weight, extra_weight_attrs) + + # Per-tensor fp8 weight scales + w13_weight_scale = torch.nn.Parameter( + torch.ones(num_experts, 2, dtype=torch.float32), requires_grad=False + ) + w2_weight_scale = torch.nn.Parameter( + torch.ones(num_experts, dtype=torch.float32), requires_grad=False + ) + layer.register_parameter("w13_weight_scale", w13_weight_scale) + layer.register_parameter("w2_weight_scale", w2_weight_scale) + extra_weight_attrs.update( + {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value} + ) + set_weight_attrs(w13_weight_scale, extra_weight_attrs) + set_weight_attrs(w2_weight_scale, extra_weight_attrs) + + # Per-channel int4 weight scales + w13_weight_scale_2 = torch.nn.Parameter( + torch.ones( + num_experts, + 2 * intermediate_size_per_partition, + dtype=torch.float32, + ), + requires_grad=False, + ) + w2_weight_scale_2 = torch.nn.Parameter( + torch.ones(num_experts, hidden_size, dtype=torch.float32), + requires_grad=False, + ) + layer.register_parameter("w13_weight_scale_2", w13_weight_scale_2) + layer.register_parameter("w2_weight_scale_2", w2_weight_scale_2) + extra_weight_attrs.update( + {"quant_method": FusedMoeWeightScaleSupported.CHANNEL.value} + ) + set_weight_attrs(w13_weight_scale_2, extra_weight_attrs) + set_weight_attrs(w2_weight_scale_2, extra_weight_attrs) + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + shuffled_w13, shuffled_w2 = rocm_aiter_ops.shuffle_weights( + layer.w13_weight.data, layer.w2_weight.data + ) + layer.w13_weight = torch.nn.Parameter(shuffled_w13, requires_grad=False) + layer.w2_weight = torch.nn.Parameter(shuffled_w2, requires_grad=False) + + # INT4-FP8 : offset INT4 w13_weight_scale1 to single w13_weight_scale + # Fp8 moe kernel needs single fp8 w13_weight_scale for w13 per expert. + # We won't do requant each expert's fp8 weight (not direct available), + # instead we adjust half of INT4 w13_weight_scale1 numbers + shard_size = layer.intermediate_size_per_partition + max_w13_scales = layer.w13_weight_scale.max(dim=1).values + assert torch.all(max_w13_scales != 0), "fp8 weight scale cannot be zero." + for expert_id in range(layer.local_num_experts): + start = 0 + max_w13_scale_fp8 = max_w13_scales[expert_id] + for shard_id in range(2): + if layer.w13_weight_scale[expert_id][shard_id] != max_w13_scale_fp8: + int4_rescale = ( + layer.w13_weight_scale[expert_id][shard_id] / max_w13_scale_fp8 + ) + layer.w13_weight_scale_2[expert_id][start : start + shard_size] *= ( + int4_rescale + ) + start += shard_size + + layer.w13_weight_scale = torch.nn.Parameter(max_w13_scales, requires_grad=False) + + # special hack to asm_moe, which takes (weight_scale1 * weight_scale) as post + # GEMM scaling optimal design - shall apply per-column weight_scale1 before + # GEMM, and weight_scale post + for expert_id in range(layer.local_num_experts): + layer.w13_weight_scale_2[expert_id] *= max_w13_scales[expert_id] + layer.w2_weight_scale_2[expert_id] *= layer.w2_weight_scale[expert_id] + + def get_fused_moe_quant_config(self, layer): + return fp8_w8a8_moe_quant_config( + w1_scale=layer.w13_weight_scale_2, + w2_scale=layer.w2_weight_scale_2, + per_out_ch_quant=True, + ) + + def apply( + self, + layer: FusedMoE, + x: torch.Tensor, + router_logits: torch.Tensor, + ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: + topk_weights, topk_ids, _ = layer.select_experts( + hidden_states=x, + router_logits=router_logits, + ) + + from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( + rocm_aiter_fused_experts, + ) + + return rocm_aiter_fused_experts( + hidden_states=x, + w1=layer.w13_weight, + w2=layer.w2_weight, + topk_weights=topk_weights, + topk_ids=topk_ids, + activation=layer.activation, + apply_router_weight_on_input=layer.apply_router_weight_on_input, + quant_config=self.moe_quant_config, + expert_map=layer.expert_map, + ) + + class QuarkOCP_MX_MoEMethod(QuarkMoEMethod): def __init__( self,