[Quantization] Support Quark int4-fp8 w4a8 for MoE (#30071)

Signed-off-by: Bowen Bao <bowenbao@amd.com>
This commit is contained in:
Bowen Bao 2025-12-17 20:20:42 -08:00 committed by GitHub
parent 5a3adf581e
commit 0c738b58bc
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 201 additions and 2 deletions

View File

@ -218,6 +218,49 @@ class QuarkConfig(QuantizationConfig):
else:
return False
def _is_fp8_w4a8(
self,
weight_quant: list[dict[str, Any]] | None,
input_quant: dict[str, Any] | None,
) -> bool:
# Confirm weights and input quantized.
if weight_quant is None or input_quant is None:
return False
if not isinstance(weight_quant, list) or len(weight_quant) != 2:
return False
# Confirm weight scheme is supported
is_w4a8_dtype = (
weight_quant[0].get("dtype") == "fp8_e4m3"
and weight_quant[1].get("dtype") == "int4"
and input_quant.get("dtype") == "fp8_e4m3"
)
is_static_weight = not weight_quant[0].get("is_dynamic") and not weight_quant[
1
].get("is_dynamic")
is_per_tensor_fp8_and_per_channel_int4_weight = (
weight_quant[0].get("qscheme") == "per_tensor"
and weight_quant[1].get("qscheme") == "per_channel"
and weight_quant[1].get("symmetric") is True
and weight_quant[1].get("ch_axis") == 0
)
if not (
is_w4a8_dtype
and is_static_weight
and is_per_tensor_fp8_and_per_channel_int4_weight
):
return False
# Dynamic quantization is always supported if weights supported.
if input_quant.get("is_dynamic"):
return True
# Confirm activation scheme is supported.
is_per_tensor_activation = input_quant.get("qscheme") == "per_tensor"
return is_per_tensor_activation
def _is_fp8_w8a8(
self,
weight_quant: dict[str, Any] | None,

View File

@ -63,8 +63,9 @@ class QuarkMoEMethod(FusedMoEMethodBase):
)
weight_config = layer_quant_config.get("weight")
input_config = layer_quant_config.get("input_tensors")
if quant_config._is_fp8_w8a8(weight_config, input_config):
if quant_config._is_fp8_w4a8(weight_config, input_config):
return QuarkW4A8Fp8MoEMethod(weight_config, input_config, module.moe_config)
elif quant_config._is_fp8_w8a8(weight_config, input_config):
return QuarkW8A8Fp8MoEMethod(weight_config, input_config, module.moe_config)
elif quant_config._is_ocp_mx(weight_config, input_config):
return QuarkOCP_MX_MoEMethod(weight_config, input_config, module.moe_config)
@ -396,6 +397,161 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod):
)
class QuarkW4A8Fp8MoEMethod(QuarkMoEMethod):
def __init__(
self,
weight_config: dict[str, Any],
input_config: dict[str, Any],
moe: FusedMoEConfig,
):
super().__init__(moe)
self.weight_quant = weight_config
self.input_quant = input_config
assert rocm_aiter_ops.is_fused_moe_enabled(), (
"W4A8 FP8 MoE requires ROCm AITER fused MoE support."
)
def create_weights(
self,
layer: torch.nn.Module,
num_experts: int,
hidden_size: int,
intermediate_size_per_partition: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
):
params_dtype = torch.uint32
w13_weight = torch.nn.Parameter(
torch.empty(
num_experts,
2 * intermediate_size_per_partition,
hidden_size // 8, # INT32 packing for W4
dtype=params_dtype,
),
requires_grad=False,
)
w2_weight = torch.nn.Parameter(
torch.empty(
num_experts,
hidden_size,
intermediate_size_per_partition // 8, # INT32 packing for W4
dtype=params_dtype,
),
requires_grad=False,
)
layer.register_parameter("w13_weight", w13_weight)
layer.register_parameter("w2_weight", w2_weight)
set_weight_attrs(w13_weight, extra_weight_attrs)
set_weight_attrs(w2_weight, extra_weight_attrs)
# Per-tensor fp8 weight scales
w13_weight_scale = torch.nn.Parameter(
torch.ones(num_experts, 2, dtype=torch.float32), requires_grad=False
)
w2_weight_scale = torch.nn.Parameter(
torch.ones(num_experts, dtype=torch.float32), requires_grad=False
)
layer.register_parameter("w13_weight_scale", w13_weight_scale)
layer.register_parameter("w2_weight_scale", w2_weight_scale)
extra_weight_attrs.update(
{"quant_method": FusedMoeWeightScaleSupported.TENSOR.value}
)
set_weight_attrs(w13_weight_scale, extra_weight_attrs)
set_weight_attrs(w2_weight_scale, extra_weight_attrs)
# Per-channel int4 weight scales
w13_weight_scale_2 = torch.nn.Parameter(
torch.ones(
num_experts,
2 * intermediate_size_per_partition,
dtype=torch.float32,
),
requires_grad=False,
)
w2_weight_scale_2 = torch.nn.Parameter(
torch.ones(num_experts, hidden_size, dtype=torch.float32),
requires_grad=False,
)
layer.register_parameter("w13_weight_scale_2", w13_weight_scale_2)
layer.register_parameter("w2_weight_scale_2", w2_weight_scale_2)
extra_weight_attrs.update(
{"quant_method": FusedMoeWeightScaleSupported.CHANNEL.value}
)
set_weight_attrs(w13_weight_scale_2, extra_weight_attrs)
set_weight_attrs(w2_weight_scale_2, extra_weight_attrs)
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
shuffled_w13, shuffled_w2 = rocm_aiter_ops.shuffle_weights(
layer.w13_weight.data, layer.w2_weight.data
)
layer.w13_weight = torch.nn.Parameter(shuffled_w13, requires_grad=False)
layer.w2_weight = torch.nn.Parameter(shuffled_w2, requires_grad=False)
# INT4-FP8 : offset INT4 w13_weight_scale1 to single w13_weight_scale
# Fp8 moe kernel needs single fp8 w13_weight_scale for w13 per expert.
# We won't do requant each expert's fp8 weight (not direct available),
# instead we adjust half of INT4 w13_weight_scale1 numbers
shard_size = layer.intermediate_size_per_partition
max_w13_scales = layer.w13_weight_scale.max(dim=1).values
assert torch.all(max_w13_scales != 0), "fp8 weight scale cannot be zero."
for expert_id in range(layer.local_num_experts):
start = 0
max_w13_scale_fp8 = max_w13_scales[expert_id]
for shard_id in range(2):
if layer.w13_weight_scale[expert_id][shard_id] != max_w13_scale_fp8:
int4_rescale = (
layer.w13_weight_scale[expert_id][shard_id] / max_w13_scale_fp8
)
layer.w13_weight_scale_2[expert_id][start : start + shard_size] *= (
int4_rescale
)
start += shard_size
layer.w13_weight_scale = torch.nn.Parameter(max_w13_scales, requires_grad=False)
# special hack to asm_moe, which takes (weight_scale1 * weight_scale) as post
# GEMM scaling optimal design - shall apply per-column weight_scale1 before
# GEMM, and weight_scale post
for expert_id in range(layer.local_num_experts):
layer.w13_weight_scale_2[expert_id] *= max_w13_scales[expert_id]
layer.w2_weight_scale_2[expert_id] *= layer.w2_weight_scale[expert_id]
def get_fused_moe_quant_config(self, layer):
return fp8_w8a8_moe_quant_config(
w1_scale=layer.w13_weight_scale_2,
w2_scale=layer.w2_weight_scale_2,
per_out_ch_quant=True,
)
def apply(
self,
layer: FusedMoE,
x: torch.Tensor,
router_logits: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
topk_weights, topk_ids, _ = layer.select_experts(
hidden_states=x,
router_logits=router_logits,
)
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
rocm_aiter_fused_experts,
)
return rocm_aiter_fused_experts(
hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
activation=layer.activation,
apply_router_weight_on_input=layer.apply_router_weight_on_input,
quant_config=self.moe_quant_config,
expert_map=layer.expert_map,
)
class QuarkOCP_MX_MoEMethod(QuarkMoEMethod):
def __init__(
self,