mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-23 16:14:37 +08:00
[Bugfix] [AITER] [ROCm] Fix Quark MoE Quant Config and AITER Fused MoE quant type logic (#27029)
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
This commit is contained in:
parent
b10c64c834
commit
e33ee23ee3
@ -492,6 +492,8 @@ def rocm_aiter_fused_experts(
|
||||
assert quant_config.w1_scale is not None
|
||||
assert quant_config.w2_scale is not None
|
||||
quant_method = QuantMethod.BLOCK_128x128.value
|
||||
elif quant_config.use_fp8_w8a8 and quant_config.per_out_ch_quant:
|
||||
quant_method = QuantMethod.PER_TOKEN.value
|
||||
elif quant_config.use_fp8_w8a8:
|
||||
# Currently only per tensor quantization method is enabled.
|
||||
quant_method = QuantMethod.PER_TENSOR.value
|
||||
|
||||
@ -342,7 +342,8 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod):
|
||||
w2_scale=layer.w2_weight_scale,
|
||||
a1_scale=layer.w13_input_scale,
|
||||
a2_scale=layer.w2_input_scale,
|
||||
per_act_token_quant=self.weight_qscheme == "per_channel",
|
||||
per_act_token_quant=self.input_qscheme == "per_channel",
|
||||
per_out_ch_quant=self.weight_qscheme == "per_channel",
|
||||
)
|
||||
|
||||
def apply(
|
||||
|
||||
@ -464,8 +464,16 @@ class Fp8LinearOp:
|
||||
else:
|
||||
qinput, x_scale = input_2d, input_scale
|
||||
|
||||
per_tensor_weights = weight_scale.numel() == 1
|
||||
per_tensor_activations = x_scale.numel() == 1
|
||||
# Must have dim() conditions
|
||||
# In per-token quant scenario, when the number of token is 1,
|
||||
# the scale will only have 1 elements.
|
||||
# Without checking the dim(),
|
||||
# we cannot distingushes between per-tensor and per-token quant.
|
||||
# Example:
|
||||
# When the number of token is 1, per-token scale is [[1]]
|
||||
# When per-tensor scale is [1] or ().
|
||||
per_tensor_weights = (weight_scale.numel() == 1) and weight_scale.dim() < 2
|
||||
per_tensor_activations = (x_scale.numel() == 1) and x_scale.dim() < 2
|
||||
|
||||
# TODO(luka) do this dispatch during init (after ScaledMM refactor)
|
||||
w8a8_scaled_mm_func = dispatch_w8a8_scaled_mm(
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user