[Feature][Quantization] MXFP4 support for MOE models (#17888)

Signed-off-by: Felix Marty <felmarty@amd.com>
Signed-off-by: Bowen Bao <bowenbao@amd.com>
Signed-off-by: Felix Marty <Felix.Marty@amd.com>
Co-authored-by: Bowen Bao <bowenbao@amd.com>
This commit is contained in:
fxmarty-amd 2025-07-09 22:19:02 +02:00 committed by GitHub
parent bf03ff3575
commit 332d4cb17b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
15 changed files with 873 additions and 104 deletions

View File

@ -229,3 +229,28 @@ python3 quantize_quark.py --model_dir meta-llama/Llama-2-70b-chat-hf \
--model_export hf_format \
--tasks gsm8k
```
## Using MXFP4 models
vLLM supports loading MXFP4 models quantized offline through AMD Quark, compliant with [Open Compute Project (OCP) specification](https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf).
The scheme currently only supports dynamic quantization for activations.
Example usage, after installing the latest AMD Quark release:
```bash
vllm serve fxmarty/qwen_1.5-moe-a2.7b-mxfp4 --tensor-parallel-size 1
```
A simulation of the matrix multiplication execution in MXFP4 can be run on devices that do not support MXFP4 operations natively (e.g. AMD Instinct MI325, MI300 and MI250), dequantizing weights from MXFP4 to half precision on the fly, using a fused kernel. This is useful e.g. to evaluate MXFP4 models using vLLM, or alternatively to benefit from the ~4x memory savings (compared to float16 and bfloat16).
To generate offline models quantized using MXFP4 data type, the easiest approach is to use AMD Quark's [quantization script](https://quark.docs.amd.com/latest/pytorch/example_quark_torch_llm_ptq.html), as an example:
```bash
python quantize_quark.py --model_dir Qwen/Qwen1.5-MoE-A2.7B-Chat \
--quant_scheme w_mxfp4_a_mxfp4_sym \
--output_dir qwen_1.5-moe-a2.7b-mxfp4 \
--skip_evaluation \
--model_export hf_format \
--group_size 32
```

View File

@ -174,6 +174,7 @@ def test_fused_moe(
use_int8_w8a8=False,
use_int8_w8a16=False,
use_int4_w4a16=False,
use_mxfp4_w4a4=False,
per_act_token_quant=False,
block_shape=None)

View File

@ -0,0 +1,57 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import importlib
import importlib.metadata
from dataclasses import dataclass
import pytest
import torch
from packaging import version
QUARK_MXFP4_AVAILABLE = importlib.util.find_spec(
"quark") is not None and version.parse(
importlib.metadata.version("amd-quark")) >= version.parse('0.8.99')
@dataclass
class ModelCase:
model_id: str
tp: int
@pytest.mark.parametrize('model_case', [
ModelCase("fxmarty/qwen_1.5-moe-a2.7b-mxfp4", tp=1),
ModelCase("fxmarty/deepseek_r1_3_layers_mxfp4", tp=8),
ModelCase("fxmarty/Llama-4-Scout-17B-16E-Instruct-2-layers-mxfp4", tp=1)
])
@pytest.mark.skipif(not QUARK_MXFP4_AVAILABLE,
reason="amd-quark>=0.9 is not available")
def test_mxfp4_loading_and_execution_moe(vllm_runner, model_case: ModelCase):
if torch.cuda.device_count() < model_case.tp:
pytest.skip(f"This test requires >={model_case.tp} gpus, got only "
f"{torch.cuda.device_count()}")
with vllm_runner(model_case.model_id,
tensor_parallel_size=model_case.tp,
load_format="dummy") as llm:
# TODO: llm.apply_model(check_model) currently relies on V0 internals.
# Re-enable this later.
# def check_model(model):
# layer = model.model.layers[0]
# qkv_proj = layer.self_attn.qkv_proj
# assert isinstance(qkv_proj.quant_method, QuarkLinearMethod)
# assert isinstance(qkv_proj.scheme, QuarkW4A4MXFP4)
# assert isinstance(layer.mlp.experts.quant_method,
# QuarkW4A4MXFp4MoEMethod)
# if model_case.model_id == "fxmarty/qwen_1.5-moe-a2.7b-mxfp4":
# llm.apply_model(check_model)
output = llm.generate_greedy("Today I am in the French Alps and",
max_tokens=20)
assert output

View File

@ -0,0 +1,287 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
BFLOAT16_EXP_BIAS = 127
BFLOAT16_MANTISSA_BITS = 7
BFLOAT16_EXP_BITS = 8
FLOAT16_EXP_BIAS = 15
FLOAT16_MANTISSA_BITS = 10
FLOAT16_EXP_BITS = 5
FLOAT8_E8M0_MAX_EXP = 127
FLOAT4_EXP_BIAS = 1
FLOAT4_MANTISSA_BITS = 1
FLOAT16_VAL_TO_ADD = (1 << (FLOAT16_MANTISSA_BITS - FLOAT4_MANTISSA_BITS - 1))
FLOAT16_SIGN_EXPONENT_MASK = ((
(1 << (FLOAT16_EXP_BITS + 1)) - 1) << FLOAT16_MANTISSA_BITS)
BFLOAT16_VAL_TO_ADD = (1 <<
(BFLOAT16_MANTISSA_BITS - FLOAT4_MANTISSA_BITS - 1))
BFLOAT16_SIGN_EXPONENT_MASK = ((
(1 << (BFLOAT16_EXP_BITS + 1)) - 1) << BFLOAT16_MANTISSA_BITS)
def e8m0_to_half(scale, half_dtype: torch.dtype):
assert scale.dtype == torch.uint8
scale_exp = scale.to(torch.int16) - 127
# This can be implemented with bitwise operations in a proper kernel.
scale_half = 2.0**(scale_exp.to(torch.float))
return scale_half.to(half_dtype)
def upcast_fp4_to_fp16_or_bf16(val, float_dtype: torch.dtype,
half_exp_bias: int, half_mantissa_bits: int):
assert val.dtype == torch.uint8
unpacked = torch.zeros(*val.shape[:-1],
val.shape[-1] * 2,
dtype=torch.uint8,
device=val.device)
unpacked[..., 1::2] = (val >> 4) & 0x0F # Extract high 4 bits.
unpacked[..., ::2] = val & 0x0F # Extract low 4 bits.
# Takes one float4 values represented as b0000xxxx,
# and converts it to the corresponding float16 value.
sign = unpacked >> 3
exp = (unpacked >> 1) & 3
new_mantissa = unpacked & 1
# if exp == 0 and new_mantissa == 0:
# new_exp = 0
# else:
# new_exp = exp - FLOAT4_EXP_BIAS + FLOAT16_EXP_BIAS
# int8_t works with float16, but may overflow with bfloat16.
new_exp = exp - FLOAT4_EXP_BIAS + half_exp_bias
# Cast b0000 to 0. in fp16/bf16.
new_exp = new_exp * torch.logical_or(exp > 0, new_mantissa > 0)
# Cast b0001 to 0.5 in fp16/bf16.
new_mantissa = torch.logical_and(new_mantissa, exp > 0)
new_mantissa = new_mantissa.to(torch.int32)
new_exp = new_exp.to(torch.int32)
sign = sign.to(torch.int32)
qdq_val = (sign << 15) + (new_exp << half_mantissa_bits) + (
new_mantissa << (half_mantissa_bits - 1))
assert qdq_val.max() <= 65535
assert qdq_val.min() >= 0
qdq_val = qdq_val.to(torch.uint16)
result = qdq_val.view(float_dtype)
return result
def dq_mxfp4_torch(x: torch.Tensor, scale: torch.Tensor,
float_dtype: torch.dtype) -> torch.Tensor:
assert x.dtype == torch.uint8
assert scale.dtype == torch.uint8
if float_dtype == torch.float16:
half_exp_bias = FLOAT16_EXP_BIAS
half_mantissa_bits = FLOAT16_MANTISSA_BITS
elif float_dtype == torch.bfloat16:
half_exp_bias = BFLOAT16_EXP_BIAS
half_mantissa_bits = BFLOAT16_MANTISSA_BITS
scale_half = e8m0_to_half(scale, half_dtype=float_dtype)
x_half = upcast_fp4_to_fp16_or_bf16(x,
float_dtype=float_dtype,
half_exp_bias=half_exp_bias,
half_mantissa_bits=half_mantissa_bits)
x_half = x_half.reshape(*x_half.shape[:-1], -1, 32)
x_half = x_half * scale_half[..., None]
x_half = x_half.reshape(*x_half.shape[:-2], -1)
return x_half
def fp16_to_fp4_simulate(val, half_mantissa_bits: int, half_exp_bits: int,
half_exp_bias: int):
# Casts an fp16/bf16 input to the restricted values of float4_e2m1,
# that is to say [0., 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0, -0.0,
# -0.5, -1.0, -1.5, -2.0, -3.0, -4.0, -6.0].
float_type = val.dtype
# "rshift_cuda" not implemented for 'UInt16'
val_view = val.view(torch.int16) #.to(torch.int32)
exp = val_view >> half_mantissa_bits
exp = exp & ((1 << half_exp_bits) - 1)
exp = exp.view(torch.uint16).to(torch.int32)
sign = (val_view >> (half_mantissa_bits + half_exp_bits)) & 1
mantissa_last = (val_view >> (half_mantissa_bits - 1)) & 1
exp_unbias = exp - half_exp_bias
new_exp = exp_unbias + FLOAT4_EXP_BIAS
exp_shift = (new_exp <= 0) * (1 - new_exp)
# Typically 9.
# Take the min to prevent overflow on `uint16_t half`. This is the case for
# very small values, correctly mapped to `round_close`.
tail_bits = half_mantissa_bits - FLOAT4_MANTISSA_BITS + exp_shift
tail_bits[tail_bits >= 16] = 16
mantissa_plus_one = val_view & ((1 << (half_mantissa_bits + 1)) - 1)
half = 1 << (tail_bits - 1)
tail = mantissa_plus_one & ((1 << tail_bits) - 1)
round_close = (tail < half) # round towards 0
round_away = (tail > half) # round away from 0
tie = tail == half
new_mantissa_close = torch.zeros(val.shape,
device=val.device,
dtype=torch.bool)
new_exp_close = torch.zeros(val.shape,
device=val.device,
dtype=torch.uint16)
new_mantissa_away = torch.zeros(val.shape,
device=val.device,
dtype=torch.bool)
new_exp_away = torch.zeros(val.shape,
device=val.device,
dtype=torch.uint16)
new_exp_tie = torch.zeros(val.shape, device=val.device, dtype=torch.uint16)
# 1. round down
# if new_exp == 0: # case [0.5, 0.749999]
# new_mantissa = 0
# elif new_exp < 0: # case [0, 0.24999]
# new_mantissa = 0
# else:
# new_mantissa = mantissa_last
new_mantissa_close = (new_exp > 0) * mantissa_last
new_exp_close = exp
# # 2. round up
# if new_exp <= 0: # case [0.250001, 0.499999] and [0.75001, 0.99999]
# new_mantissa = 0
# new_exp += 1
# elif mantissa_last == 0:
# new_mantissa = 1
# else:
# new_mantissa = 0
# new_exp += 1
new_mantissa_away = torch.logical_and(new_exp > 0, mantissa_last == 0)
new_exp_away = exp + torch.logical_or(new_exp <= 0, mantissa_last == 1)
# # 3. tie
# 0.25 -> 0. (handled by `exp > (half_exp_bias - 2)`)
# 0.75 -> 1.
# 1.25 -> 1.
# 1.75 -> 2.
# 2.5 -> 2.
# 3.5 -> 4.
# 5. -> 4.
new_exp_tie = (exp > (half_exp_bias - 2)) * (exp + (mantissa_last == 1))
# Gather round up, round down and tie.
new_exp = round_away * new_exp_away \
+ round_close * new_exp_close \
+ tie * new_exp_tie
new_mantissa = round_away * new_mantissa_away \
+ round_close * new_mantissa_close
# if new_exp > 3:
# new_mantissa = 1
new_mantissa = new_mantissa + (new_exp >
(2 + half_exp_bias)) * (new_mantissa == 0)
# Clamp the exponent to acceptable values.
new_exp = (new_exp >= (half_exp_bias - 2)) * torch.clamp(
new_exp, half_exp_bias - 2, half_exp_bias + 2)
sign = sign.to(torch.int32)
new_mantissa = new_mantissa.to(torch.int32)
qdq_val = (sign << 15) + (new_exp << half_mantissa_bits) + (
new_mantissa << (half_mantissa_bits - 1))
assert qdq_val.max() <= 65535
assert qdq_val.min() >= 0
assert qdq_val.dtype == torch.int32
qdq_val = qdq_val.to(torch.uint16)
result = qdq_val.view(float_type)
return result
def qdq_mxfp4_torch(x: torch.Tensor,
scale_calculation_mode: str = "even") -> torch.Tensor:
half_dtype = x.dtype
if half_dtype == torch.float16:
half_mantissa_bits = FLOAT16_MANTISSA_BITS
half_exp_bits = FLOAT16_EXP_BITS
half_exp_bias = FLOAT16_EXP_BIAS
val_to_add = FLOAT16_VAL_TO_ADD
sign_exponent_mask = FLOAT16_SIGN_EXPONENT_MASK
elif half_dtype == torch.bfloat16:
half_mantissa_bits = BFLOAT16_MANTISSA_BITS
half_exp_bits = BFLOAT16_EXP_BITS
half_exp_bias = BFLOAT16_EXP_BIAS
val_to_add = BFLOAT16_VAL_TO_ADD
sign_exponent_mask = BFLOAT16_SIGN_EXPONENT_MASK
else:
raise ValueError("not implemented")
x = x.reshape(*x.shape[:-1], -1, 32)
block_max = torch.max(torch.abs(x), dim=-1).values
block_max = block_max.view(torch.uint16).to(torch.int32)
block_max_uint = torch.bitwise_and(block_max + val_to_add,
sign_exponent_mask)
assert block_max_uint.max() <= 65535
assert block_max_uint.min() >= 0
assert block_max_uint.dtype == torch.int32
block_max_uint = block_max_uint.to(torch.uint16)
block_max = block_max_uint.view(half_dtype)
scale_exp = FLOAT8_E8M0_MAX_EXP + torch.floor(torch.log2(block_max)).to(
torch.int32) - 2
scale_exp = torch.clamp(scale_exp, 0, 2 * FLOAT8_E8M0_MAX_EXP)
scale = 2.0**(scale_exp - FLOAT8_E8M0_MAX_EXP)
scale = scale.to(half_dtype)
x = x / scale[..., None]
x_fp4 = fp16_to_fp4_simulate(x,
half_exp_bits=half_exp_bits,
half_mantissa_bits=half_mantissa_bits,
half_exp_bias=half_exp_bias)
x_fp4 = x_fp4 * scale[..., None]
return x_fp4.reshape(*x_fp4.shape[:-2], -1)

View File

@ -3,15 +3,44 @@
"""Test model set-up and weight loading for quark-quantized models.
Run `pytest tests/quantization/test_quark.py`.
See also `tests/kernels/moe/test_mxfp4_moe.py`.
"""
import importlib
import importlib.metadata
import os
from dataclasses import dataclass
import huggingface_hub
import lm_eval
import pytest
import torch
from packaging import version
from vllm.model_executor.layers.quantization.quark.quark import ( # noqa: E501
QuarkLinearMethod, QuarkW8A8Fp8, QuarkW8A8Int8)
from vllm.platforms import current_platform
from .reference_mxfp4 import dq_mxfp4_torch, qdq_mxfp4_torch
QUARK_MXFP4_AVAILABLE = importlib.util.find_spec(
"quark") is not None and version.parse(
importlib.metadata.version("amd-quark")) >= version.parse('0.8.99')
if QUARK_MXFP4_AVAILABLE:
from quark.torch.export.nn.modules.realquantizer import (
StaticScaledRealQuantizer)
from quark.torch.kernel import mx as mx_kernel
from quark.torch.quantization.config.config import FP4PerGroupSpec
try:
huggingface_hub.list_repo_refs(
"amd/Llama-3.3-70B-Instruct-WMXFP4-AMXFP4-KVFP8-Scale-UINT8-SQ")
HF_HUB_AMD_ORG_ACCESS = True
except huggingface_hub.errors.RepositoryNotFoundError:
HF_HUB_AMD_ORG_ACCESS = False
@pytest.fixture(scope="function", autouse=True)
def use_v0_only(monkeypatch):
@ -90,3 +119,145 @@ def test_quark_fp8_parity(vllm_runner):
for key in fp8_state_dict:
assert torch.equal(fp8_state_dict[key], quark_state_dict[key])
@dataclass
class ModelCase:
model_id: str
tp: int
@dataclass
class GSM8KAccuracyTestConfig:
model_name: str
excepted_value: float
def get_model_args(self) -> str:
return (
f"pretrained={self.model_name},"
"dtype=auto,add_bos_token=True,tensor_parallel_size=8,gpu_memory_utilization=0.7,max_model_len=38768"
)
ACCURACY_CONFIGS = [
# Private model.
GSM8KAccuracyTestConfig(
model_name="amd/DeepSeek-R1-WMXFP4-AMXFP4-Scale-UINT8-MoE-Quant",
excepted_value=0.96),
]
@pytest.mark.parametrize("config", ACCURACY_CONFIGS)
@pytest.mark.skipif(not QUARK_MXFP4_AVAILABLE,
reason="amd-quark>=0.9 is not available")
@pytest.mark.skipif(
not HF_HUB_AMD_ORG_ACCESS,
reason="Read access to huggingface.co/amd is required for this test.")
def test_mxfp4_gsm8k_correctness(config: GSM8KAccuracyTestConfig):
if torch.cuda.device_count() < 8:
pytest.skip(
f"This test requires >=8 gpus, got only {torch.cuda.device_count()}"
)
task = "gsm8k"
rtol = 0.03
os.environ["VLLM_USE_TRITON_FLASH_ATTN"] = "0"
results = lm_eval.simple_evaluate(
model="vllm",
model_args=config.get_model_args(),
tasks=task,
batch_size=64,
num_fewshot=8,
)
EXPECTED_VALUE = config.excepted_value
measured_value = results["results"][task]["exact_match,strict-match"]
assert (measured_value - rtol < EXPECTED_VALUE
and measured_value + rtol > EXPECTED_VALUE
), f"Expected: {EXPECTED_VALUE} | Measured: {measured_value}"
del os.environ["VLLM_USE_TRITON_FLASH_ATTN"]
@pytest.mark.skipif(not QUARK_MXFP4_AVAILABLE,
reason="amd-quark>=0.9 is not available")
@pytest.mark.parametrize("float_dtype", [torch.bfloat16, torch.float16])
@pytest.mark.parametrize("scalings",
[[2.3, 0.03, 7.3, 0.1, 0.004, 17.3, 1e4, 1e-4]])
def test_mxfp4_fused_qdq_match_quark(float_dtype: torch.dtype,
scalings: list[int]):
torch.manual_seed(0)
hidden_size = 64 * 32
inp = (torch.rand(1, hidden_size, dtype=float_dtype, device="cuda") -
0.5) * 2
for i in range(hidden_size // 32):
inp[:, i * 32:(i + 1) *
32] = inp[:, i * 32:(i + 1) * 32] * scalings[i % len(scalings)]
inp_kernel = inp.clone()
inp_kernel_clone = inp_kernel.clone()
res_hip = mx_kernel.qdq_mxfp4_hip(inp_kernel_clone, "even")
res_torch = qdq_mxfp4_torch(inp_kernel, "even")
for i in range(hidden_size // 32):
assert torch.all(torch.isfinite(res_hip[:, i * 32:(i + 1) * 32]))
assert torch.all(torch.isfinite(res_torch[:, i * 32:(i + 1) * 32]))
torch.testing.assert_close(res_hip[:, i * 32:(i + 1) * 32],
res_torch[:, i * 32:(i + 1) * 32])
@pytest.mark.skipif(not QUARK_MXFP4_AVAILABLE,
reason="amd-quark>=0.9 is not available")
@pytest.mark.parametrize("float_dtype", [torch.bfloat16, torch.float16])
@pytest.mark.parametrize("scalings",
[[2.3, 0.03, 7.3, 0.1, 0.004, 17.3, 1e4, 1e-4]])
def test_mxfp4_dequant_kernel_match_quark(float_dtype: torch.dtype,
scalings: list[int]):
qspec = FP4PerGroupSpec(
ch_axis=-1,
group_size=32,
scale_format="e8m0",
scale_calculation_mode="even",
is_dynamic=False,
).to_quantization_spec()
weight_quantizer = StaticScaledRealQuantizer(
qspec=qspec,
quantizer=None,
reorder=False,
real_quantized=True,
float_dtype=float_dtype,
device="cuda",
)
observer = qspec.observer_cls(qspec, device="cuda")
hidden_size = 512
shape = (11008, hidden_size)
w = (torch.rand(shape, device="cuda", dtype=float_dtype) - 0.5) * 2
# Make it so that different groups have different scales.
for i in range(hidden_size // 32):
w[:, i * 32:(i + 1) *
32] = w[:, i * 32:(i + 1) * 32] * scalings[i % len(scalings)]
observer(w)
scale, _ = observer._calculate_qparams()
weight_quantizer.scale = scale
w_mxfp4 = weight_quantizer.to_real_quantize_params(w).to("cuda")
weight_quantizer.maybe_convert_and_transpose_scale()
scale = weight_quantizer.scale
out_hip = mx_kernel.dq_mxfp4_hip(w_mxfp4, scale, float_dtype)
out_torch = dq_mxfp4_torch(w_mxfp4, scale, float_dtype)
assert torch.equal(out_hip, out_torch)

View File

@ -94,7 +94,6 @@ if TYPE_CHECKING:
VLLM_ROCM_FP8_PADDING: bool = True
VLLM_ROCM_MOE_PADDING: bool = True
VLLM_ROCM_CUSTOM_PAGED_ATTN: bool = True
VLLM_QUARK_EMU_MEM_OPT: bool = False
VLLM_ENABLE_V1_MULTIPROCESSING: bool = True
VLLM_LOG_BATCHSIZE_INTERVAL: float = -1
VLLM_DISABLE_COMPILE_CACHE: bool = False
@ -723,14 +722,6 @@ environment_variables: dict[str, Callable[[], Any]] = {
lambda: maybe_convert_int(
os.environ.get("VLLM_ROCM_QUICK_REDUCE_MAX_SIZE_BYTES_MB", None)),
# If set, when running in Quark emulation mode, do not dequantize the
# weights at load time. Instead, dequantize weights on-the-fly during
# kernel execution.
# This allows running larger models at the cost of slower inference.
# This flag has no effect when not running in Quark emulation mode.
"VLLM_QUARK_EMU_MEM_OPT":
lambda: bool(int(os.getenv("VLLM_QUARK_EMU_MEM_OPT", "0"))),
# Divisor for dynamic query scale factor calculation for FP8 KV Cache
"Q_SCALE_CONSTANT":
lambda: int(os.getenv("Q_SCALE_CONSTANT", "200")),

View File

@ -50,11 +50,14 @@ def get_config_quant_dtype(
use_int8_w8a8: bool,
use_int8_w8a16: bool,
use_int4_w4a16: bool,
) -> Optional[torch.dtype]:
use_mxfp4_w4a4: bool,
) -> Union[None, torch.dtype, str]:
if use_fp8_w8a8:
return torch.float8_e4m3fn
elif use_int8_w8a8:
return torch.int8
elif use_mxfp4_w4a4:
return "mxfp4"
return None
@ -126,6 +129,7 @@ class FusedMoEQuantConfig:
use_int8_w8a8: bool = False,
use_int8_w8a16: bool = False,
use_int4_w4a16: bool = False,
use_mxfp4_w4a4: bool = False,
per_act_token_quant: bool = False,
per_out_ch_quant: bool = False,
block_shape: Optional[list[int]] = None,
@ -144,6 +148,7 @@ class FusedMoEQuantConfig:
use_int8_w8a8=use_int8_w8a8,
use_int8_w8a16=use_int8_w8a16,
use_int4_w4a16=use_int4_w4a16,
use_mxfp4_w4a4=use_mxfp4_w4a4,
)
return FusedMoEQuantConfig(
quant_dtype,

View File

@ -632,6 +632,7 @@ class NaiveBatchedExperts(mk.FusedMoEPermuteExpertsUnpermute):
use_int8_w8a8: bool = False,
use_int8_w8a16: bool = False,
use_int4_w4a16: bool = False,
use_mxfp4_w4a4: bool = False,
block_shape: Optional[list[int]] = None,
per_act_token_quant: bool = False,
):
@ -641,12 +642,14 @@ class NaiveBatchedExperts(mk.FusedMoEPermuteExpertsUnpermute):
use_int8_w8a8=use_int8_w8a8,
use_int8_w8a16=use_int8_w8a16,
use_int4_w4a16=use_int4_w4a16,
use_mxfp4_w4a4=use_mxfp4_w4a4,
per_act_token_quant=per_act_token_quant,
block_shape=block_shape,
))
assert not use_int8_w8a8, "NYI"
assert not use_int8_w8a16, "NYI"
assert not use_int4_w4a16, "NYI"
assert not use_mxfp4_w4a4, "NYI"
self.max_num_tokens = max_num_tokens
self.num_dispatchers = num_dispatchers
@ -838,6 +841,7 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
use_int8_w8a8: bool = False,
use_int8_w8a16: bool = False,
use_int4_w4a16: bool = False,
use_mxfp4_w4a4: bool = False,
per_act_token_quant: bool = False,
block_shape: Optional[list[int]] = None,
):
@ -847,18 +851,21 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
use_int8_w8a8=use_int8_w8a8,
use_int8_w8a16=use_int8_w8a16,
use_int4_w4a16=use_int4_w4a16,
use_mxfp4_w4a4=use_mxfp4_w4a4,
per_act_token_quant=per_act_token_quant,
block_shape=block_shape,
))
assert not use_int8_w8a8, "NYI"
assert not use_int8_w8a16, "NYI"
assert not use_int4_w4a16, "NYI"
assert not use_mxfp4_w4a4, "NYI"
assert max_num_tokens > 0
assert num_dispatchers > 0
self.use_fp8_w8a8 = use_fp8_w8a8
self.use_int8_w8a8 = use_int8_w8a8
self.use_int4_w4a16 = use_int4_w4a16
self.use_int8_w8a16 = use_int8_w8a16
self.use_mxfp4_w4a4 = use_mxfp4_w4a4
self.max_num_tokens = max_num_tokens
self.num_dispatchers = num_dispatchers
@ -941,6 +948,7 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
config_dtype = get_config_dtype_str(use_fp8_w8a8=self.use_fp8_w8a8,
use_int8_w8a16=self.use_int8_w8a16,
use_int4_w4a16=self.use_int4_w4a16,
use_mxfp4_w4a4=self.use_mxfp4_w4a4,
dtype=hidden_states.dtype)
config = try_get_optimal_moe_config(

View File

@ -27,6 +27,8 @@ from vllm.model_executor.layers.fused_moe.prepare_finalize import (
MoEPrepareAndFinalizeNoEP)
from vllm.model_executor.layers.fused_moe.utils import (
_resize_cache, moe_kernel_quantize_input)
from vllm.model_executor.layers.quantization.utils.mxfp4_utils import (
dequant_mxfp4)
from vllm.platforms import current_platform
from vllm.triton_utils import tl, triton
from vllm.utils import direct_register_custom_op
@ -973,13 +975,16 @@ def get_config_dtype_str(
dtype: torch.dtype,
use_int4_w4a16: Optional[bool] = False,
use_int8_w8a16: Optional[bool] = False,
use_fp8_w8a8: Optional[bool] = False) -> Optional[str]:
use_fp8_w8a8: Optional[bool] = False,
use_mxfp4_w4a4: Optional[bool] = False) -> Optional[str]:
if use_fp8_w8a8:
return "fp8_w8a8"
elif use_int8_w8a16:
return "int8_w8a16"
elif use_int4_w4a16:
return "int4_w4a16"
elif use_mxfp4_w4a4:
return "mxfp4_w4a4"
elif dtype == torch.float:
# avoiding cases where kernel fails when float32 MoE
# use fp16/bfloat16 configs
@ -998,6 +1003,7 @@ def inplace_fused_experts(hidden_states: torch.Tensor,
use_int8_w8a8: bool = False,
use_int8_w8a16: bool = False,
use_int4_w4a16: bool = False,
use_mxfp4_w4a4: bool = False,
per_channel_quant: bool = False,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
@ -1011,9 +1017,9 @@ def inplace_fused_experts(hidden_states: torch.Tensor,
fused_experts_impl(hidden_states, w1, w2, topk_weights, topk_ids, True,
activation, apply_router_weight_on_input, use_fp8_w8a8,
use_int8_w8a8, use_int8_w8a16, use_int4_w4a16,
per_channel_quant, global_num_experts, expert_map,
w1_scale, w2_scale, w1_zp, w2_zp, a1_scale, a2_scale,
block_shape)
use_mxfp4_w4a4, per_channel_quant, global_num_experts,
expert_map, w1_scale, w2_scale, w1_zp, w2_zp, a1_scale,
a2_scale, block_shape)
def inplace_fused_experts_fake(
@ -1028,6 +1034,7 @@ def inplace_fused_experts_fake(
use_int8_w8a8: bool = False,
use_int8_w8a16: bool = False,
use_int4_w4a16: bool = False,
use_mxfp4_w4a4: bool = False,
per_channel_quant: bool = False,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
@ -1062,6 +1069,7 @@ def outplace_fused_experts(
use_int8_w8a8: bool = False,
use_int8_w8a16: bool = False,
use_int4_w4a16: bool = False,
use_mxfp4_w4a4: bool = False,
per_channel_quant: bool = False,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
@ -1075,10 +1083,10 @@ def outplace_fused_experts(
return fused_experts_impl(hidden_states, w1, w2, topk_weights, topk_ids,
False, activation, apply_router_weight_on_input,
use_fp8_w8a8, use_int8_w8a8, use_int8_w8a16,
use_int4_w4a16, per_channel_quant,
global_num_experts, expert_map, w1_scale,
w2_scale, w1_zp, w2_zp, a1_scale, a2_scale,
block_shape)
use_int4_w4a16, use_mxfp4_w4a4,
per_channel_quant, global_num_experts,
expert_map, w1_scale, w2_scale, w1_zp, w2_zp,
a1_scale, a2_scale, block_shape)
def outplace_fused_experts_fake(
@ -1092,6 +1100,7 @@ def outplace_fused_experts_fake(
use_int8_w8a8: bool = False,
use_int8_w8a16: bool = False,
use_int4_w4a16: bool = False,
use_mxfp4_w4a4: bool = False,
per_channel_quant: bool = False,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
@ -1145,6 +1154,7 @@ def fused_experts(
use_int8_w8a8: bool = False,
use_int8_w8a16: bool = False,
use_int4_w4a16: bool = False,
use_mxfp4_w4a4: bool = False,
per_channel_quant: bool = False,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
@ -1203,6 +1213,7 @@ def fused_experts(
use_int8_w8a8=use_int8_w8a8,
use_int8_w8a16=use_int8_w8a16,
use_int4_w4a16=use_int4_w4a16,
use_mxfp4_w4a4=use_mxfp4_w4a4,
per_channel_quant=per_channel_quant,
global_num_experts=global_num_experts,
expert_map=expert_map,
@ -1228,6 +1239,7 @@ def fused_experts_impl(
use_int8_w8a8: bool = False,
use_int8_w8a16: bool = False,
use_int4_w4a16: bool = False,
use_mxfp4_w4a4: bool = False,
per_channel_quant: bool = False,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
@ -1243,6 +1255,9 @@ def fused_experts_impl(
if use_int4_w4a16:
assert hidden_states.size(1) // 2 == w1.size(2), (
"Hidden size mismatch")
elif use_mxfp4_w4a4:
# 16bit activation and fp4x2 packed weight
assert hidden_states.size(1) // 2 == w1.size(2), "hidden size mismatch"
else:
assert hidden_states.size(1) == w1.size(2), (
f"Hidden size mismatch {hidden_states.size(1)} != {w1.size(2)}")
@ -1268,12 +1283,14 @@ def fused_experts_impl(
config_dtype = get_config_dtype_str(use_fp8_w8a8=use_fp8_w8a8,
use_int8_w8a16=use_int8_w8a16,
use_int4_w4a16=use_int4_w4a16,
use_mxfp4_w4a4=use_mxfp4_w4a4,
dtype=hidden_states.dtype)
qtype = get_config_quant_dtype(use_fp8_w8a8=use_fp8_w8a8,
use_int8_w8a8=use_int8_w8a8,
use_int8_w8a16=use_int8_w8a16,
use_int4_w4a16=use_int4_w4a16)
use_int4_w4a16=use_int4_w4a16,
use_mxfp4_w4a4=use_mxfp4_w4a4)
get_config_func = functools.partial(
try_get_optimal_moe_config,
@ -1313,6 +1330,13 @@ def fused_experts_impl(
else:
out_hidden_states = torch.empty_like(hidden_states)
if use_mxfp4_w4a4:
# Weight has to be dequantized for mxfp4 emulation.
w1 = dequant_mxfp4(w1, w1_scale, hidden_states.dtype)
w1_scale = None
w2 = dequant_mxfp4(w2, w2_scale, hidden_states.dtype)
w2_scale = None
for chunk in range((num_tokens // CHUNK_SIZE) + 1):
begin_chunk_idx, end_chunk_idx = (chunk * CHUNK_SIZE,
min((chunk + 1) * CHUNK_SIZE,
@ -1429,6 +1453,7 @@ def fused_moe(
use_int8_w8a8: bool = False,
use_int8_w8a16: bool = False,
use_int4_w4a16: bool = False,
use_mxfp4_w4a4: bool = False,
per_channel_quant: bool = False,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
@ -1470,6 +1495,9 @@ def fused_moe(
- use_int4_w4a16 (bool): If True, use matmul of int4 weight and bf16/fp16
activation to compute the inner products for w1 and w2.
Defaults to False.
- use_mxfp4_w4a4 (bool): If True, use matmul of OCP MXFP4 weight and
OCP MXFP4 activation to compute the inner products for w1 and w2.
Defaults to False.
- global_num_experts (int): The total number of experts in the global
expert space.
- expert_map (Optional[torch.Tensor]): A tensor mapping expert indices
@ -1513,6 +1541,7 @@ def fused_moe(
use_int8_w8a8=use_int8_w8a8,
use_int8_w8a16=use_int8_w8a16,
use_int4_w4a16=use_int4_w4a16,
use_mxfp4_w4a4=use_mxfp4_w4a4,
per_channel_quant=per_channel_quant,
global_num_experts=global_num_experts,
expert_map=expert_map,
@ -1533,6 +1562,7 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
use_int8_w8a8: bool = False,
use_int8_w8a16: bool = False,
use_int4_w4a16: bool = False,
use_mxfp4_w4a4: bool = False,
per_act_token_quant: bool = False,
block_shape: Optional[list[int]] = None,
):
@ -1542,6 +1572,7 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
use_int8_w8a8=use_int8_w8a8,
use_int8_w8a16=use_int8_w8a16,
use_int4_w4a16=use_int4_w4a16,
use_mxfp4_w4a4=use_mxfp4_w4a4,
per_act_token_quant=per_act_token_quant,
block_shape=block_shape,
))
@ -1550,6 +1581,7 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
self.use_int4_w4a16 = use_int4_w4a16
self.use_int8_w8a8 = use_int8_w8a8
self.use_int8_w8a16 = use_int8_w8a16
self.use_mxfp4_w4a4 = use_mxfp4_w4a4
@property
def activation_formats(
@ -1627,6 +1659,7 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
config_dtype = get_config_dtype_str(use_fp8_w8a8=self.use_fp8_w8a8,
use_int8_w8a16=self.use_int8_w8a16,
use_int4_w4a16=self.use_int4_w4a16,
use_mxfp4_w4a4=self.use_mxfp4_w4a4,
dtype=hidden_states.dtype)
config = try_get_optimal_moe_config(
@ -1718,6 +1751,7 @@ def modular_triton_fused_moe(
use_int8_w8a8: bool,
use_int8_w8a16: bool,
use_int4_w4a16: bool,
use_mxfp4_w4a4: bool,
per_act_token_quant: bool,
block_shape: Optional[list[int]] = None,
) -> mk.FusedMoEModularKernel:
@ -1728,6 +1762,7 @@ def modular_triton_fused_moe(
use_int8_w8a8=use_int8_w8a8,
use_int8_w8a16=use_int8_w8a16,
use_int4_w4a16=use_int4_w4a16,
use_mxfp4_w4a4=use_mxfp4_w4a4,
per_act_token_quant=per_act_token_quant,
block_shape=block_shape,
),

View File

@ -19,6 +19,7 @@ class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
use_int8_w8a8: bool = False,
use_int8_w8a16: bool = False,
use_int4_w4a16: bool = False,
use_mxfp4_w4a4: bool = False,
per_act_token_quant: bool = False,
block_shape: Optional[list[int]] = None,
allow_deep_gemm: bool = False,
@ -29,6 +30,7 @@ class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
use_int8_w8a8=use_int8_w8a8,
use_int8_w8a16=use_int8_w8a16,
use_int4_w4a16=use_int4_w4a16,
use_mxfp4_w4a4=use_mxfp4_w4a4,
per_act_token_quant=per_act_token_quant,
block_shape=block_shape,
))
@ -37,6 +39,7 @@ class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
use_int8_w8a8=use_int8_w8a8,
use_int4_w4a16=use_int4_w4a16,
use_int8_w8a16=use_int8_w8a16,
use_mxfp4_w4a4=use_mxfp4_w4a4,
per_act_token_quant=per_act_token_quant,
block_shape=block_shape,
)

View File

@ -1,7 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from math import prod
from typing import Optional
from typing import Optional, Union
import torch
@ -10,6 +10,9 @@ from vllm.model_executor.layers.quantization.utils.fp8_utils import (
per_token_group_quant_fp8)
from vllm.model_executor.layers.quantization.utils.int8_utils import (
per_token_group_quant_int8, per_token_quant_int8)
from vllm.model_executor.layers.quantization.utils.mxfp4_utils import (
quant_dequant_mxfp4)
from vllm.platforms import current_platform
from vllm.utils import cdiv
@ -74,10 +77,25 @@ def _int8_quantize(
return A, A_scale
def _mxfp4_quantize(
A: torch.Tensor,
A_scale: Optional[torch.Tensor],
per_act_token_quant: bool,
block_shape: Optional[list[int]] = None,
) -> tuple[torch.Tensor, None]:
assert block_shape is None
if not current_platform.supports_mx():
A = quant_dequant_mxfp4(A)
else:
raise NotImplementedError()
return A, None
def moe_kernel_quantize_input(
A: torch.Tensor,
A_scale: Optional[torch.Tensor],
quant_dtype: Optional[torch.dtype],
quant_dtype: Union[None, torch.dtype, str],
per_act_token_quant: bool,
block_shape: Optional[list[int]] = None,
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
@ -85,6 +103,8 @@ def moe_kernel_quantize_input(
return _fp8_quantize(A, A_scale, per_act_token_quant, block_shape)
elif quant_dtype == torch.int8:
return _int8_quantize(A, A_scale, per_act_token_quant, block_shape)
elif quant_dtype == "mxfp4":
return _mxfp4_quantize(A, A_scale, per_act_token_quant, block_shape)
else:
return A, A_scale

View File

@ -237,12 +237,6 @@ class QuarkConfig(QuantizationConfig):
"Quark model is not in MX-FP4 format: not group_size=32")
return False
# Weights need to use static quantization.
if weight_quant.get("is_dynamic") is True:
logger.debug(
"Quark model is not in MX-FP4 format: not weight static")
return False
# Activations need to use dynamic quantization.
if input_quant.get("is_dynamic") is False:
logger.debug(

View File

@ -5,11 +5,12 @@ from typing import Any, Callable, Optional
import torch
import vllm.model_executor.layers.fused_moe # noqa
from vllm import _custom_ops as ops
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEMethodBase,
FusedMoeWeightScaleSupported)
from vllm.model_executor.layers.quantization.utils.mxfp4_utils import (
OCP_MX_BLOCK_SIZE)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
all_close_1d, normalize_e4m3fn_to_e4m3fnuz, per_tensor_dequantize)
from vllm.model_executor.utils import set_weight_attrs
@ -17,7 +18,9 @@ from vllm.platforms import current_platform
logger = init_logger(__name__)
__all__ = ["QuarkMoEMethod", "QuarkW8A8Fp8MoEMethod"]
__all__ = [
"QuarkMoEMethod", "QuarkW8A8Fp8MoEMethod", "QuarkW4A4MXFp4MoEMethod"
]
class QuarkMoEMethod(FusedMoEMethodBase):
@ -40,6 +43,8 @@ class QuarkMoEMethod(FusedMoEMethodBase):
if quant_config._is_fp8_w8a8(weight_config, input_config):
return QuarkW8A8Fp8MoEMethod(weight_config, input_config)
elif quant_config._is_mx_fp4(weight_config, input_config):
return QuarkW4A4MXFp4MoEMethod(weight_config, input_config)
else:
raise RuntimeError("Unsupported FusedMoe scheme")
@ -242,4 +247,163 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod):
w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
a1_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale)
a2_scale=layer.w2_input_scale,
activation=activation)
class QuarkW4A4MXFp4MoEMethod(QuarkMoEMethod):
def __init__(self, weight_config: dict[str, Any], input_config: dict[str,
Any]):
self.weight_quant = weight_config
self.input_quant = input_config
weight_qscheme = self.weight_quant.get("qscheme")
input_qscheme = self.input_quant.get("qscheme")
if not (weight_qscheme == "per_group"
and input_qscheme == "per_group"):
raise ValueError(
"For MX(FP4) Fused MoE layers, only per-group scales "
"for weights and activations are supported. Found "
f"{weight_qscheme}, {input_qscheme}") # noqa E501
self.static_input_scales = not self.input_quant.get("is_dynamic")
if self.static_input_scales:
raise NotImplementedError(
"QuarkW4A4MXFp4MoEMethod with static input scales is currently "
"not implemented. Please open an issue.")
if not current_platform.supports_mx():
self.emulate = True
logger.warning_once(
"The current platform does not support native MXFP4 "
"computation. Simulated weight dequantization and activation "
"QDQ (quantize and dequantize) will be used, with the linear "
"layers computed in high precision.")
else:
self.emulate = True
logger.warning_once(
"The current platform supports native MXFP4 "
"computation, but kernels are not yet integrated in vLLM. "
"Simulated weight dequantization and activation "
"QDQ (quantize and dequantize) will be used, with the linear "
"layers computed in high precision.")
def create_weights(self, layer: torch.nn.Module, num_experts: int,
hidden_size: int, intermediate_size_per_partition: int,
params_dtype: torch.dtype, **extra_weight_attrs):
# Add the quantization method used (per tensor/grouped/channel)
# to ensure the weight scales are loaded in properly
extra_weight_attrs.update(
{"quant_method": FusedMoeWeightScaleSupported.BLOCK.value})
params_dtype = torch.uint8
# WEIGHTS
w13_weight = torch.nn.Parameter(torch.empty(
num_experts,
2 * intermediate_size_per_partition,
hidden_size // 2,
dtype=params_dtype),
requires_grad=False)
layer.register_parameter("w13_weight", w13_weight)
set_weight_attrs(w13_weight, extra_weight_attrs)
w2_weight = torch.nn.Parameter(torch.empty(
num_experts,
hidden_size,
intermediate_size_per_partition // 2,
dtype=params_dtype),
requires_grad=False)
layer.register_parameter("w2_weight", w2_weight)
set_weight_attrs(w2_weight, extra_weight_attrs)
# WEIGHT_SCALES
w13_weight_scale = torch.nn.Parameter(
torch.ones(
num_experts,
2 * intermediate_size_per_partition,
hidden_size // OCP_MX_BLOCK_SIZE,
dtype=params_dtype,
),
requires_grad=False,
)
w2_weight_scale = torch.nn.Parameter(
torch.ones(
num_experts,
hidden_size,
intermediate_size_per_partition // OCP_MX_BLOCK_SIZE,
dtype=params_dtype,
),
requires_grad=False,
)
set_weight_attrs(w2_weight_scale, extra_weight_attrs)
set_weight_attrs(w13_weight_scale, extra_weight_attrs)
layer.register_parameter("w13_weight_scale", w13_weight_scale)
layer.register_parameter("w2_weight_scale", w2_weight_scale)
def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
renormalize: bool,
use_grouped_topk: bool = False,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
enable_eplb: bool = False,
expert_load_view: Optional[torch.Tensor] = None,
logical_to_physical_map: Optional[torch.Tensor] = None,
logical_replica_count: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if enable_eplb:
raise NotImplementedError(
"EPLB not supported for `QuarkW4A4MXFp4MoEMethod` yet.")
from vllm.model_executor.layers.fused_moe import fused_experts
topk_weights, topk_ids = FusedMoE.select_experts(
hidden_states=x,
router_logits=router_logits,
use_grouped_topk=use_grouped_topk,
top_k=top_k,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias)
out = fused_experts(
x,
layer.w13_weight,
layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=True,
use_mxfp4_w4a4=True,
global_num_experts=global_num_experts,
apply_router_weight_on_input=apply_router_weight_on_input,
expert_map=expert_map,
w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
a1_scale=None,
a2_scale=None,
block_shape=None,
activation=activation,
)
return out

View File

@ -6,14 +6,16 @@ from typing import Any, Callable, Optional
import torch
import torch.nn.functional as F
import vllm.envs as envs
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.quark.schemes import QuarkScheme
from vllm.model_executor.layers.quantization.utils.mxfp4_utils import (
OCP_MX_BLOCK_SIZE, per_token_group_quant_mxfp4)
OCP_MX_BLOCK_SIZE, dequant_mxfp4, quant_dequant_mxfp4)
from vllm.model_executor.parameter import (GroupQuantScaleParameter,
PackedvLLMParameter)
from vllm.platforms import current_platform
logger = init_logger(__name__)
__all__ = ["QuarkW4A4MXFP4"]
@ -25,7 +27,29 @@ class QuarkW4A4MXFP4(QuarkScheme):
self.qscheme = "per_group"
self.weight_quant_spec = weight_quant_spec
self.input_quant_spec = input_quant_spec
self.emulate = not current_platform.supports_mx()
self.static_input_scales = not input_quant_spec.get("is_dynamic")
if self.static_input_scales:
raise NotImplementedError(
"QuarkW4A4MXFP4 with static input scales is currently not "
"implemented. Please open an issue.")
if not current_platform.supports_mx():
self.emulate = True
logger.warning_once(
"The current platform does not support native MXFP4 "
"computation. Simulated weight dequantization and activation "
"QDQ (quantize and dequantize) will be used, with the linear "
"layers computed in high precision.")
else:
self.emulate = True
logger.warning_once(
"The current platform supports native MXFP4 "
"computation, but kernels are not yet integrated in vLLM. "
"Simulated weight dequantization and activation "
"QDQ (quantize and dequantize) will be used, with the linear "
"layers computed in high precision.")
@classmethod
def get_min_capability(cls) -> int:
@ -37,43 +61,6 @@ class QuarkW4A4MXFP4(QuarkScheme):
layer.weight_scale = torch.nn.Parameter(layer.weight_scale.data,
requires_grad=False)
if self.emulate:
try:
from quark.torch.export.nn.modules import realquantizer
from quark.torch.quantization.config.config import (
QuantizationSpec)
except ImportError as err:
raise ImportError(
"The package `amd-quark` is required to use AMD Quark "
"MX-FP4 models. Please install it with `pip install "
"amd-quark`.") from err
weight_quant_spec = QuantizationSpec.from_dict(
self.weight_quant_spec)
weight_quantizer = realquantizer.get_real_quantizer(
qspec=weight_quant_spec,
quantizer=None,
real_quantized=True,
reorder=False,
float_dtype=self.out_dtype,
scale_shape=layer.weight_scale.shape,
zero_point_shape=None,
)
weight_quantizer.scale.data = layer.weight_scale.data
if not envs.VLLM_QUARK_EMU_MEM_OPT:
layer.weight = torch.nn.Parameter(
weight_quantizer(layer.weight.data).to(self.out_dtype),
requires_grad=False,
)
else:
self.weight_quantizer = weight_quantizer
layer.weight_scale = None
# This call is necessary to release the scales memory.
torch.cuda.empty_cache()
def create_weights(self, layer: torch.nn.Module,
output_partition_sizes: list[int],
input_size_per_partition: int,
@ -116,11 +103,10 @@ class QuarkW4A4MXFP4(QuarkScheme):
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
if self.emulate:
if envs.VLLM_QUARK_EMU_MEM_OPT:
dq_w = self.weight_quantizer(layer.weight).to(self.out_dtype)
else:
dq_w = layer.weight
qdq_x, _ = per_token_group_quant_mxfp4(x, OCP_MX_BLOCK_SIZE)
return F.linear(qdq_x, dq_w, bias)
dq_w = dequant_mxfp4(layer.weight, layer.weight_scale, x.dtype)
x = quant_dequant_mxfp4(x)
return F.linear(x, dq_w, bias)
else:
raise NotImplementedError()

View File

@ -1,45 +1,67 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
from vllm.utils import direct_register_custom_op
OCP_MX_BLOCK_SIZE = 32
def per_token_group_quant_mxfp4(x: torch.Tensor,
block_k: int,
scale_calculation_mode: str = "even"
) -> tuple[torch.Tensor, torch.Tensor]:
def _dequant_mxfp4(x: torch.Tensor, scale: torch.Tensor,
float_dtype: torch.dtype) -> torch.Tensor:
try:
from quark.torch.kernel.hw_emulation.hw_emulation_interface import (
fake_quantize_fp4_fp6_per_group_with_scale)
from quark.torch.quantization.utils import (even_round,
reshape_to_blocks)
from quark.torch.kernel import mx
except ImportError as err:
raise ImportError("The package `amd-quark` is required to use "
"MX-FP4 models. Please install it with `pip install "
"amd-quark`.") from err
axis = -1
block_x = reshape_to_blocks(x, block_k, axis)
amax, _ = torch.max(torch.abs(block_x), dim=-1, keepdim=True)
amax = amax.squeeze(-1)
return mx.dq_mxfp4(x, scale, float_dtype)
# TODO: there are other rounding strategies supported in quark and in the
# config.json that we do not check for here!
if scale_calculation_mode != "even":
raise NotImplementedError(
f"Scale calculation mode {scale_calculation_mode} is not yet "
"supported in MX-FP4 quantization")
scale = even_round(amax, "fp4")
# Apply dequantize(quantize(x)).
x = fake_quantize_fp4_fp6_per_group_with_scale(
x,
scale.to(x.device),
axis=axis,
group_size=block_k,
quant_dtype="fp4",
def _dequant_mxfp4_fake(x: torch.Tensor, scale: torch.Tensor,
float_dtype: torch.dtype) -> torch.Tensor:
return torch.empty((*x.shape[:-1], x.shape[-1] * 2),
dtype=float_dtype,
device=x.device)
def _quant_dequant_mxfp4(x: torch.Tensor,
scale_calculation_mode: str = "even") -> torch.Tensor:
try:
from quark.torch.kernel import mx
except ImportError as err:
raise ImportError("The package `amd-quark` is required to use "
"MX-FP4 models. Please install it with `pip install "
"amd-quark`.") from err
return mx.qdq_mxfp4(x, scale_calculation_mode)
def _quant_dequant_mxfp4_fake(x: torch.Tensor,
scale_calculation_mode: str = "even"
) -> torch.Tensor:
return torch.empty_like(x)
try:
direct_register_custom_op(
op_name="dequant_mxfp4",
op_func=_dequant_mxfp4,
mutates_args=[],
fake_impl=_dequant_mxfp4_fake,
)
dequant_mxfp4 = torch.ops.vllm.dequant_mxfp4
except AttributeError as error:
raise error
return x, scale
try:
direct_register_custom_op(
op_name="quant_dequant_mxfp4",
op_func=_quant_dequant_mxfp4,
mutates_args=[],
fake_impl=_quant_dequant_mxfp4_fake,
)
quant_dequant_mxfp4 = torch.ops.vllm.quant_dequant_mxfp4
except AttributeError as error:
raise error