[Rocm] [quantization] Fix quark ptpc moe and add test case (#24649)

Signed-off-by: Haoyang Li <lihaoyang0109@gmail.com>
Co-authored-by: Haoyang Li <haoyang.li@amd.com>
This commit is contained in:
haoyangli-amd 2025-09-17 13:15:13 +08:00 committed by GitHub
parent 0f7acdd73c
commit ca2d1925ef
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 196 additions and 52 deletions

View File

@ -77,6 +77,31 @@ def test_quark_fp8_w_per_tensor_a_per_tensor(vllm_runner, kv_cache_dtype, tp):
assert output assert output
@pytest.mark.parametrize('tp', [1])
def test_quark_fp8_w_per_channel_a_per_token(vllm_runner, tp):
model_path = "amd/Qwen2.5-1.5B-Instruct-ptpc-Quark-ts"
with vllm_runner(model_path, tensor_parallel_size=tp) as llm:
def check_model(model):
layer = model.model.layers[0]
qkv_proj = layer.self_attn.qkv_proj
assert isinstance(qkv_proj.quant_method, QuarkLinearMethod)
assert isinstance(qkv_proj.scheme, QuarkW8A8Fp8)
if isinstance(qkv_proj.scheme, QuarkW8A8Fp8):
assert qkv_proj.weight.dtype is current_platform.fp8_dtype()
assert qkv_proj.weight_scale.shape[0] == qkv_proj.weight.shape[
1]
assert qkv_proj.weight_scale.shape[1] == 1
llm.apply_model(check_model)
output = llm.generate_greedy("Hello my name is", max_tokens=20)
assert output
@pytest.mark.parametrize('tp', [1]) @pytest.mark.parametrize('tp', [1])
def test_quark_int8_w_per_tensor_a_per_tensor(vllm_runner, tp): def test_quark_int8_w_per_tensor_a_per_tensor(vllm_runner, tp):
model_path = "amd/Llama-3.1-8B-Instruct-w-int8-a-int8-sym-test" model_path = "amd/Llama-3.1-8B-Instruct-w-int8-a-int8-sym-test"

View File

@ -5,17 +5,25 @@ from typing import Any, Callable, Optional, Union
import torch import torch
import vllm.envs as envs
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEConfig, from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEConfig,
FusedMoEMethodBase, FusedMoEMethodBase,
FusedMoeWeightScaleSupported) FusedMoeWeightScaleSupported)
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
is_rocm_aiter_moe_enabled)
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
prepare_moe_fp8_layer_for_marlin)
from vllm.model_executor.layers.quantization.utils.mxfp4_utils import ( from vllm.model_executor.layers.quantization.utils.mxfp4_utils import (
OCP_MX_BLOCK_SIZE) OCP_MX_BLOCK_SIZE)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
all_close_1d, normalize_e4m3fn_to_e4m3fnuz, per_tensor_dequantize) all_close_1d, normalize_e4m3fn_to_e4m3fnuz, per_tensor_dequantize)
from vllm.model_executor.utils import set_weight_attrs from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.scalar_type import scalar_types
logger = init_logger(__name__) logger = init_logger(__name__)
@ -67,21 +75,45 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod):
self.weight_quant = weight_config self.weight_quant = weight_config
self.input_quant = input_config self.input_quant = input_config
weight_qscheme = self.weight_quant.get("qscheme") self.weight_qscheme = self.weight_quant.get("qscheme")
input_qscheme = self.input_quant.get("qscheme") self.input_qscheme = self.input_quant.get("qscheme")
if not (weight_qscheme == "per_tensor" per_tensor = (self.weight_qscheme == "per_tensor"
and input_qscheme == "per_tensor"): and self.input_qscheme == "per_tensor")
per_channel = (self.weight_qscheme == "per_channel"
and self.input_qscheme == "per_channel")
self.act_quant_group_shape = GroupShape.PER_TOKEN \
if per_channel else GroupShape.PER_TENSOR
if not (per_tensor or per_channel):
raise ValueError( raise ValueError(
"For FP8 Fused MoE layers, only per-tensor scales " "For FP8 Fused MoE layers, only per-tensor and per-channel "
"for weights and activations are supported. Found " "scales for weights and activations are supported. Found "
f"{weight_qscheme}, {input_qscheme}") # noqa E501 f"{self.weight_qscheme}, {self.input_qscheme}") # noqa E501
self.static_input_scales = not self.input_quant.get("is_dynamic") self.static_input_scales = not self.input_quant.get("is_dynamic")
if self.static_input_scales and per_channel:
raise ValueError(
"For FP8 Fused MoE layer, we require either per tensor or "
"channelwise, dynamic per token quantization.")
# For GPUs that lack FP8 hardware support, we can leverage the Marlin
# kernel for fast weight-only FP8 quantization
self.use_marlin = (not current_platform.has_device_capability(89)
or envs.VLLM_TEST_FORCE_FP8_MARLIN)
# Disable marlin for rocm
if current_platform.is_rocm():
self.use_marlin = False
self.rocm_aiter_moe_enabled = is_rocm_aiter_moe_enabled()
def create_weights(self, layer: torch.nn.Module, num_experts: int, def create_weights(self, layer: torch.nn.Module, num_experts: int,
hidden_size: int, intermediate_size_per_partition: int, hidden_size: int, intermediate_size_per_partition: int,
params_dtype: torch.dtype, **extra_weight_attrs): params_dtype: torch.dtype, **extra_weight_attrs):
layer.intermediate_size_per_partition = intermediate_size_per_partition
layer.hidden_size = hidden_size
layer.num_experts = num_experts
layer.orig_dtype = params_dtype
layer.weight_block_size = None
params_dtype = torch.float8_e4m3fn params_dtype = torch.float8_e4m3fn
# WEIGHTS # WEIGHTS
@ -104,24 +136,39 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod):
set_weight_attrs(w2_weight, extra_weight_attrs) set_weight_attrs(w2_weight, extra_weight_attrs)
# WEIGHT_SCALES # WEIGHT_SCALES
# Allocate 2 scales for w1 and w3 respectively. if self.weight_qscheme == "per_tensor":
# They will be combined to a single scale after weight loading. # Allocate 2 scales for w1 and w3 respectively.
w13_weight_scale = torch.nn.Parameter(torch.ones(num_experts, # They are combined to a single scale after weight loading.
2, w13_weight_scale = torch.nn.Parameter(torch.ones(
dtype=torch.float32), num_experts, 2, dtype=torch.float32),
requires_grad=False) requires_grad=False)
layer.register_parameter("w13_weight_scale", w13_weight_scale) layer.register_parameter("w13_weight_scale", w13_weight_scale)
w2_weight_scale = torch.nn.Parameter(torch.ones(
w2_weight_scale = torch.nn.Parameter(torch.ones(num_experts, num_experts, dtype=torch.float32),
dtype=torch.float32), requires_grad=False)
requires_grad=False) layer.register_parameter("w2_weight_scale", w2_weight_scale)
layer.register_parameter("w2_weight_scale", w2_weight_scale) # Add PER-TENSOR quantization for FusedMoE.weight_loader.
# Add the quantization method used (per tensor/grouped/channel) extra_weight_attrs.update(
# to ensure the weight scales are loaded in properly {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value})
extra_weight_attrs.update( set_weight_attrs(w13_weight_scale, extra_weight_attrs)
{"quant_method": FusedMoeWeightScaleSupported.TENSOR.value}) set_weight_attrs(w2_weight_scale, extra_weight_attrs)
set_weight_attrs(w13_weight_scale, extra_weight_attrs) elif self.weight_qscheme == "per_channel":
set_weight_attrs(w2_weight_scale, extra_weight_attrs) # quark's scale is 1 dim.
w13_weight_scale = torch.nn.Parameter(torch.ones(
num_experts,
2 * intermediate_size_per_partition,
dtype=torch.float32),
requires_grad=False)
layer.register_parameter("w13_weight_scale", w13_weight_scale)
w2_weight_scale = torch.nn.Parameter(torch.ones(
num_experts, hidden_size, dtype=torch.float32),
requires_grad=False)
layer.register_parameter("w2_weight_scale", w2_weight_scale)
# Add PER-CHANNEL quantization for FusedMoE.weight_loader.
extra_weight_attrs.update(
{"quant_method": FusedMoeWeightScaleSupported.CHANNEL.value})
set_weight_attrs(w13_weight_scale, extra_weight_attrs)
set_weight_attrs(w2_weight_scale, extra_weight_attrs)
# INPUT_SCALES # INPUT_SCALES
if self.static_input_scales: if self.static_input_scales:
@ -185,24 +232,60 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod):
layer.w2_input_scale = torch.nn.Parameter(w2_input_scale, layer.w2_input_scale = torch.nn.Parameter(w2_input_scale,
requires_grad=False) requires_grad=False)
# Fp8 moe kernel needs single weight scale for w13 per expert. # For per-tensor case, Fp8 moe kernel needs single weight scale
# We take the max then dequant and requant each expert. # for w13 per expert. Use max then dequant and requant each expert.
assert layer.w13_weight_scale is not None if self.weight_qscheme == "per_tensor":
shard_size = layer.intermediate_size_per_partition assert layer.w13_weight_scale is not None
max_w13_scales = layer.w13_weight_scale.max(dim=1).values shard_size = layer.intermediate_size_per_partition
for expert_id in range(layer.local_num_experts): max_w13_scales = layer.w13_weight_scale.max(dim=1).values
start = 0 for expert_id in range(layer.local_num_experts):
for shard_id in range(2): start = 0
dq_weight = per_tensor_dequantize( for shard_id in range(2):
layer.w13_weight[expert_id][start:start + shard_size, :], dq_weight = per_tensor_dequantize(
layer.w13_weight_scale[expert_id][shard_id]) layer.w13_weight[expert_id][start:start +
layer.w13_weight[expert_id][ shard_size, :],
start:start + shard_size, :], _ = ops.scaled_fp8_quant( layer.w13_weight_scale[expert_id][shard_id])
dq_weight, max_w13_scales[expert_id]) layer.w13_weight[expert_id][
start += shard_size start:start + shard_size, :], _ = ops.scaled_fp8_quant(
dq_weight, max_w13_scales[expert_id])
start += shard_size
layer.w13_weight_scale = torch.nn.Parameter(max_w13_scales, layer.w13_weight_scale = torch.nn.Parameter(max_w13_scales,
requires_grad=False) requires_grad=False)
# quark's scale is 1 dim.
elif self.weight_qscheme == "per_channel":
if self.act_quant_group_shape == GroupShape.PER_TOKEN:
w13_weight_scale = layer.w13_weight_scale.unsqueeze(-1)
layer.w13_weight_scale = torch.nn.Parameter(
w13_weight_scale, requires_grad=False)
w2_weight_scale = layer.w2_weight_scale.unsqueeze(-1)
layer.w2_weight_scale = torch.nn.Parameter(w2_weight_scale,
requires_grad=False)
# Property to determine if AITER is used
if self.rocm_aiter_moe_enabled:
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( # noqa E501
rocm_aiter_fused_experts, shuffle_weights)
# reshaping weights is required for aiter moe kernel.
shuffled_w13, shuffled_w2 = shuffle_weights(
layer.w13_weight.data, layer.w2_weight.data)
layer.w13_weight = torch.nn.Parameter(shuffled_w13,
requires_grad=False)
layer.w2_weight = torch.nn.Parameter(shuffled_w2,
requires_grad=False)
self.rocm_aiter_fused_experts_func = rocm_aiter_fused_experts
elif self.use_marlin:
prepare_moe_fp8_layer_for_marlin(layer, False)
# Activations not quantized for marlin.
del layer.w13_input_scale
del layer.w2_input_scale
self.fused_experts_func = None
else:
from vllm.model_executor.layers.fused_moe import fused_experts
self.fused_experts_func = fused_experts
def apply( def apply(
self, self,
@ -233,8 +316,6 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod):
raise NotImplementedError( raise NotImplementedError(
"EPLB not supported for `QuarkW8A8Fp8MoEMethod` yet.") "EPLB not supported for `QuarkW8A8Fp8MoEMethod` yet.")
from vllm.model_executor.layers.fused_moe import fused_experts
topk_weights, topk_ids = FusedMoE.select_experts( topk_weights, topk_ids = FusedMoE.select_experts(
hidden_states=x, hidden_states=x,
router_logits=router_logits, router_logits=router_logits,
@ -249,22 +330,60 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod):
e_score_correction_bias=e_score_correction_bias, e_score_correction_bias=e_score_correction_bias,
indices_type=self.topk_indices_dtype) indices_type=self.topk_indices_dtype)
return fused_experts( if self.rocm_aiter_moe_enabled:
x, return self.rocm_aiter_fused_experts_func(
layer.w13_weight, hidden_states=x,
layer.w2_weight, w1=layer.w13_weight,
w2=layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input,
use_fp8_w8a8=True,
per_channel_quant=self.weight_qscheme == "per_channel",
w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
a1_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale,
expert_map=expert_map)
if self.use_marlin:
assert activation == "silu", (
f"{activation} not supported for Marlin MoE.")
return torch.ops.vllm.fused_marlin_moe(
x,
layer.w13_weight,
layer.w2_weight,
None,
None,
layer.w13_weight_scale,
layer.w2_weight_scale,
router_logits,
topk_weights,
topk_ids,
quant_type_id=scalar_types.float8_e4m3fn.id,
apply_router_weight_on_input=apply_router_weight_on_input,
global_num_experts=global_num_experts,
expert_map=expert_map)
assert self.fused_experts_func is not None
return self.fused_experts_func(
hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
topk_weights=topk_weights, topk_weights=topk_weights,
topk_ids=topk_ids, topk_ids=topk_ids,
inplace=True, inplace=True,
use_fp8_w8a8=True, activation=activation,
global_num_experts=global_num_experts,
apply_router_weight_on_input=apply_router_weight_on_input, apply_router_weight_on_input=apply_router_weight_on_input,
use_fp8_w8a8=True,
per_channel_quant=self.weight_qscheme == "per_channel",
global_num_experts=global_num_experts,
expert_map=expert_map, expert_map=expert_map,
w1_scale=layer.w13_weight_scale, w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale, w2_scale=layer.w2_weight_scale,
a1_scale=layer.w13_input_scale, a1_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale, a2_scale=layer.w2_input_scale)
activation=activation)
class QuarkW4A4MXFp4MoEMethod(QuarkMoEMethod): class QuarkW4A4MXFp4MoEMethod(QuarkMoEMethod):