mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-03 00:57:56 +08:00
[Rocm] [quantization] Fix quark ptpc moe and add test case (#24649)
Signed-off-by: Haoyang Li <lihaoyang0109@gmail.com> Co-authored-by: Haoyang Li <haoyang.li@amd.com>
This commit is contained in:
parent
0f7acdd73c
commit
ca2d1925ef
@ -77,6 +77,31 @@ def test_quark_fp8_w_per_tensor_a_per_tensor(vllm_runner, kv_cache_dtype, tp):
|
|||||||
assert output
|
assert output
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize('tp', [1])
|
||||||
|
def test_quark_fp8_w_per_channel_a_per_token(vllm_runner, tp):
|
||||||
|
model_path = "amd/Qwen2.5-1.5B-Instruct-ptpc-Quark-ts"
|
||||||
|
with vllm_runner(model_path, tensor_parallel_size=tp) as llm:
|
||||||
|
|
||||||
|
def check_model(model):
|
||||||
|
layer = model.model.layers[0]
|
||||||
|
|
||||||
|
qkv_proj = layer.self_attn.qkv_proj
|
||||||
|
|
||||||
|
assert isinstance(qkv_proj.quant_method, QuarkLinearMethod)
|
||||||
|
assert isinstance(qkv_proj.scheme, QuarkW8A8Fp8)
|
||||||
|
|
||||||
|
if isinstance(qkv_proj.scheme, QuarkW8A8Fp8):
|
||||||
|
assert qkv_proj.weight.dtype is current_platform.fp8_dtype()
|
||||||
|
assert qkv_proj.weight_scale.shape[0] == qkv_proj.weight.shape[
|
||||||
|
1]
|
||||||
|
assert qkv_proj.weight_scale.shape[1] == 1
|
||||||
|
|
||||||
|
llm.apply_model(check_model)
|
||||||
|
|
||||||
|
output = llm.generate_greedy("Hello my name is", max_tokens=20)
|
||||||
|
assert output
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize('tp', [1])
|
@pytest.mark.parametrize('tp', [1])
|
||||||
def test_quark_int8_w_per_tensor_a_per_tensor(vllm_runner, tp):
|
def test_quark_int8_w_per_tensor_a_per_tensor(vllm_runner, tp):
|
||||||
model_path = "amd/Llama-3.1-8B-Instruct-w-int8-a-int8-sym-test"
|
model_path = "amd/Llama-3.1-8B-Instruct-w-int8-a-int8-sym-test"
|
||||||
|
|||||||
@ -5,17 +5,25 @@ from typing import Any, Callable, Optional, Union
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
import vllm.envs as envs
|
||||||
from vllm import _custom_ops as ops
|
from vllm import _custom_ops as ops
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEConfig,
|
from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEConfig,
|
||||||
FusedMoEMethodBase,
|
FusedMoEMethodBase,
|
||||||
FusedMoeWeightScaleSupported)
|
FusedMoeWeightScaleSupported)
|
||||||
|
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
|
||||||
|
is_rocm_aiter_moe_enabled)
|
||||||
|
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
|
||||||
|
prepare_moe_fp8_layer_for_marlin)
|
||||||
from vllm.model_executor.layers.quantization.utils.mxfp4_utils import (
|
from vllm.model_executor.layers.quantization.utils.mxfp4_utils import (
|
||||||
OCP_MX_BLOCK_SIZE)
|
OCP_MX_BLOCK_SIZE)
|
||||||
|
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||||
|
GroupShape)
|
||||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||||
all_close_1d, normalize_e4m3fn_to_e4m3fnuz, per_tensor_dequantize)
|
all_close_1d, normalize_e4m3fn_to_e4m3fnuz, per_tensor_dequantize)
|
||||||
from vllm.model_executor.utils import set_weight_attrs
|
from vllm.model_executor.utils import set_weight_attrs
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
|
from vllm.scalar_type import scalar_types
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
@ -67,21 +75,45 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod):
|
|||||||
self.weight_quant = weight_config
|
self.weight_quant = weight_config
|
||||||
self.input_quant = input_config
|
self.input_quant = input_config
|
||||||
|
|
||||||
weight_qscheme = self.weight_quant.get("qscheme")
|
self.weight_qscheme = self.weight_quant.get("qscheme")
|
||||||
input_qscheme = self.input_quant.get("qscheme")
|
self.input_qscheme = self.input_quant.get("qscheme")
|
||||||
if not (weight_qscheme == "per_tensor"
|
per_tensor = (self.weight_qscheme == "per_tensor"
|
||||||
and input_qscheme == "per_tensor"):
|
and self.input_qscheme == "per_tensor")
|
||||||
|
per_channel = (self.weight_qscheme == "per_channel"
|
||||||
|
and self.input_qscheme == "per_channel")
|
||||||
|
self.act_quant_group_shape = GroupShape.PER_TOKEN \
|
||||||
|
if per_channel else GroupShape.PER_TENSOR
|
||||||
|
if not (per_tensor or per_channel):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"For FP8 Fused MoE layers, only per-tensor scales "
|
"For FP8 Fused MoE layers, only per-tensor and per-channel "
|
||||||
"for weights and activations are supported. Found "
|
"scales for weights and activations are supported. Found "
|
||||||
f"{weight_qscheme}, {input_qscheme}") # noqa E501
|
f"{self.weight_qscheme}, {self.input_qscheme}") # noqa E501
|
||||||
|
|
||||||
self.static_input_scales = not self.input_quant.get("is_dynamic")
|
self.static_input_scales = not self.input_quant.get("is_dynamic")
|
||||||
|
if self.static_input_scales and per_channel:
|
||||||
|
raise ValueError(
|
||||||
|
"For FP8 Fused MoE layer, we require either per tensor or "
|
||||||
|
"channelwise, dynamic per token quantization.")
|
||||||
|
|
||||||
|
# For GPUs that lack FP8 hardware support, we can leverage the Marlin
|
||||||
|
# kernel for fast weight-only FP8 quantization
|
||||||
|
self.use_marlin = (not current_platform.has_device_capability(89)
|
||||||
|
or envs.VLLM_TEST_FORCE_FP8_MARLIN)
|
||||||
|
# Disable marlin for rocm
|
||||||
|
if current_platform.is_rocm():
|
||||||
|
self.use_marlin = False
|
||||||
|
|
||||||
|
self.rocm_aiter_moe_enabled = is_rocm_aiter_moe_enabled()
|
||||||
|
|
||||||
def create_weights(self, layer: torch.nn.Module, num_experts: int,
|
def create_weights(self, layer: torch.nn.Module, num_experts: int,
|
||||||
hidden_size: int, intermediate_size_per_partition: int,
|
hidden_size: int, intermediate_size_per_partition: int,
|
||||||
params_dtype: torch.dtype, **extra_weight_attrs):
|
params_dtype: torch.dtype, **extra_weight_attrs):
|
||||||
|
|
||||||
|
layer.intermediate_size_per_partition = intermediate_size_per_partition
|
||||||
|
layer.hidden_size = hidden_size
|
||||||
|
layer.num_experts = num_experts
|
||||||
|
layer.orig_dtype = params_dtype
|
||||||
|
layer.weight_block_size = None
|
||||||
params_dtype = torch.float8_e4m3fn
|
params_dtype = torch.float8_e4m3fn
|
||||||
|
|
||||||
# WEIGHTS
|
# WEIGHTS
|
||||||
@ -104,24 +136,39 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod):
|
|||||||
set_weight_attrs(w2_weight, extra_weight_attrs)
|
set_weight_attrs(w2_weight, extra_weight_attrs)
|
||||||
|
|
||||||
# WEIGHT_SCALES
|
# WEIGHT_SCALES
|
||||||
# Allocate 2 scales for w1 and w3 respectively.
|
if self.weight_qscheme == "per_tensor":
|
||||||
# They will be combined to a single scale after weight loading.
|
# Allocate 2 scales for w1 and w3 respectively.
|
||||||
w13_weight_scale = torch.nn.Parameter(torch.ones(num_experts,
|
# They are combined to a single scale after weight loading.
|
||||||
2,
|
w13_weight_scale = torch.nn.Parameter(torch.ones(
|
||||||
dtype=torch.float32),
|
num_experts, 2, dtype=torch.float32),
|
||||||
requires_grad=False)
|
requires_grad=False)
|
||||||
layer.register_parameter("w13_weight_scale", w13_weight_scale)
|
layer.register_parameter("w13_weight_scale", w13_weight_scale)
|
||||||
|
w2_weight_scale = torch.nn.Parameter(torch.ones(
|
||||||
w2_weight_scale = torch.nn.Parameter(torch.ones(num_experts,
|
num_experts, dtype=torch.float32),
|
||||||
dtype=torch.float32),
|
requires_grad=False)
|
||||||
requires_grad=False)
|
layer.register_parameter("w2_weight_scale", w2_weight_scale)
|
||||||
layer.register_parameter("w2_weight_scale", w2_weight_scale)
|
# Add PER-TENSOR quantization for FusedMoE.weight_loader.
|
||||||
# Add the quantization method used (per tensor/grouped/channel)
|
extra_weight_attrs.update(
|
||||||
# to ensure the weight scales are loaded in properly
|
{"quant_method": FusedMoeWeightScaleSupported.TENSOR.value})
|
||||||
extra_weight_attrs.update(
|
set_weight_attrs(w13_weight_scale, extra_weight_attrs)
|
||||||
{"quant_method": FusedMoeWeightScaleSupported.TENSOR.value})
|
set_weight_attrs(w2_weight_scale, extra_weight_attrs)
|
||||||
set_weight_attrs(w13_weight_scale, extra_weight_attrs)
|
elif self.weight_qscheme == "per_channel":
|
||||||
set_weight_attrs(w2_weight_scale, extra_weight_attrs)
|
# quark's scale is 1 dim.
|
||||||
|
w13_weight_scale = torch.nn.Parameter(torch.ones(
|
||||||
|
num_experts,
|
||||||
|
2 * intermediate_size_per_partition,
|
||||||
|
dtype=torch.float32),
|
||||||
|
requires_grad=False)
|
||||||
|
layer.register_parameter("w13_weight_scale", w13_weight_scale)
|
||||||
|
w2_weight_scale = torch.nn.Parameter(torch.ones(
|
||||||
|
num_experts, hidden_size, dtype=torch.float32),
|
||||||
|
requires_grad=False)
|
||||||
|
layer.register_parameter("w2_weight_scale", w2_weight_scale)
|
||||||
|
# Add PER-CHANNEL quantization for FusedMoE.weight_loader.
|
||||||
|
extra_weight_attrs.update(
|
||||||
|
{"quant_method": FusedMoeWeightScaleSupported.CHANNEL.value})
|
||||||
|
set_weight_attrs(w13_weight_scale, extra_weight_attrs)
|
||||||
|
set_weight_attrs(w2_weight_scale, extra_weight_attrs)
|
||||||
|
|
||||||
# INPUT_SCALES
|
# INPUT_SCALES
|
||||||
if self.static_input_scales:
|
if self.static_input_scales:
|
||||||
@ -185,24 +232,60 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod):
|
|||||||
layer.w2_input_scale = torch.nn.Parameter(w2_input_scale,
|
layer.w2_input_scale = torch.nn.Parameter(w2_input_scale,
|
||||||
requires_grad=False)
|
requires_grad=False)
|
||||||
|
|
||||||
# Fp8 moe kernel needs single weight scale for w13 per expert.
|
# For per-tensor case, Fp8 moe kernel needs single weight scale
|
||||||
# We take the max then dequant and requant each expert.
|
# for w13 per expert. Use max then dequant and requant each expert.
|
||||||
assert layer.w13_weight_scale is not None
|
if self.weight_qscheme == "per_tensor":
|
||||||
shard_size = layer.intermediate_size_per_partition
|
assert layer.w13_weight_scale is not None
|
||||||
max_w13_scales = layer.w13_weight_scale.max(dim=1).values
|
shard_size = layer.intermediate_size_per_partition
|
||||||
for expert_id in range(layer.local_num_experts):
|
max_w13_scales = layer.w13_weight_scale.max(dim=1).values
|
||||||
start = 0
|
for expert_id in range(layer.local_num_experts):
|
||||||
for shard_id in range(2):
|
start = 0
|
||||||
dq_weight = per_tensor_dequantize(
|
for shard_id in range(2):
|
||||||
layer.w13_weight[expert_id][start:start + shard_size, :],
|
dq_weight = per_tensor_dequantize(
|
||||||
layer.w13_weight_scale[expert_id][shard_id])
|
layer.w13_weight[expert_id][start:start +
|
||||||
layer.w13_weight[expert_id][
|
shard_size, :],
|
||||||
start:start + shard_size, :], _ = ops.scaled_fp8_quant(
|
layer.w13_weight_scale[expert_id][shard_id])
|
||||||
dq_weight, max_w13_scales[expert_id])
|
layer.w13_weight[expert_id][
|
||||||
start += shard_size
|
start:start + shard_size, :], _ = ops.scaled_fp8_quant(
|
||||||
|
dq_weight, max_w13_scales[expert_id])
|
||||||
|
start += shard_size
|
||||||
|
|
||||||
layer.w13_weight_scale = torch.nn.Parameter(max_w13_scales,
|
layer.w13_weight_scale = torch.nn.Parameter(max_w13_scales,
|
||||||
requires_grad=False)
|
requires_grad=False)
|
||||||
|
# quark's scale is 1 dim.
|
||||||
|
elif self.weight_qscheme == "per_channel":
|
||||||
|
if self.act_quant_group_shape == GroupShape.PER_TOKEN:
|
||||||
|
w13_weight_scale = layer.w13_weight_scale.unsqueeze(-1)
|
||||||
|
layer.w13_weight_scale = torch.nn.Parameter(
|
||||||
|
w13_weight_scale, requires_grad=False)
|
||||||
|
w2_weight_scale = layer.w2_weight_scale.unsqueeze(-1)
|
||||||
|
layer.w2_weight_scale = torch.nn.Parameter(w2_weight_scale,
|
||||||
|
requires_grad=False)
|
||||||
|
# Property to determine if AITER is used
|
||||||
|
if self.rocm_aiter_moe_enabled:
|
||||||
|
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( # noqa E501
|
||||||
|
rocm_aiter_fused_experts, shuffle_weights)
|
||||||
|
|
||||||
|
# reshaping weights is required for aiter moe kernel.
|
||||||
|
shuffled_w13, shuffled_w2 = shuffle_weights(
|
||||||
|
layer.w13_weight.data, layer.w2_weight.data)
|
||||||
|
|
||||||
|
layer.w13_weight = torch.nn.Parameter(shuffled_w13,
|
||||||
|
requires_grad=False)
|
||||||
|
layer.w2_weight = torch.nn.Parameter(shuffled_w2,
|
||||||
|
requires_grad=False)
|
||||||
|
|
||||||
|
self.rocm_aiter_fused_experts_func = rocm_aiter_fused_experts
|
||||||
|
elif self.use_marlin:
|
||||||
|
|
||||||
|
prepare_moe_fp8_layer_for_marlin(layer, False)
|
||||||
|
# Activations not quantized for marlin.
|
||||||
|
del layer.w13_input_scale
|
||||||
|
del layer.w2_input_scale
|
||||||
|
self.fused_experts_func = None
|
||||||
|
else:
|
||||||
|
from vllm.model_executor.layers.fused_moe import fused_experts
|
||||||
|
self.fused_experts_func = fused_experts
|
||||||
|
|
||||||
def apply(
|
def apply(
|
||||||
self,
|
self,
|
||||||
@ -233,8 +316,6 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod):
|
|||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
"EPLB not supported for `QuarkW8A8Fp8MoEMethod` yet.")
|
"EPLB not supported for `QuarkW8A8Fp8MoEMethod` yet.")
|
||||||
|
|
||||||
from vllm.model_executor.layers.fused_moe import fused_experts
|
|
||||||
|
|
||||||
topk_weights, topk_ids = FusedMoE.select_experts(
|
topk_weights, topk_ids = FusedMoE.select_experts(
|
||||||
hidden_states=x,
|
hidden_states=x,
|
||||||
router_logits=router_logits,
|
router_logits=router_logits,
|
||||||
@ -249,22 +330,60 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod):
|
|||||||
e_score_correction_bias=e_score_correction_bias,
|
e_score_correction_bias=e_score_correction_bias,
|
||||||
indices_type=self.topk_indices_dtype)
|
indices_type=self.topk_indices_dtype)
|
||||||
|
|
||||||
return fused_experts(
|
if self.rocm_aiter_moe_enabled:
|
||||||
x,
|
return self.rocm_aiter_fused_experts_func(
|
||||||
layer.w13_weight,
|
hidden_states=x,
|
||||||
layer.w2_weight,
|
w1=layer.w13_weight,
|
||||||
|
w2=layer.w2_weight,
|
||||||
|
topk_weights=topk_weights,
|
||||||
|
topk_ids=topk_ids,
|
||||||
|
activation=activation,
|
||||||
|
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||||
|
use_fp8_w8a8=True,
|
||||||
|
per_channel_quant=self.weight_qscheme == "per_channel",
|
||||||
|
w1_scale=layer.w13_weight_scale,
|
||||||
|
w2_scale=layer.w2_weight_scale,
|
||||||
|
a1_scale=layer.w13_input_scale,
|
||||||
|
a2_scale=layer.w2_input_scale,
|
||||||
|
expert_map=expert_map)
|
||||||
|
if self.use_marlin:
|
||||||
|
assert activation == "silu", (
|
||||||
|
f"{activation} not supported for Marlin MoE.")
|
||||||
|
return torch.ops.vllm.fused_marlin_moe(
|
||||||
|
x,
|
||||||
|
layer.w13_weight,
|
||||||
|
layer.w2_weight,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
layer.w13_weight_scale,
|
||||||
|
layer.w2_weight_scale,
|
||||||
|
router_logits,
|
||||||
|
topk_weights,
|
||||||
|
topk_ids,
|
||||||
|
quant_type_id=scalar_types.float8_e4m3fn.id,
|
||||||
|
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||||
|
global_num_experts=global_num_experts,
|
||||||
|
expert_map=expert_map)
|
||||||
|
|
||||||
|
assert self.fused_experts_func is not None
|
||||||
|
|
||||||
|
return self.fused_experts_func(
|
||||||
|
hidden_states=x,
|
||||||
|
w1=layer.w13_weight,
|
||||||
|
w2=layer.w2_weight,
|
||||||
topk_weights=topk_weights,
|
topk_weights=topk_weights,
|
||||||
topk_ids=topk_ids,
|
topk_ids=topk_ids,
|
||||||
inplace=True,
|
inplace=True,
|
||||||
use_fp8_w8a8=True,
|
activation=activation,
|
||||||
global_num_experts=global_num_experts,
|
|
||||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||||
|
use_fp8_w8a8=True,
|
||||||
|
per_channel_quant=self.weight_qscheme == "per_channel",
|
||||||
|
global_num_experts=global_num_experts,
|
||||||
expert_map=expert_map,
|
expert_map=expert_map,
|
||||||
w1_scale=layer.w13_weight_scale,
|
w1_scale=layer.w13_weight_scale,
|
||||||
w2_scale=layer.w2_weight_scale,
|
w2_scale=layer.w2_weight_scale,
|
||||||
a1_scale=layer.w13_input_scale,
|
a1_scale=layer.w13_input_scale,
|
||||||
a2_scale=layer.w2_input_scale,
|
a2_scale=layer.w2_input_scale)
|
||||||
activation=activation)
|
|
||||||
|
|
||||||
|
|
||||||
class QuarkW4A4MXFp4MoEMethod(QuarkMoEMethod):
|
class QuarkW4A4MXFp4MoEMethod(QuarkMoEMethod):
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user