mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-31 00:27:12 +08:00
[Quantization] Support Quark int4-fp8 w4a8 for MoE (#30071)
Signed-off-by: Bowen Bao <bowenbao@amd.com>
This commit is contained in:
parent
5a3adf581e
commit
0c738b58bc
@ -218,6 +218,49 @@ class QuarkConfig(QuantizationConfig):
|
|||||||
else:
|
else:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
def _is_fp8_w4a8(
|
||||||
|
self,
|
||||||
|
weight_quant: list[dict[str, Any]] | None,
|
||||||
|
input_quant: dict[str, Any] | None,
|
||||||
|
) -> bool:
|
||||||
|
# Confirm weights and input quantized.
|
||||||
|
if weight_quant is None or input_quant is None:
|
||||||
|
return False
|
||||||
|
|
||||||
|
if not isinstance(weight_quant, list) or len(weight_quant) != 2:
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Confirm weight scheme is supported
|
||||||
|
is_w4a8_dtype = (
|
||||||
|
weight_quant[0].get("dtype") == "fp8_e4m3"
|
||||||
|
and weight_quant[1].get("dtype") == "int4"
|
||||||
|
and input_quant.get("dtype") == "fp8_e4m3"
|
||||||
|
)
|
||||||
|
is_static_weight = not weight_quant[0].get("is_dynamic") and not weight_quant[
|
||||||
|
1
|
||||||
|
].get("is_dynamic")
|
||||||
|
is_per_tensor_fp8_and_per_channel_int4_weight = (
|
||||||
|
weight_quant[0].get("qscheme") == "per_tensor"
|
||||||
|
and weight_quant[1].get("qscheme") == "per_channel"
|
||||||
|
and weight_quant[1].get("symmetric") is True
|
||||||
|
and weight_quant[1].get("ch_axis") == 0
|
||||||
|
)
|
||||||
|
|
||||||
|
if not (
|
||||||
|
is_w4a8_dtype
|
||||||
|
and is_static_weight
|
||||||
|
and is_per_tensor_fp8_and_per_channel_int4_weight
|
||||||
|
):
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Dynamic quantization is always supported if weights supported.
|
||||||
|
if input_quant.get("is_dynamic"):
|
||||||
|
return True
|
||||||
|
|
||||||
|
# Confirm activation scheme is supported.
|
||||||
|
is_per_tensor_activation = input_quant.get("qscheme") == "per_tensor"
|
||||||
|
return is_per_tensor_activation
|
||||||
|
|
||||||
def _is_fp8_w8a8(
|
def _is_fp8_w8a8(
|
||||||
self,
|
self,
|
||||||
weight_quant: dict[str, Any] | None,
|
weight_quant: dict[str, Any] | None,
|
||||||
|
|||||||
@ -63,8 +63,9 @@ class QuarkMoEMethod(FusedMoEMethodBase):
|
|||||||
)
|
)
|
||||||
weight_config = layer_quant_config.get("weight")
|
weight_config = layer_quant_config.get("weight")
|
||||||
input_config = layer_quant_config.get("input_tensors")
|
input_config = layer_quant_config.get("input_tensors")
|
||||||
|
if quant_config._is_fp8_w4a8(weight_config, input_config):
|
||||||
if quant_config._is_fp8_w8a8(weight_config, input_config):
|
return QuarkW4A8Fp8MoEMethod(weight_config, input_config, module.moe_config)
|
||||||
|
elif quant_config._is_fp8_w8a8(weight_config, input_config):
|
||||||
return QuarkW8A8Fp8MoEMethod(weight_config, input_config, module.moe_config)
|
return QuarkW8A8Fp8MoEMethod(weight_config, input_config, module.moe_config)
|
||||||
elif quant_config._is_ocp_mx(weight_config, input_config):
|
elif quant_config._is_ocp_mx(weight_config, input_config):
|
||||||
return QuarkOCP_MX_MoEMethod(weight_config, input_config, module.moe_config)
|
return QuarkOCP_MX_MoEMethod(weight_config, input_config, module.moe_config)
|
||||||
@ -396,6 +397,161 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class QuarkW4A8Fp8MoEMethod(QuarkMoEMethod):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
weight_config: dict[str, Any],
|
||||||
|
input_config: dict[str, Any],
|
||||||
|
moe: FusedMoEConfig,
|
||||||
|
):
|
||||||
|
super().__init__(moe)
|
||||||
|
self.weight_quant = weight_config
|
||||||
|
self.input_quant = input_config
|
||||||
|
|
||||||
|
assert rocm_aiter_ops.is_fused_moe_enabled(), (
|
||||||
|
"W4A8 FP8 MoE requires ROCm AITER fused MoE support."
|
||||||
|
)
|
||||||
|
|
||||||
|
def create_weights(
|
||||||
|
self,
|
||||||
|
layer: torch.nn.Module,
|
||||||
|
num_experts: int,
|
||||||
|
hidden_size: int,
|
||||||
|
intermediate_size_per_partition: int,
|
||||||
|
params_dtype: torch.dtype,
|
||||||
|
**extra_weight_attrs,
|
||||||
|
):
|
||||||
|
params_dtype = torch.uint32
|
||||||
|
w13_weight = torch.nn.Parameter(
|
||||||
|
torch.empty(
|
||||||
|
num_experts,
|
||||||
|
2 * intermediate_size_per_partition,
|
||||||
|
hidden_size // 8, # INT32 packing for W4
|
||||||
|
dtype=params_dtype,
|
||||||
|
),
|
||||||
|
requires_grad=False,
|
||||||
|
)
|
||||||
|
w2_weight = torch.nn.Parameter(
|
||||||
|
torch.empty(
|
||||||
|
num_experts,
|
||||||
|
hidden_size,
|
||||||
|
intermediate_size_per_partition // 8, # INT32 packing for W4
|
||||||
|
dtype=params_dtype,
|
||||||
|
),
|
||||||
|
requires_grad=False,
|
||||||
|
)
|
||||||
|
layer.register_parameter("w13_weight", w13_weight)
|
||||||
|
layer.register_parameter("w2_weight", w2_weight)
|
||||||
|
set_weight_attrs(w13_weight, extra_weight_attrs)
|
||||||
|
set_weight_attrs(w2_weight, extra_weight_attrs)
|
||||||
|
|
||||||
|
# Per-tensor fp8 weight scales
|
||||||
|
w13_weight_scale = torch.nn.Parameter(
|
||||||
|
torch.ones(num_experts, 2, dtype=torch.float32), requires_grad=False
|
||||||
|
)
|
||||||
|
w2_weight_scale = torch.nn.Parameter(
|
||||||
|
torch.ones(num_experts, dtype=torch.float32), requires_grad=False
|
||||||
|
)
|
||||||
|
layer.register_parameter("w13_weight_scale", w13_weight_scale)
|
||||||
|
layer.register_parameter("w2_weight_scale", w2_weight_scale)
|
||||||
|
extra_weight_attrs.update(
|
||||||
|
{"quant_method": FusedMoeWeightScaleSupported.TENSOR.value}
|
||||||
|
)
|
||||||
|
set_weight_attrs(w13_weight_scale, extra_weight_attrs)
|
||||||
|
set_weight_attrs(w2_weight_scale, extra_weight_attrs)
|
||||||
|
|
||||||
|
# Per-channel int4 weight scales
|
||||||
|
w13_weight_scale_2 = torch.nn.Parameter(
|
||||||
|
torch.ones(
|
||||||
|
num_experts,
|
||||||
|
2 * intermediate_size_per_partition,
|
||||||
|
dtype=torch.float32,
|
||||||
|
),
|
||||||
|
requires_grad=False,
|
||||||
|
)
|
||||||
|
w2_weight_scale_2 = torch.nn.Parameter(
|
||||||
|
torch.ones(num_experts, hidden_size, dtype=torch.float32),
|
||||||
|
requires_grad=False,
|
||||||
|
)
|
||||||
|
layer.register_parameter("w13_weight_scale_2", w13_weight_scale_2)
|
||||||
|
layer.register_parameter("w2_weight_scale_2", w2_weight_scale_2)
|
||||||
|
extra_weight_attrs.update(
|
||||||
|
{"quant_method": FusedMoeWeightScaleSupported.CHANNEL.value}
|
||||||
|
)
|
||||||
|
set_weight_attrs(w13_weight_scale_2, extra_weight_attrs)
|
||||||
|
set_weight_attrs(w2_weight_scale_2, extra_weight_attrs)
|
||||||
|
|
||||||
|
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||||
|
shuffled_w13, shuffled_w2 = rocm_aiter_ops.shuffle_weights(
|
||||||
|
layer.w13_weight.data, layer.w2_weight.data
|
||||||
|
)
|
||||||
|
layer.w13_weight = torch.nn.Parameter(shuffled_w13, requires_grad=False)
|
||||||
|
layer.w2_weight = torch.nn.Parameter(shuffled_w2, requires_grad=False)
|
||||||
|
|
||||||
|
# INT4-FP8 : offset INT4 w13_weight_scale1 to single w13_weight_scale
|
||||||
|
# Fp8 moe kernel needs single fp8 w13_weight_scale for w13 per expert.
|
||||||
|
# We won't do requant each expert's fp8 weight (not direct available),
|
||||||
|
# instead we adjust half of INT4 w13_weight_scale1 numbers
|
||||||
|
shard_size = layer.intermediate_size_per_partition
|
||||||
|
max_w13_scales = layer.w13_weight_scale.max(dim=1).values
|
||||||
|
assert torch.all(max_w13_scales != 0), "fp8 weight scale cannot be zero."
|
||||||
|
for expert_id in range(layer.local_num_experts):
|
||||||
|
start = 0
|
||||||
|
max_w13_scale_fp8 = max_w13_scales[expert_id]
|
||||||
|
for shard_id in range(2):
|
||||||
|
if layer.w13_weight_scale[expert_id][shard_id] != max_w13_scale_fp8:
|
||||||
|
int4_rescale = (
|
||||||
|
layer.w13_weight_scale[expert_id][shard_id] / max_w13_scale_fp8
|
||||||
|
)
|
||||||
|
layer.w13_weight_scale_2[expert_id][start : start + shard_size] *= (
|
||||||
|
int4_rescale
|
||||||
|
)
|
||||||
|
start += shard_size
|
||||||
|
|
||||||
|
layer.w13_weight_scale = torch.nn.Parameter(max_w13_scales, requires_grad=False)
|
||||||
|
|
||||||
|
# special hack to asm_moe, which takes (weight_scale1 * weight_scale) as post
|
||||||
|
# GEMM scaling optimal design - shall apply per-column weight_scale1 before
|
||||||
|
# GEMM, and weight_scale post
|
||||||
|
for expert_id in range(layer.local_num_experts):
|
||||||
|
layer.w13_weight_scale_2[expert_id] *= max_w13_scales[expert_id]
|
||||||
|
layer.w2_weight_scale_2[expert_id] *= layer.w2_weight_scale[expert_id]
|
||||||
|
|
||||||
|
def get_fused_moe_quant_config(self, layer):
|
||||||
|
return fp8_w8a8_moe_quant_config(
|
||||||
|
w1_scale=layer.w13_weight_scale_2,
|
||||||
|
w2_scale=layer.w2_weight_scale_2,
|
||||||
|
per_out_ch_quant=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
def apply(
|
||||||
|
self,
|
||||||
|
layer: FusedMoE,
|
||||||
|
x: torch.Tensor,
|
||||||
|
router_logits: torch.Tensor,
|
||||||
|
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
topk_weights, topk_ids, _ = layer.select_experts(
|
||||||
|
hidden_states=x,
|
||||||
|
router_logits=router_logits,
|
||||||
|
)
|
||||||
|
|
||||||
|
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
|
||||||
|
rocm_aiter_fused_experts,
|
||||||
|
)
|
||||||
|
|
||||||
|
return rocm_aiter_fused_experts(
|
||||||
|
hidden_states=x,
|
||||||
|
w1=layer.w13_weight,
|
||||||
|
w2=layer.w2_weight,
|
||||||
|
topk_weights=topk_weights,
|
||||||
|
topk_ids=topk_ids,
|
||||||
|
activation=layer.activation,
|
||||||
|
apply_router_weight_on_input=layer.apply_router_weight_on_input,
|
||||||
|
quant_config=self.moe_quant_config,
|
||||||
|
expert_map=layer.expert_map,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class QuarkOCP_MX_MoEMethod(QuarkMoEMethod):
|
class QuarkOCP_MX_MoEMethod(QuarkMoEMethod):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user