mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-24 11:45:59 +08:00
[Quantization] Support Quark int4-fp8 w4a8 for MoE (#30071)
Signed-off-by: Bowen Bao <bowenbao@amd.com>
This commit is contained in:
parent
5a3adf581e
commit
0c738b58bc
@ -218,6 +218,49 @@ class QuarkConfig(QuantizationConfig):
|
||||
else:
|
||||
return False
|
||||
|
||||
def _is_fp8_w4a8(
|
||||
self,
|
||||
weight_quant: list[dict[str, Any]] | None,
|
||||
input_quant: dict[str, Any] | None,
|
||||
) -> bool:
|
||||
# Confirm weights and input quantized.
|
||||
if weight_quant is None or input_quant is None:
|
||||
return False
|
||||
|
||||
if not isinstance(weight_quant, list) or len(weight_quant) != 2:
|
||||
return False
|
||||
|
||||
# Confirm weight scheme is supported
|
||||
is_w4a8_dtype = (
|
||||
weight_quant[0].get("dtype") == "fp8_e4m3"
|
||||
and weight_quant[1].get("dtype") == "int4"
|
||||
and input_quant.get("dtype") == "fp8_e4m3"
|
||||
)
|
||||
is_static_weight = not weight_quant[0].get("is_dynamic") and not weight_quant[
|
||||
1
|
||||
].get("is_dynamic")
|
||||
is_per_tensor_fp8_and_per_channel_int4_weight = (
|
||||
weight_quant[0].get("qscheme") == "per_tensor"
|
||||
and weight_quant[1].get("qscheme") == "per_channel"
|
||||
and weight_quant[1].get("symmetric") is True
|
||||
and weight_quant[1].get("ch_axis") == 0
|
||||
)
|
||||
|
||||
if not (
|
||||
is_w4a8_dtype
|
||||
and is_static_weight
|
||||
and is_per_tensor_fp8_and_per_channel_int4_weight
|
||||
):
|
||||
return False
|
||||
|
||||
# Dynamic quantization is always supported if weights supported.
|
||||
if input_quant.get("is_dynamic"):
|
||||
return True
|
||||
|
||||
# Confirm activation scheme is supported.
|
||||
is_per_tensor_activation = input_quant.get("qscheme") == "per_tensor"
|
||||
return is_per_tensor_activation
|
||||
|
||||
def _is_fp8_w8a8(
|
||||
self,
|
||||
weight_quant: dict[str, Any] | None,
|
||||
|
||||
@ -63,8 +63,9 @@ class QuarkMoEMethod(FusedMoEMethodBase):
|
||||
)
|
||||
weight_config = layer_quant_config.get("weight")
|
||||
input_config = layer_quant_config.get("input_tensors")
|
||||
|
||||
if quant_config._is_fp8_w8a8(weight_config, input_config):
|
||||
if quant_config._is_fp8_w4a8(weight_config, input_config):
|
||||
return QuarkW4A8Fp8MoEMethod(weight_config, input_config, module.moe_config)
|
||||
elif quant_config._is_fp8_w8a8(weight_config, input_config):
|
||||
return QuarkW8A8Fp8MoEMethod(weight_config, input_config, module.moe_config)
|
||||
elif quant_config._is_ocp_mx(weight_config, input_config):
|
||||
return QuarkOCP_MX_MoEMethod(weight_config, input_config, module.moe_config)
|
||||
@ -396,6 +397,161 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod):
|
||||
)
|
||||
|
||||
|
||||
class QuarkW4A8Fp8MoEMethod(QuarkMoEMethod):
|
||||
def __init__(
|
||||
self,
|
||||
weight_config: dict[str, Any],
|
||||
input_config: dict[str, Any],
|
||||
moe: FusedMoEConfig,
|
||||
):
|
||||
super().__init__(moe)
|
||||
self.weight_quant = weight_config
|
||||
self.input_quant = input_config
|
||||
|
||||
assert rocm_aiter_ops.is_fused_moe_enabled(), (
|
||||
"W4A8 FP8 MoE requires ROCm AITER fused MoE support."
|
||||
)
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
num_experts: int,
|
||||
hidden_size: int,
|
||||
intermediate_size_per_partition: int,
|
||||
params_dtype: torch.dtype,
|
||||
**extra_weight_attrs,
|
||||
):
|
||||
params_dtype = torch.uint32
|
||||
w13_weight = torch.nn.Parameter(
|
||||
torch.empty(
|
||||
num_experts,
|
||||
2 * intermediate_size_per_partition,
|
||||
hidden_size // 8, # INT32 packing for W4
|
||||
dtype=params_dtype,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
w2_weight = torch.nn.Parameter(
|
||||
torch.empty(
|
||||
num_experts,
|
||||
hidden_size,
|
||||
intermediate_size_per_partition // 8, # INT32 packing for W4
|
||||
dtype=params_dtype,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w13_weight", w13_weight)
|
||||
layer.register_parameter("w2_weight", w2_weight)
|
||||
set_weight_attrs(w13_weight, extra_weight_attrs)
|
||||
set_weight_attrs(w2_weight, extra_weight_attrs)
|
||||
|
||||
# Per-tensor fp8 weight scales
|
||||
w13_weight_scale = torch.nn.Parameter(
|
||||
torch.ones(num_experts, 2, dtype=torch.float32), requires_grad=False
|
||||
)
|
||||
w2_weight_scale = torch.nn.Parameter(
|
||||
torch.ones(num_experts, dtype=torch.float32), requires_grad=False
|
||||
)
|
||||
layer.register_parameter("w13_weight_scale", w13_weight_scale)
|
||||
layer.register_parameter("w2_weight_scale", w2_weight_scale)
|
||||
extra_weight_attrs.update(
|
||||
{"quant_method": FusedMoeWeightScaleSupported.TENSOR.value}
|
||||
)
|
||||
set_weight_attrs(w13_weight_scale, extra_weight_attrs)
|
||||
set_weight_attrs(w2_weight_scale, extra_weight_attrs)
|
||||
|
||||
# Per-channel int4 weight scales
|
||||
w13_weight_scale_2 = torch.nn.Parameter(
|
||||
torch.ones(
|
||||
num_experts,
|
||||
2 * intermediate_size_per_partition,
|
||||
dtype=torch.float32,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
w2_weight_scale_2 = torch.nn.Parameter(
|
||||
torch.ones(num_experts, hidden_size, dtype=torch.float32),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w13_weight_scale_2", w13_weight_scale_2)
|
||||
layer.register_parameter("w2_weight_scale_2", w2_weight_scale_2)
|
||||
extra_weight_attrs.update(
|
||||
{"quant_method": FusedMoeWeightScaleSupported.CHANNEL.value}
|
||||
)
|
||||
set_weight_attrs(w13_weight_scale_2, extra_weight_attrs)
|
||||
set_weight_attrs(w2_weight_scale_2, extra_weight_attrs)
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
shuffled_w13, shuffled_w2 = rocm_aiter_ops.shuffle_weights(
|
||||
layer.w13_weight.data, layer.w2_weight.data
|
||||
)
|
||||
layer.w13_weight = torch.nn.Parameter(shuffled_w13, requires_grad=False)
|
||||
layer.w2_weight = torch.nn.Parameter(shuffled_w2, requires_grad=False)
|
||||
|
||||
# INT4-FP8 : offset INT4 w13_weight_scale1 to single w13_weight_scale
|
||||
# Fp8 moe kernel needs single fp8 w13_weight_scale for w13 per expert.
|
||||
# We won't do requant each expert's fp8 weight (not direct available),
|
||||
# instead we adjust half of INT4 w13_weight_scale1 numbers
|
||||
shard_size = layer.intermediate_size_per_partition
|
||||
max_w13_scales = layer.w13_weight_scale.max(dim=1).values
|
||||
assert torch.all(max_w13_scales != 0), "fp8 weight scale cannot be zero."
|
||||
for expert_id in range(layer.local_num_experts):
|
||||
start = 0
|
||||
max_w13_scale_fp8 = max_w13_scales[expert_id]
|
||||
for shard_id in range(2):
|
||||
if layer.w13_weight_scale[expert_id][shard_id] != max_w13_scale_fp8:
|
||||
int4_rescale = (
|
||||
layer.w13_weight_scale[expert_id][shard_id] / max_w13_scale_fp8
|
||||
)
|
||||
layer.w13_weight_scale_2[expert_id][start : start + shard_size] *= (
|
||||
int4_rescale
|
||||
)
|
||||
start += shard_size
|
||||
|
||||
layer.w13_weight_scale = torch.nn.Parameter(max_w13_scales, requires_grad=False)
|
||||
|
||||
# special hack to asm_moe, which takes (weight_scale1 * weight_scale) as post
|
||||
# GEMM scaling optimal design - shall apply per-column weight_scale1 before
|
||||
# GEMM, and weight_scale post
|
||||
for expert_id in range(layer.local_num_experts):
|
||||
layer.w13_weight_scale_2[expert_id] *= max_w13_scales[expert_id]
|
||||
layer.w2_weight_scale_2[expert_id] *= layer.w2_weight_scale[expert_id]
|
||||
|
||||
def get_fused_moe_quant_config(self, layer):
|
||||
return fp8_w8a8_moe_quant_config(
|
||||
w1_scale=layer.w13_weight_scale_2,
|
||||
w2_scale=layer.w2_weight_scale_2,
|
||||
per_out_ch_quant=True,
|
||||
)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: FusedMoE,
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
topk_weights, topk_ids, _ = layer.select_experts(
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
)
|
||||
|
||||
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
|
||||
rocm_aiter_fused_experts,
|
||||
)
|
||||
|
||||
return rocm_aiter_fused_experts(
|
||||
hidden_states=x,
|
||||
w1=layer.w13_weight,
|
||||
w2=layer.w2_weight,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
activation=layer.activation,
|
||||
apply_router_weight_on_input=layer.apply_router_weight_on_input,
|
||||
quant_config=self.moe_quant_config,
|
||||
expert_map=layer.expert_map,
|
||||
)
|
||||
|
||||
|
||||
class QuarkOCP_MX_MoEMethod(QuarkMoEMethod):
|
||||
def __init__(
|
||||
self,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user