[Feature][OCP MX] Support mxfp6 and mixed mxfp6-mxfp4 (#21166)

This commit is contained in:
fxmarty-amd 2025-10-07 15:35:26 +02:00 committed by GitHub
parent 08d26a1b7e
commit 41f1cf38f2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
18 changed files with 656 additions and 180 deletions

View File

@ -231,9 +231,9 @@ python3 quantize_quark.py --model_dir meta-llama/Llama-2-70b-chat-hf \
--tasks gsm8k
```
## Using MXFP4 models
## Using OCP MX (MXFP4, MXFP6) models
vLLM supports loading MXFP4 models quantized offline through AMD Quark, compliant with [Open Compute Project (OCP) specification](https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf).
vLLM supports loading MXFP4 and MXFP6 models quantized offline through AMD Quark, compliant with [Open Compute Project (OCP) specification](https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf).
The scheme currently only supports dynamic quantization for activations.
@ -241,17 +241,21 @@ Example usage, after installing the latest AMD Quark release:
```bash
vllm serve fxmarty/qwen_1.5-moe-a2.7b-mxfp4 --tensor-parallel-size 1
# or, for a model using fp6 activations and fp4 weights:
vllm serve fxmarty/qwen1.5_moe_a2.7b_chat_w_fp4_a_fp6_e2m3 --tensor-parallel-size 1
```
A simulation of the matrix multiplication execution in MXFP4 can be run on devices that do not support MXFP4 operations natively (e.g. AMD Instinct MI325, MI300 and MI250), dequantizing weights from MXFP4 to half precision on the fly, using a fused kernel. This is useful e.g. to evaluate MXFP4 models using vLLM, or alternatively to benefit from the ~4x memory savings (compared to float16 and bfloat16).
A simulation of the matrix multiplication execution in MXFP4/MXFP6 can be run on devices that do not support OCP MX operations natively (e.g. AMD Instinct MI325, MI300 and MI250), dequantizing weights from FP4/FP6 to half precision on the fly, using a fused kernel. This is useful e.g. to evaluate FP4/FP6 models using vLLM, or alternatively to benefit from the ~2.5-4x memory savings (compared to float16 and bfloat16).
To generate offline models quantized using MXFP4 data type, the easiest approach is to use AMD Quark's [quantization script](https://quark.docs.amd.com/latest/pytorch/example_quark_torch_llm_ptq.html), as an example:
```bash
python quantize_quark.py --model_dir Qwen/Qwen1.5-MoE-A2.7B-Chat \
--quant_scheme w_mxfp4_a_mxfp4_sym \
--quant_scheme w_mxfp4_a_mxfp4 \
--output_dir qwen_1.5-moe-a2.7b-mxfp4 \
--skip_evaluation \
--model_export hf_format \
--group_size 32
```
The current integration supports [all combination of FP4, FP6_E3M2, FP6_E2M3](https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/utils/ocp_mx_utils.py) used for either weights or activations. Eventually, some target hardware support mixed precision GEMM, as AMD Instinct MI350/MI355, for example using FP6 for activations and FP4 for weights.

View File

@ -10,13 +10,6 @@ import pytest
import torch
from packaging import version
from vllm.model_executor.layers.quantization.quark.quark import (
QuarkLinearMethod,
QuarkW4A4MXFP4,
)
from vllm.model_executor.layers.quantization.quark.quark_moe import (
QuarkW4A4MXFp4MoEMethod,
)
from vllm.platforms import current_platform
from vllm.utils.flashinfer import has_flashinfer
@ -63,9 +56,11 @@ def enable_pickle(monkeypatch):
@pytest.mark.parametrize(
"model_case",
[
ModelCase("fxmarty/qwen_1.5-moe-a2.7b-mxfp4", tp=1),
ModelCase("fxmarty/qwen_1.5-moe-a2.7b-mxfp4", tp=2),
ModelCase("fxmarty/deepseek_r1_3_layers_mxfp4", tp=8),
ModelCase("fxmarty/Llama-4-Scout-17B-16E-Instruct-2-layers-mxfp4", tp=1),
ModelCase("fxmarty/Llama-3.1-70B-Instruct-2-layers-mxfp6", tp=1),
ModelCase("fxmarty/Llama-3.1-70B-Instruct-2-layers-mxfp6", tp=4),
],
)
@pytest.mark.skipif(not QUARK_MXFP4_AVAILABLE, reason="amd-quark>=0.9 is not available")
@ -76,22 +71,33 @@ def test_mxfp4_loading_and_execution_moe(vllm_runner, model_case: ModelCase):
f"{torch.cuda.device_count()}"
)
# `cuda_graph_sizes=[16]` to reduce load time.
with vllm_runner(
model_case.model_id, tensor_parallel_size=model_case.tp, load_format="dummy"
model_case.model_id,
tensor_parallel_size=model_case.tp,
load_format="dummy",
cuda_graph_sizes=[16],
) as llm:
# Disabled as check_model is broken: https://github.com/vllm-project/vllm/pull/18465#issuecomment-3329880562
# def check_model(model):
# from vllm.model_executor.layers.quantization.quark.quark import ( # noqa: E501
# QuarkLinearMethod)
# from vllm.model_executor.layers.quantization.quark.schemes.quark_ocp_mx import QuarkOCP_MX # noqa: E501
# from vllm.model_executor.layers.quantization.quark.quark_moe import ( # noqa: E501
# QuarkOCP_MX_MoEMethod)
def check_model(model):
layer = model.model.layers[0]
# layer = model.model.layers[0]
qkv_proj = layer.self_attn.qkv_proj
# qkv_proj = layer.self_attn.qkv_proj
assert isinstance(qkv_proj.quant_method, QuarkLinearMethod)
assert isinstance(qkv_proj.scheme, QuarkW4A4MXFP4)
# assert isinstance(qkv_proj.quant_method, QuarkLinearMethod)
# assert isinstance(qkv_proj.scheme, QuarkOCP_MX)
assert isinstance(layer.mlp.experts.quant_method, QuarkW4A4MXFp4MoEMethod)
# assert isinstance(layer.mlp.experts.quant_method,
# QuarkOCP_MX_MoEMethod)
if model_case.model_id == "fxmarty/qwen_1.5-moe-a2.7b-mxfp4":
llm.apply_model(check_model)
# if model_case.model_id == "fxmarty/qwen_1.5-moe-a2.7b-mxfp4":
# llm.apply_model(check_model)
output = llm.generate_greedy("Today I am in the French Alps and", max_tokens=20)
assert output

View File

@ -11,6 +11,7 @@ import importlib.metadata
import os
from dataclasses import dataclass
from importlib.util import find_spec
from typing import Optional
import huggingface_hub
import lm_eval
@ -148,39 +149,93 @@ def test_quark_fp8_parity(vllm_runner):
@dataclass
class ModelCase:
model_id: str
tp: int
@dataclass
class GSM8KAccuracyTestConfig:
class AccuracyTestConfig:
model_name: str
excepted_value: float
def get_model_args(self) -> str:
return (
f"pretrained={self.model_name},"
"dtype=auto,add_bos_token=True,tensor_parallel_size=8,gpu_memory_utilization=0.7,max_model_len=38768"
)
def get_model_args(
self,
tp_size: int,
model_max_len: Optional[int] = None,
kwargs: Optional[dict] = None,
) -> dict:
if kwargs is None:
kwargs = {}
model_args = {
"pretrained": self.model_name,
"dtype": "auto",
"add_bos_token": True,
"tensor_parallel_size": tp_size,
"gpu_memory_utilization": 0.7,
**kwargs,
}
if model_max_len is not None:
model_args["max_model_len"] = model_max_len
return model_args
ACCURACY_CONFIGS = [
GSM8K_ACCURACY_CONFIGS = [
# Private model.
GSM8KAccuracyTestConfig(
AccuracyTestConfig(
model_name="amd/DeepSeek-R1-WMXFP4-AMXFP4-Scale-UINT8-MoE-Quant",
excepted_value=0.96,
),
]
WIKITEXT_ACCURACY_CONFIGS = [
AccuracyTestConfig(
model_name="fxmarty/qwen1.5_moe_a2.7b_chat_w_fp4_a_fp6_e2m3",
excepted_value=11.3,
),
AccuracyTestConfig(
model_name="fxmarty/qwen1.5_moe_a2.7b_chat_w_fp6_e3m2_a_fp6_e3m2",
excepted_value=10.6,
),
AccuracyTestConfig(
model_name="fxmarty/qwen_1.5-moe-a2.7b-mxfp4", excepted_value=12.4
),
]
@pytest.mark.parametrize("config", ACCURACY_CONFIGS)
@pytest.mark.skipif(not QUARK_MXFP4_AVAILABLE, reason="amd-quark>=0.9 is not available")
@pytest.mark.parametrize("config", WIKITEXT_ACCURACY_CONFIGS)
@pytest.mark.parametrize("tp_size", [1, 2])
def test_ocp_mx_wikitext_correctness(config: AccuracyTestConfig, tp_size: int):
if torch.cuda.device_count() < tp_size:
pytest.skip(
f"This test requires >={tp_size} gpus, got only {torch.cuda.device_count()}"
)
task = "wikitext"
rtol = 0.1
# Smaller cuda_graph_sizes to speed up the test.
results = lm_eval.simple_evaluate(
model="vllm",
model_args=config.get_model_args(
tp_size=tp_size, kwargs={"cuda_graph_sizes": [16]}
),
tasks=task,
batch_size=64,
)
EXPECTED_VALUE = config.excepted_value
measured_value = results["results"][task]["word_perplexity,none"]
assert (
measured_value < EXPECTED_VALUE + rtol
and measured_value > EXPECTED_VALUE - rtol
), f"Expected: {EXPECTED_VALUE} | Measured: {measured_value}"
@pytest.mark.parametrize("config", GSM8K_ACCURACY_CONFIGS)
@pytest.mark.skipif(not QUARK_MXFP4_AVAILABLE, reason="amd-quark>=0.9 is not available")
@pytest.mark.skipif(
not HF_HUB_AMD_ORG_ACCESS,
reason="Read access to huggingface.co/amd is required for this test.",
)
def test_mxfp4_gsm8k_correctness(config: GSM8KAccuracyTestConfig):
def test_mxfp4_gsm8k_correctness(config: AccuracyTestConfig):
if torch.cuda.device_count() < 8:
pytest.skip(
f"This test requires >=8 gpus, got only {torch.cuda.device_count()}"
@ -193,7 +248,7 @@ def test_mxfp4_gsm8k_correctness(config: GSM8KAccuracyTestConfig):
results = lm_eval.simple_evaluate(
model="vllm",
model_args=config.get_model_args(),
model_args=config.get_model_args(tp_size=8, model_max_len=38768),
tasks=task,
batch_size=64,
num_fewshot=8,

View File

@ -9,6 +9,10 @@ import vllm.envs as envs
from vllm.config import ParallelConfig
from vllm.distributed import get_dp_group, get_tensor_model_parallel_rank
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.utils.ocp_mx_utils import (
OCP_MX_DTYPES,
OCP_MX_Scheme,
)
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
from vllm.utils import cdiv, has_triton_kernels
from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe
@ -30,7 +34,7 @@ def _get_config_dtype_str(
use_fp8_w8a8: bool = False,
use_int8_w8a16: bool = False,
use_int4_w4a16: bool = False,
use_mxfp4_w4a4: bool = False,
ocp_mx_scheme: Optional[str] = None,
) -> Optional[str]:
"""
Return a string used to construct the filename that contains the
@ -43,8 +47,11 @@ def _get_config_dtype_str(
return "int8_w8a16"
elif use_int4_w4a16:
return "int4_w4a16"
elif use_mxfp4_w4a4:
return "mxfp4_w4a4"
elif ocp_mx_scheme is not None:
# The output of this function is passed to `try_get_optimal_moe_config`,
# and as we only simulate OCP MX execution in fused_moe for now,
# we will NOT look for `*,dtype=w_mxfp4_a_mxfp4.json` for now.
return None
elif dtype == torch.float:
# avoiding cases where kernel fails when float32 MoE
# use fp16/bfloat16 configs
@ -289,8 +296,23 @@ class FusedMoEQuantConfig:
return self._a1.dtype is None and self._w1.dtype == "int4"
@property
def use_mxfp4_w4a4(self) -> bool:
return self._a1.dtype == "mxfp4" and self._w1.dtype == "mxfp4"
def ocp_mx_scheme(self) -> Union[str, None]:
if not hasattr(self, "_ocp_mx_scheme"):
if (self._a1.dtype is not None and not isinstance(self._a1.dtype, str)) or (
self._w1.dtype is not None and not isinstance(self._w1.dtype, str)
):
self._ocp_mx_scheme = None
else:
ocp_mx_scheme = OCP_MX_Scheme.from_quant_dtype(
self._a1.dtype, self._w1.dtype
)
if ocp_mx_scheme is not None:
ocp_mx_scheme = ocp_mx_scheme.value
self._ocp_mx_scheme = ocp_mx_scheme
return self._ocp_mx_scheme
@property
def use_mxfp4_w4a16(self) -> bool:
@ -310,7 +332,7 @@ class FusedMoEQuantConfig:
use_fp8_w8a8=self.use_fp8_w8a8,
use_int8_w8a16=self.use_int8_w8a16,
use_int4_w4a16=self.use_int4_w4a16,
use_mxfp4_w4a4=self.use_mxfp4_w4a4,
ocp_mx_scheme=self.ocp_mx_scheme,
dtype=dtype,
)
@ -371,12 +393,14 @@ class FusedMoEQuantConfig:
w2_bias: Optional[torch.Tensor] = None,
w1_zp: Optional[torch.Tensor] = None,
w2_zp: Optional[torch.Tensor] = None,
weight_dtype: Union[torch.dtype, str, None] = None,
) -> "FusedMoEQuantConfig":
"""
General builder function for a FusedMoEQuantConfig.
- quant_dtype: Optional quantization type. None if activations are
unquantized or quantized prior to calling. Note: "nvfp4" and
"mxfp4" are the only valid string values for quant_dtype.
unquantized or quantized prior to calling. Note: "nvfp4", "mxfp4",
"mxfp6_e3m2", "mxfp6_e2m3" are the only valid string values
for quant_dtype.
- per_act_token_quant: Activations have per token quantization.
- per_out_ch_quant: Outputs have per channel quantization. (only
for cutlass).
@ -395,11 +419,22 @@ class FusedMoEQuantConfig:
- w1_zp: Optional w1 zero points for int4/int8 quantization.
- w2_zp: Optional w2 zero points for int4/int8 quantization.
"""
assert (
not isinstance(quant_dtype, str)
or quant_dtype == "nvfp4"
or quant_dtype == "mxfp4"
)
assert not isinstance(quant_dtype, str) or quant_dtype in {
"nvfp4",
"mxfp4",
"mxfp6_e3m2",
"mxfp6_e2m3",
}
assert not isinstance(weight_dtype, str) or weight_dtype in {
"nvfp4",
"mxfp4",
"mxfp6_e3m2",
"mxfp6_e2m3",
}
if weight_dtype is None:
weight_dtype = quant_dtype
a_shape, w_shape = _quant_flags_to_group_shape(
quant_dtype, per_act_token_quant, per_out_ch_quant, block_shape
)
@ -407,10 +442,10 @@ class FusedMoEQuantConfig:
_a1=FusedMoEQuantDesc(quant_dtype, a_shape, a1_scale, a1_gscale),
_a2=FusedMoEQuantDesc(quant_dtype, a_shape, a2_scale, a2_gscale),
_w1=FusedMoEQuantDesc(
quant_dtype, w_shape, w1_scale, g1_alphas, w1_zp, w1_bias
weight_dtype, w_shape, w1_scale, g1_alphas, w1_zp, w1_bias
),
_w2=FusedMoEQuantDesc(
quant_dtype, w_shape, w2_scale, g2_alphas, w2_zp, w2_bias
weight_dtype, w_shape, w2_scale, g2_alphas, w2_zp, w2_bias
),
)
assert quant_config.per_act_token_quant == per_act_token_quant
@ -482,9 +517,11 @@ def mxfp4_w4a16_moe_quant_config(
)
def mxfp4_w4a4_moe_quant_config(
def ocp_mx_moe_quant_config(
quant_dtype: str,
w1_scale: Union[torch.Tensor, "PrecisionConfig"],
w2_scale: Union[torch.Tensor, "PrecisionConfig"],
weight_dtype: Optional[str] = None,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
w1_bias: Optional[torch.Tensor] = None,
@ -494,8 +531,10 @@ def mxfp4_w4a4_moe_quant_config(
"""
Construct a quant config for mxfp4 activations and mxfp4 weights.
"""
assert quant_dtype in OCP_MX_DTYPES
return FusedMoEQuantConfig.make(
"mxfp4",
quant_dtype=quant_dtype,
weight_dtype=weight_dtype,
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a1_scale,

View File

@ -640,7 +640,7 @@ class NaiveBatchedExperts(mk.FusedMoEPermuteExpertsUnpermute):
assert not self.quant_config.use_int8_w8a8, "NYI"
assert not self.quant_config.use_int8_w8a16, "NYI"
assert not self.quant_config.use_int4_w4a16, "NYI"
assert not self.quant_config.use_mxfp4_w4a4, "NYI"
assert self.quant_config.ocp_mx_scheme is None, "NYI"
self.max_num_tokens = max_num_tokens
self.num_dispatchers = num_dispatchers
@ -835,7 +835,7 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
assert not self.quant_config.use_int8_w8a8, "NYI"
assert not self.quant_config.use_int8_w8a16, "NYI"
assert not self.quant_config.use_int4_w4a16, "NYI"
assert not self.quant_config.use_mxfp4_w4a4, "NYI"
assert self.quant_config.ocp_mx_scheme is None, "NYI"
assert max_num_tokens > 0
assert num_dispatchers > 0
self.max_num_tokens = max_num_tokens

View File

@ -42,6 +42,8 @@ from vllm.model_executor.layers.fused_moe.utils import (
moe_kernel_quantize_input,
)
from vllm.model_executor.layers.quantization.utils.mxfp4_utils import dequant_mxfp4
from vllm.model_executor.layers.quantization.utils.mxfp6_utils import dequant_mxfp6
from vllm.model_executor.layers.quantization.utils.ocp_mx_utils import OCP_MX_Scheme
from vllm.platforms import current_platform
from vllm.triton_utils import tl, triton
from vllm.utils import direct_register_custom_op, is_torch_equal_or_newer
@ -1323,7 +1325,7 @@ def inplace_fused_experts(
use_int8_w8a8: bool = False,
use_int8_w8a16: bool = False,
use_int4_w4a16: bool = False,
use_mxfp4_w4a4: bool = False,
ocp_mx_scheme: Optional[str] = None,
per_channel_quant: bool = False,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
@ -1350,7 +1352,7 @@ def inplace_fused_experts(
use_int8_w8a8,
use_int8_w8a16,
use_int4_w4a16,
use_mxfp4_w4a4,
ocp_mx_scheme,
per_channel_quant,
global_num_experts,
expert_map,
@ -1378,7 +1380,7 @@ def inplace_fused_experts_fake(
use_int8_w8a8: bool = False,
use_int8_w8a16: bool = False,
use_int4_w4a16: bool = False,
use_mxfp4_w4a4: bool = False,
ocp_mx_scheme: Optional[str] = None,
per_channel_quant: bool = False,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
@ -1420,7 +1422,7 @@ def outplace_fused_experts(
use_int8_w8a8: bool = False,
use_int8_w8a16: bool = False,
use_int4_w4a16: bool = False,
use_mxfp4_w4a4: bool = False,
ocp_mx_scheme: Optional[str] = None,
per_channel_quant: bool = False,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
@ -1447,7 +1449,7 @@ def outplace_fused_experts(
use_int8_w8a8,
use_int8_w8a16,
use_int4_w4a16,
use_mxfp4_w4a4,
ocp_mx_scheme,
per_channel_quant,
global_num_experts,
expert_map,
@ -1474,7 +1476,7 @@ def outplace_fused_experts_fake(
use_int8_w8a8: bool = False,
use_int8_w8a16: bool = False,
use_int4_w4a16: bool = False,
use_mxfp4_w4a4: bool = False,
ocp_mx_scheme: Optional[str] = None,
per_channel_quant: bool = False,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
@ -1599,7 +1601,7 @@ def fused_experts(
use_int8_w8a8=quant_config.use_int8_w8a8,
use_int8_w8a16=quant_config.use_int8_w8a16,
use_int4_w4a16=quant_config.use_int4_w4a16,
use_mxfp4_w4a4=quant_config.use_mxfp4_w4a4,
ocp_mx_scheme=quant_config.ocp_mx_scheme,
per_channel_quant=quant_config.per_act_token_quant,
global_num_experts=global_num_experts,
expert_map=expert_map,
@ -1622,7 +1624,7 @@ GELU_NO_MUL: str = activation_without_mul("gelu")
def _get_config_quant_dtype(
use_fp8_w8a8: bool,
use_int8_w8a8: bool,
use_mxfp4_w4a4: bool,
ocp_mx_scheme: Optional[str],
) -> Union[None, torch.dtype, str]:
"""
Get the quantization type based on the quantization strategy flags.
@ -1635,8 +1637,12 @@ def _get_config_quant_dtype(
return torch.float8_e4m3fn
elif use_int8_w8a8:
return torch.int8
elif use_mxfp4_w4a4:
elif ocp_mx_scheme == "w_mxfp4_a_mxfp4":
return "mxfp4"
elif ocp_mx_scheme in {"w_mxfp4_a_mxfp6_e3m2", "w_mxfp6_e3m2_a_mxfp6_e3m2"}:
return "mxfp6_e3m2"
elif ocp_mx_scheme in {"w_mxfp4_a_mxfp6_e2m3", "w_mxfp6_e2m3_a_mxfp6_e2m3"}:
return "mxfp6_e2m3"
return None
@ -1653,7 +1659,7 @@ def fused_experts_impl(
use_int8_w8a8: bool = False,
use_int8_w8a16: bool = False,
use_int4_w4a16: bool = False,
use_mxfp4_w4a4: bool = False,
ocp_mx_scheme: Optional[str] = None,
per_channel_quant: bool = False,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
@ -1670,9 +1676,23 @@ def fused_experts_impl(
# Check constraints.
if use_int4_w4a16:
assert hidden_states.size(1) // 2 == w1.size(2), "Hidden size mismatch"
elif use_mxfp4_w4a4:
# 16bit activation and fp4x2 packed weight
assert hidden_states.size(1) // 2 == w1.size(2), "hidden size mismatch"
elif ocp_mx_scheme is not None:
if ocp_mx_scheme in {
"w_mxfp4_a_mxfp4",
"w_mxfp4_a_mxfp6_e3m2",
"w_mxfp4_a_mxfp6_e2m3",
}:
# 16bit activation and fp4x2 packed weight
assert hidden_states.size(1) == w1.size(2) * 2, "hidden size mismatch"
elif ocp_mx_scheme in {
"w_mxfp6_e3m2_a_mxfp6_e3m2",
"w_mxfp6_e2m3_a_mxfp6_e2m3",
}:
assert hidden_states.size(1) == (w1.size(2) * 4) // 3, (
"hidden size mismatch"
)
else:
raise NotImplementedError(f"Unsupported ocp_mx_scheme={ocp_mx_scheme}")
else:
assert hidden_states.size(1) == w1.size(2), (
f"Hidden size mismatch {hidden_states.size(1)} != {w1.size(2)}"
@ -1699,7 +1719,7 @@ def fused_experts_impl(
use_fp8_w8a8=use_fp8_w8a8,
use_int8_w8a16=use_int8_w8a16,
use_int4_w4a16=use_int4_w4a16,
use_mxfp4_w4a4=use_mxfp4_w4a4,
ocp_mx_scheme=ocp_mx_scheme,
dtype=hidden_states.dtype,
)
@ -1708,7 +1728,7 @@ def fused_experts_impl(
quant_dtype = _get_config_quant_dtype(
use_fp8_w8a8=use_fp8_w8a8,
use_int8_w8a8=use_int8_w8a8,
use_mxfp4_w4a4=use_mxfp4_w4a4,
ocp_mx_scheme=ocp_mx_scheme,
)
get_config_func = functools.partial(
@ -1748,12 +1768,40 @@ def fused_experts_impl(
out_hidden_states = hidden_states if inplace else torch.empty_like(hidden_states)
if use_mxfp4_w4a4:
# Weight has to be dequantized for mxfp4 emulation.
w1 = dequant_mxfp4(w1, w1_scale, hidden_states.dtype)
w1_scale = None
w2 = dequant_mxfp4(w2, w2_scale, hidden_states.dtype)
w2_scale = None
if ocp_mx_scheme is not None:
# TODO: On platforms for which `current_platform.supports_mx()` is True
# and for which we have a native OCP mx fused MOE kernel,
# this dequantization step should not be done.
if ocp_mx_scheme in {
OCP_MX_Scheme.w_mxfp4_a_mxfp4,
OCP_MX_Scheme.w_mxfp4_a_mxfp6_e3m2,
OCP_MX_Scheme.w_mxfp4_a_mxfp6_e2m3,
}:
# Weight has to be dequantized for mxfp4 emulation.
w1 = dequant_mxfp4(w1, w1_scale, hidden_states.dtype)
w1_scale = None
w2 = dequant_mxfp4(w2, w2_scale, hidden_states.dtype)
w2_scale = None
elif ocp_mx_scheme == OCP_MX_Scheme.w_mxfp6_e3m2_a_mxfp6_e3m2:
w1 = dequant_mxfp6(
w1, w1_scale, quant_dtype="fp6_e3m2", float_dtype=hidden_states.dtype
)
w1_scale = None
w2 = dequant_mxfp6(
w2, w2_scale, quant_dtype="fp6_e3m2", float_dtype=hidden_states.dtype
)
w2_scale = None
elif ocp_mx_scheme == OCP_MX_Scheme.w_mxfp6_e2m3_a_mxfp6_e2m3:
w1 = dequant_mxfp6(
w1, w1_scale, quant_dtype="fp6_e2m3", float_dtype=hidden_states.dtype
)
w1_scale = None
w2 = dequant_mxfp6(
w2, w2_scale, quant_dtype="fp6_e2m3", float_dtype=hidden_states.dtype
)
w2_scale = None
else:
raise NotImplementedError(f"Unsupported ocp_mx_scheme={ocp_mx_scheme}")
for chunk in range((num_tokens // CHUNK_SIZE) + 1):
begin_chunk_idx, end_chunk_idx = (

View File

@ -16,11 +16,15 @@ from vllm.model_executor.layers.quantization.utils.int8_utils import (
from vllm.model_executor.layers.quantization.utils.mxfp4_utils import (
quant_dequant_mxfp4,
)
from vllm.model_executor.layers.quantization.utils.mxfp8_utils import mxfp8_quantize
from vllm.platforms import current_platform
from vllm.model_executor.layers.quantization.utils.mxfp6_utils import (
quant_dequant_mxfp6,
)
from vllm.model_executor.layers.quantization.utils.mxfp8_utils import (
mxfp8_e4m3_quantize,
)
from vllm.triton_utils import tl, triton
from vllm.utils import cdiv
from vllm.utils.flashinfer import fp4_quantize
from vllm.utils.flashinfer import flashinfer_fp4_quantize
@triton.jit
@ -106,12 +110,14 @@ def _resize_cache(x: torch.Tensor, v: tuple[int, ...]) -> torch.Tensor:
return x.flatten()[: prod(v)].view(*v)
def _fp4_quantize(
def _nvfp4_quantize(
A: torch.Tensor,
A_scale: Optional[torch.Tensor],
is_sf_swizzled_layout: bool,
) -> tuple[torch.Tensor, torch.Tensor]:
return fp4_quantize(A, A_scale, is_sf_swizzled_layout=is_sf_swizzled_layout)
return flashinfer_fp4_quantize(
A, A_scale, is_sf_swizzled_layout=is_sf_swizzled_layout
)
def _fp8_quantize(
@ -174,15 +180,16 @@ def _mxfp4_quantize(
block_shape: Optional[list[int]] = None,
) -> tuple[torch.Tensor, None]:
assert block_shape is None
if not current_platform.supports_mx():
A = quant_dequant_mxfp4(A)
else:
raise NotImplementedError()
# TODO: native mxfp4 is currently not integrated in vllm,
# so simulating even on devices supporting this data type natively.
# Once integrated, `current_platform.supports_mx()` should be used to
# control quantize+dequantize, or simply quantize here down to mxfp4.
A = quant_dequant_mxfp4(A)
return A, None
def _mxfp8_quantize(
def _mxfp8_e4m3_quantize(
A: torch.Tensor,
A_scale: Optional[torch.Tensor],
per_act_token_quant: bool,
@ -191,7 +198,41 @@ def _mxfp8_quantize(
assert A_scale is None
assert not per_act_token_quant
assert block_shape is None
return mxfp8_quantize(A)
return mxfp8_e4m3_quantize(A)
def _mxfp6_e3m2_quantize(
A: torch.Tensor,
A_scale: Optional[torch.Tensor],
per_act_token_quant: bool,
block_shape: Optional[list[int]] = None,
) -> tuple[torch.Tensor, None]:
assert block_shape is None
# TODO: native mxfp6 is currently not integrated in vllm,
# so simulating even on devices supporting this data type natively.
# Eventually, there should be a check based on
# `current_platform.supports_mx()` here.
A = quant_dequant_mxfp6(A, quant_dtype="fp6_e3m2")
return A, None
def _mxfp6_e2m3_quantize(
A: torch.Tensor,
A_scale: Optional[torch.Tensor],
per_act_token_quant: bool,
block_shape: Optional[list[int]] = None,
) -> tuple[torch.Tensor, None]:
assert block_shape is None
# TODO: native mxfp6 is currently not integrated in vllm,
# so simulating even on devices supporting this data type natively.
# Eventually, there should be a check based on
# `current_platform.supports_mx()` here.
A = quant_dequant_mxfp6(A, quant_dtype="fp6_e2m3")
return A, None
def moe_kernel_quantize_input(
@ -207,11 +248,17 @@ def moe_kernel_quantize_input(
elif quant_dtype == torch.int8:
return _int8_quantize(A, A_scale, per_act_token_quant, block_shape)
elif quant_dtype == "nvfp4":
return _fp4_quantize(A, A_scale, is_sf_swizzled_layout=is_fp4_scale_swizzled)
return _nvfp4_quantize(A, A_scale, is_sf_swizzled_layout=is_fp4_scale_swizzled)
elif quant_dtype == "mxfp4":
return _mxfp4_quantize(A, A_scale, per_act_token_quant, block_shape)
elif quant_dtype == "mxfp8":
return _mxfp8_quantize(A, A_scale, per_act_token_quant, block_shape)
# TODO: `quant_dtype == "mxfp8"` is ambiguous,
# should be fp8_e4m3. OCP MX also defines `fp8_e5m2`.
return _mxfp8_e4m3_quantize(A, A_scale, per_act_token_quant, block_shape)
elif quant_dtype == "mxfp6_e3m2":
return _mxfp6_e3m2_quantize(A, A_scale, per_act_token_quant, block_shape)
elif quant_dtype == "mxfp6_e2m3":
return _mxfp6_e2m3_quantize(A, A_scale, per_act_token_quant, block_shape)
else:
return A, A_scale

View File

@ -17,8 +17,8 @@ from vllm.model_executor.layers.fused_moe import (
from vllm.model_executor.layers.fused_moe import modular_kernel as mk
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEQuantConfig,
mxfp4_w4a4_moe_quant_config,
mxfp4_w4a16_moe_quant_config,
ocp_mx_moe_quant_config,
)
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import MarlinExperts
from vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe import (
@ -776,7 +776,8 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
else:
w1_scale = layer.w13_weight_scale
w2_scale = layer.w2_weight_scale
return mxfp4_w4a4_moe_quant_config(
return ocp_mx_moe_quant_config(
quant_dtype="mxfp4",
w1_bias=layer.w13_bias,
w2_bias=layer.w2_bias,
w1_scale=w1_scale,

View File

@ -23,8 +23,8 @@ from vllm.model_executor.layers.quantization.quark.quark_moe import ( # noqa: E
QuarkMoEMethod,
)
from vllm.model_executor.layers.quantization.quark.schemes import (
QuarkOCP_MX,
QuarkScheme,
QuarkW4A4MXFP4,
QuarkW8A8Fp8,
QuarkW8A8Int8,
)
@ -235,7 +235,7 @@ class QuarkConfig(QuantizationConfig):
# Only symmetric weight quantization supported.
return is_int8_dtype and is_tensor and is_weight_symmetric and is_static
def _is_mx_fp4(
def _is_ocp_mx(
self,
weight_quant: Optional[dict[str, Any]],
input_quant: Optional[dict[str, Any]],
@ -243,32 +243,22 @@ class QuarkConfig(QuantizationConfig):
# Confirm weights and input quantized.
if weight_quant is None or input_quant is None:
logger.debug(
"Quark model is not in MX-FP4 format: "
"Quark model is not in OCP MX format: "
"weight_quant or input_quant not set"
)
return False
# Input and weight dtype needs to be fp4.
if weight_quant.get("dtype") != "fp4" or input_quant.get("dtype") != "fp4":
logger.debug("Quark model is not in MX-FP4 format: dtype not fp4")
return False
# Input and weight qscheme needs to be per group.
if (
weight_quant.get("qscheme") != "per_group"
or input_quant.get("qscheme") != "per_group"
):
logger.debug("Quark model is not in MX-FP4 format: not per_group")
logger.debug("Quark model is not in OCP MX format: not per_group")
return False
# Input and weight group size needs to be 32.
if weight_quant.get("group_size") != 32 or input_quant.get("group_size") != 32:
logger.debug("Quark model is not in MX-FP4 format: not group_size=32")
return False
# Activations need to use dynamic quantization.
if input_quant.get("is_dynamic") is False:
logger.debug("Quark model is not in MX-FP4 format: not activation dynamic")
logger.debug("Quark model is not in OCP MX format: not group_size=32")
return False
# Activations and weight scales need to be in e8m0 format.
@ -276,7 +266,19 @@ class QuarkConfig(QuantizationConfig):
weight_quant.get("scale_format") != "e8m0"
or input_quant.get("scale_format") != "e8m0"
):
logger.debug("Quark model is not in MX-FP4 format: not scale_format e8m0")
logger.debug("Quark model is not in OCP MX format: not scale_format e8m0")
return False
# Input and weight dtypes need to be any of fp4,
# fp6_e3m2 or fp6_e3m2, possibly mixed.
if weight_quant.get("dtype") not in {
"fp4",
"fp6_e3m2",
"fp6_e2m3",
} or input_quant.get("dtype") not in {"fp4", "fp6_e3m2", "fp6_e2m3"}:
logger.debug(
"Quark model is not in OCP MX format: dtype not fp4, fp6_e3m2, fp6_e2m3"
)
return False
return True
@ -348,8 +350,8 @@ class QuarkConfig(QuantizationConfig):
is_static_input_scheme=True,
input_symmetric=input_config.get("symmetric"),
)
elif self._is_mx_fp4(weight_config, input_config):
return QuarkW4A4MXFP4(weight_config, input_config)
elif self._is_ocp_mx(weight_config, input_config):
return QuarkOCP_MX(weight_config, input_config)
raise NotImplementedError(
"No quark compatible scheme was found. "

View File

@ -17,7 +17,7 @@ from vllm.model_executor.layers.fused_moe import (
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEQuantConfig,
fp8_w8a8_moe_quant_config,
mxfp4_w4a4_moe_quant_config,
ocp_mx_moe_quant_config,
)
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
is_rocm_aiter_moe_enabled,
@ -25,7 +25,10 @@ from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
prepare_moe_fp8_layer_for_marlin,
)
from vllm.model_executor.layers.quantization.utils.mxfp4_utils import OCP_MX_BLOCK_SIZE
from vllm.model_executor.layers.quantization.utils.ocp_mx_utils import (
OCP_MX_BLOCK_SIZE,
OCP_MX_Scheme,
)
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
all_close_1d,
@ -38,7 +41,7 @@ from vllm.scalar_type import scalar_types
logger = init_logger(__name__)
__all__ = ["QuarkMoEMethod", "QuarkW8A8Fp8MoEMethod", "QuarkW4A4MXFp4MoEMethod"]
__all__ = ["QuarkMoEMethod", "QuarkW8A8Fp8MoEMethod", "QuarkOCP_MX_MoEMethod"]
class QuarkMoEMethod(FusedMoEMethodBase):
@ -64,10 +67,8 @@ class QuarkMoEMethod(FusedMoEMethodBase):
if quant_config._is_fp8_w8a8(weight_config, input_config):
return QuarkW8A8Fp8MoEMethod(weight_config, input_config, module.moe_config)
elif quant_config._is_mx_fp4(weight_config, input_config):
return QuarkW4A4MXFp4MoEMethod(
weight_config, input_config, module.moe_config
)
elif quant_config._is_ocp_mx(weight_config, input_config):
return QuarkOCP_MX_MoEMethod(weight_config, input_config, module.moe_config)
else:
raise RuntimeError("Unsupported FusedMoe scheme")
@ -434,7 +435,7 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod):
)
class QuarkW4A4MXFp4MoEMethod(QuarkMoEMethod):
class QuarkOCP_MX_MoEMethod(QuarkMoEMethod):
def __init__(
self,
weight_config: dict[str, Any],
@ -456,16 +457,23 @@ class QuarkW4A4MXFp4MoEMethod(QuarkMoEMethod):
self.static_input_scales = not self.input_quant.get("is_dynamic")
self.weight_dtype = self.weight_quant["dtype"].replace("fp", "mxfp")
self.input_dtype = self.input_quant["dtype"].replace("fp", "mxfp")
self.ocp_mx_scheme = OCP_MX_Scheme.from_quant_dtype(
self.input_dtype, self.weight_dtype
)
if self.static_input_scales:
raise NotImplementedError(
"QuarkW4A4MXFp4MoEMethod with static input scales is currently "
"QuarkOCP_MX_MoEMethod with static input scales is currently "
"not implemented. Please open an issue."
)
if not current_platform.supports_mx():
self.emulate = True
logger.warning_once(
"The current platform does not support native MXFP4 "
"The current platform does not support native MXFP4/MXFP6 "
"computation. Simulated weight dequantization and activation "
"QDQ (quantize and dequantize) will be used, with the linear "
"layers computed in high precision."
@ -473,13 +481,22 @@ class QuarkW4A4MXFp4MoEMethod(QuarkMoEMethod):
else:
self.emulate = True
logger.warning_once(
"The current platform supports native MXFP4 "
"The current platform supports native MXFP4/MXFP6 "
"computation, but kernels are not yet integrated in vLLM. "
"Simulated weight dequantization and activation "
"QDQ (quantize and dequantize) will be used, with the linear "
"layers computed in high precision."
)
def get_packed_dim(self, dim: int, quant_dtype: str):
if quant_dtype == "mxfp4":
assert dim % 2 == 0
return dim // 2
else:
# FP6 packs 4 * 6 = 24 bits on 3 bytes.
assert (dim * 3) % 4 == 0
return (dim * 3) // 4
def create_weights(
self,
layer: torch.nn.Module,
@ -502,7 +519,7 @@ class QuarkW4A4MXFp4MoEMethod(QuarkMoEMethod):
torch.empty(
num_experts,
2 * intermediate_size_per_partition,
hidden_size // 2,
self.get_packed_dim(hidden_size, self.weight_dtype),
dtype=params_dtype,
),
requires_grad=False,
@ -515,7 +532,7 @@ class QuarkW4A4MXFp4MoEMethod(QuarkMoEMethod):
torch.empty(
num_experts,
hidden_size,
intermediate_size_per_partition // 2,
self.get_packed_dim(intermediate_size_per_partition, self.weight_dtype),
dtype=params_dtype,
),
requires_grad=False,
@ -552,7 +569,9 @@ class QuarkW4A4MXFp4MoEMethod(QuarkMoEMethod):
def get_fused_moe_quant_config(
self, layer: torch.nn.Module
) -> Optional[FusedMoEQuantConfig]:
return mxfp4_w4a4_moe_quant_config(
return ocp_mx_moe_quant_config(
quant_dtype=self.input_dtype,
weight_dtype=self.weight_dtype,
w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
a1_scale=None,
@ -587,7 +606,7 @@ class QuarkW4A4MXFp4MoEMethod(QuarkMoEMethod):
if enable_eplb:
raise NotImplementedError(
"EPLB not supported for `QuarkW4A4MXFp4MoEMethod` yet."
"EPLB not supported for `QuarkOCP_MX_MoEMethod` yet."
)
from vllm.model_executor.layers.fused_moe import fused_experts

View File

@ -1,9 +1,9 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from .quark_ocp_mx import QuarkOCP_MX
from .quark_scheme import QuarkScheme
from .quark_w4a4_mxfp4 import QuarkW4A4MXFP4
from .quark_w8a8_fp8 import QuarkW8A8Fp8
from .quark_w8a8_int8 import QuarkW8A8Int8
__all__ = ["QuarkScheme", "QuarkW8A8Fp8", "QuarkW8A8Int8", "QuarkW4A4MXFP4"]
__all__ = ["QuarkScheme", "QuarkW8A8Fp8", "QuarkW8A8Int8", "QuarkOCP_MX"]

View File

@ -1,22 +1,34 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from functools import cache
from typing import Any, Callable, Optional
from fractions import Fraction
from functools import cache, partial
from typing import Any, Callable, Optional, Union
import torch
import torch.nn.functional as F
from vllm import envs
from vllm.model_executor.layers.quantization.quark.schemes import QuarkScheme
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.utils.mxfp4_utils import (
OCP_MX_BLOCK_SIZE,
dequant_mxfp4,
quant_dequant_mxfp4,
)
from vllm.model_executor.layers.quantization.utils.mxfp6_utils import (
dequant_mxfp6,
quant_dequant_mxfp6,
)
from vllm.model_executor.layers.quantization.utils.ocp_mx_utils import (
OCP_MX_BLOCK_SIZE,
OCP_MX_Scheme,
)
from vllm.model_executor.parameter import GroupQuantScaleParameter, PackedvLLMParameter
from vllm.platforms import current_platform
from .quark_scheme import QuarkScheme
logger = init_logger(__name__)
@cache
def is_rocm_aiter_fp4_asm_gemm_enabled() -> bool:
@ -96,14 +108,11 @@ try:
fake_impl=gemm_with_dynamic_quant_fake,
dispatch_key=current_platform.dispatch_key,
)
except (ImportError, AttributeError):
dynamic_mxfp4_quant = gemm_afp4wfp4 = None
__all__ = ["QuarkW4A4MXFP4"]
class QuarkW4A4MXFP4(QuarkScheme):
class QuarkOCP_MX(QuarkScheme):
def __init__(
self, weight_quant_spec: dict[str, Any], input_quant_spec: dict[str, Any]
):
@ -111,8 +120,45 @@ class QuarkW4A4MXFP4(QuarkScheme):
self.qscheme = "per_group"
self.weight_quant_spec = weight_quant_spec
self.input_quant_spec = input_quant_spec
self.emulate = not current_platform.supports_mx()
self.weight_dtype = weight_quant_spec["dtype"].replace("fp", "mxfp")
self.input_dtype = input_quant_spec["dtype"].replace("fp", "mxfp")
self.ocp_mx_scheme = OCP_MX_Scheme.from_quant_dtype(
self.input_dtype, self.weight_dtype
)
if self.weight_dtype == "mxfp4":
self.packed_factor: Union[int, Fraction] = 2
self.dequant_func = dequant_mxfp4
else:
self.packed_factor = Fraction(numerator=8, denominator=6)
self.dequant_func = partial(
dequant_mxfp6, quant_dtype=self.weight_dtype.replace("mx", "")
)
if self.input_dtype == "mxfp4":
self.quant_dequant_func = quant_dequant_mxfp4
else:
self.quant_dequant_func = partial(
quant_dequant_mxfp6, quant_dtype=self.input_dtype.replace("mx", "")
)
self.static_input_scales = not input_quant_spec.get("is_dynamic")
if self.static_input_scales:
raise NotImplementedError(
"QuarkOCP_MX with static input scales is currently not "
"implemented. Please open an issue."
)
# TODO: integrate (or test) mixed-precision kernel.
self.emulate = not current_platform.supports_mx() or (
self.input_dtype != "mxfp4" or self.weight_dtype != "mxfp4"
)
self.rocm_use_aiter_fp4_asm_gemm = is_rocm_aiter_fp4_asm_gemm_enabled()
if not self.emulate and (dynamic_mxfp4_quant is None or gemm_afp4wfp4 is None):
# Currently need these kernels if not emulating
raise NotImplementedError(
@ -121,6 +167,41 @@ class QuarkW4A4MXFP4(QuarkScheme):
"https://github.com/ROCm/aiter for installation details."
)
if not current_platform.supports_mx():
logger.warning_once(
"The current platform does not support native MXFP4/MXFP6 "
"computation. Simulated weight dequantization and activation "
"QDQ (quantize and dequantize) will be used, with the linear "
"layers computed in high precision."
)
if current_platform.supports_mx() and (
self.input_dtype != "mxfp4" or self.weight_dtype != "mxfp4"
):
logger.warning_once(
"The current platform supports native MXFP4/MXFP6 "
f"computation, but kernels for input_dtype={self.input_dtype} "
f"and weight_dtype={self.weight_dtype} are not yet integrated "
"in vLLM. Simulated weight dequantization and activation "
"QDQ (quantize and dequantize) will be used, with the linear "
"layers computed in high precision."
)
def get_packed_dim(self, dim: int, quant_dtype: str):
if quant_dtype == "mxfp4":
assert dim % 2 == 0
return dim // 2
elif quant_dtype in {"mxfp6_e3m2", "mxfp6_e2m3"}:
# FP6 packs 4 * 6 = 24 bits on 3 bytes.
assert (dim * 3) % 4 == 0
return (dim * 3) // 4
else:
raise NotImplementedError(
"Unsupported quant_dtype in QuarkOCP_MX.get_packed_dim, "
f"got quant_dtype={quant_dtype}. Something is wrong, please "
"open an issue."
)
@classmethod
def get_min_capability(cls) -> int:
return 70
@ -132,37 +213,6 @@ class QuarkW4A4MXFP4(QuarkScheme):
layer.weight_scale = torch.nn.Parameter(
layer.weight_scale.data, requires_grad=False
)
try:
from quark.torch.export.nn.modules import realquantizer
from quark.torch.quantization.config.config import QuantizationSpec
except ImportError as err:
raise ImportError(
"The package `amd-quark` is required to use AMD Quark "
"MX-FP4 models. Please install it with `pip install "
"amd-quark`."
) from err
weight_quant_spec = QuantizationSpec.from_dict(self.weight_quant_spec)
weight_quantizer = realquantizer.get_real_quantizer(
qspec=weight_quant_spec,
quantizer=None,
real_quantized=True,
reorder=False,
float_dtype=self.out_dtype,
scale_shape=layer.weight_scale.shape,
zero_point_shape=None,
)
weight_quantizer.scale.data = layer.weight_scale.data
layer.weight = torch.nn.Parameter(
weight_quantizer(layer.weight.data).to(self.out_dtype),
requires_grad=False,
)
layer.weight_scale = None
# This call is necessary to release the scales memory.
torch.cuda.empty_cache()
else:
if self.rocm_use_aiter_fp4_asm_gemm:
# shuffle weight scale
@ -204,13 +254,13 @@ class QuarkW4A4MXFP4(QuarkScheme):
weight = PackedvLLMParameter(
data=torch.empty(
output_size_per_partition,
input_size_per_partition // 2,
self.get_packed_dim(input_size_per_partition, self.weight_dtype),
dtype=torch.uint8,
),
input_dim=1,
output_dim=0,
packed_dim=1,
packed_factor=2,
packed_factor=self.packed_factor,
weight_loader=weight_loader,
)
layer.register_parameter("weight", weight)
@ -235,9 +285,9 @@ class QuarkW4A4MXFP4(QuarkScheme):
bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if self.emulate:
dq_w = dequant_mxfp4(layer.weight, layer.weight_scale, x.dtype)
x = quant_dequant_mxfp4(x)
return F.linear(x, dq_w, bias)
dq_w = self.dequant_func(layer.weight, layer.weight_scale, x.dtype)
qdq_x = self.quant_dequant_func(x)
return F.linear(qdq_x, dq_w, bias)
else:
return torch.ops.vllm.gemm_with_dynamic_quant(
x,

View File

@ -10,8 +10,6 @@ from vllm.utils import direct_register_custom_op, is_torch_equal_or_newer
logger = init_logger(__name__)
OCP_MX_BLOCK_SIZE = 32
def _swizzle_mxfp4(quant_tensor, scale, num_warps):
"""weight swizzle for mxfp4 moe, used for OAI mxfp4 kernel"""
@ -144,6 +142,14 @@ def _quant_dequant_mxfp4_fake(
return torch.empty_like(x)
# Protect these operations into a torch custom op to avoid errors as
# torch._dynamo.exc.Unsupported: Attempted to call function marked as skipped
# Explanation: Dynamo does not know how to trace the builtin
# `kernel_ext.PyCapsule.dq_uint8_mxfp4_to_half.` This function is either a
# Python builtin (e.g. _warnings.warn) or a third-party C/C++ Python
# extension (perhaps created with pybind).
# TODO: Make sure there is no way to avoid having these functions
# marked as skipped by dynamo.
try:
direct_register_custom_op(
op_name="dequant_mxfp4",

View File

@ -0,0 +1,142 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
from vllm.model_executor.layers.quantization.utils.ocp_mx_utils import OCP_MX_BLOCK_SIZE
from vllm.utils import direct_register_custom_op
def _quant_dequant_mxfp6(
x: torch.Tensor,
quant_dtype: str,
scale_calculation_mode: str = "even",
) -> torch.Tensor:
try:
from quark.torch.kernel.hw_emulation.hw_emulation_interface import (
fake_quantize_fp4_fp6_per_group_with_scale,
)
from quark.torch.quantization.utils import even_round, reshape_to_blocks
except ImportError as err:
raise ImportError(
"The package `amd-quark` is required to use "
"MX-FP6 models. Please install it with `pip install "
"amd-quark`."
) from err
axis = -1
block_x = reshape_to_blocks(x, OCP_MX_BLOCK_SIZE, axis)
amax, _ = torch.max(torch.abs(block_x), dim=-1, keepdim=True)
amax = amax.squeeze(-1)
# TODO: there are other rounding strategies supported in quark and in the
# config.json that we do not check for here!
if scale_calculation_mode != "even":
raise NotImplementedError(
f"Scale calculation mode {scale_calculation_mode} is not yet "
"supported in MX-FP6 quantization"
)
scale = even_round(amax, quant_dtype)
# Apply dequantize(quantize(x)).
x = fake_quantize_fp4_fp6_per_group_with_scale(
x,
scale.to(x.device),
axis=axis,
group_size=OCP_MX_BLOCK_SIZE,
quant_dtype=quant_dtype,
)
return x
def _quant_dequant_mxfp6_fake(
x: torch.Tensor,
quant_dtype: str,
scale_calculation_mode: str = "even",
) -> torch.Tensor:
return torch.empty_like(x)
def _dequant_mxfp6(
x: torch.Tensor, scale: torch.Tensor, float_dtype: torch.dtype, quant_dtype: str
) -> torch.Tensor:
try:
from quark.torch.kernel.hw_emulation.hw_emulation_interface import (
dequantize_fp4_fp6_per_group,
)
from quark.torch.utils.pack import create_pack_method
except ImportError as e:
raise ImportError(
"The package `amd-quark` is required to use "
"MX-FP6 models. Please install it with `pip install "
"amd-quark`."
) from e
pack_method = create_pack_method(None, dtype=quant_dtype)
unpacked_x = pack_method.unpack(x, reorder=False)
scale = 2 ** (scale.view(torch.uint8).to(torch.int16) - 127).to(float_dtype)
# TODO: `dequantize_fp4_fp6_per_group` and `prepare_inputs_per_group`
# always return fp32.
return dequantize_fp4_fp6_per_group(
unpacked_x,
scale,
axis=-1,
group_size=OCP_MX_BLOCK_SIZE,
quant_dtype=quant_dtype,
).to(float_dtype)
def _dequant_mxfp6_fake(
x: torch.Tensor, scale: torch.Tensor, float_dtype: torch.dtype, quant_dtype: str
) -> torch.Tensor:
assert (x.shape[-1] * 4) % 3 == 0
return torch.empty(
(*x.shape[:-1], (x.shape[-1] * 4) // 3), dtype=float_dtype, device=x.device
)
# Protect these operations into a torch custom op to avoid errors as
# torch._dynamo.exc.Unsupported: Attempted to call function marked as skipped
# Explanation: Dynamo does not know how to trace the builtin
# `kernel_ext.PyCapsule.dq_uint8_mxfp4_to_half.` This function is either a
# Python builtin (e.g. _warnings.warn) or a third-party C/C++ Python
# extension (perhaps created with pybind).
# TODO: Make sure there is no way to avoid having these functions
# marked as skipped by dynamo.
try:
direct_register_custom_op(
op_name="quant_dequant_mxfp6",
op_func=_quant_dequant_mxfp6,
mutates_args=[],
fake_impl=_quant_dequant_mxfp6_fake,
)
except AttributeError as error:
raise error
# Expose keyword arguments.
def quant_dequant_mxfp6(
x: torch.Tensor,
quant_dtype: str,
scale_calculation_mode: str = "even",
) -> torch.Tensor:
return torch.ops.vllm.quant_dequant_mxfp6(x, quant_dtype, scale_calculation_mode)
try:
direct_register_custom_op(
op_name="dequant_mxfp6",
op_func=_dequant_mxfp6,
mutates_args=[],
fake_impl=_dequant_mxfp6_fake,
)
except AttributeError as error:
raise error
def dequant_mxfp6(
x: torch.Tensor, scale: torch.Tensor, float_dtype: torch.dtype, quant_dtype: str
) -> torch.Tensor:
return torch.ops.vllm.dequant_mxfp6(x, scale, float_dtype, quant_dtype)

View File

@ -8,9 +8,9 @@ from vllm.logger import init_logger
logger = init_logger(__name__)
def mxfp8_quantize(x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
def mxfp8_e4m3_quantize(x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
try:
from flashinfer import mxfp8_quantize
from flashinfer import mxfp8_quantize as mxfp8_e4m3_quantize
except ImportError as err:
raise ImportError(
"The package `flashinfer` is required to do "
@ -18,4 +18,4 @@ def mxfp8_quantize(x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
"`pip install flashinfer`"
) from err
return mxfp8_quantize(x, is_sf_swizzled_layout=False)
return mxfp8_e4m3_quantize(x, is_sf_swizzled_layout=False)

View File

@ -0,0 +1,54 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from enum import Enum
from typing import Union
from vllm.logger import init_logger
logger = init_logger(__name__)
OCP_MX_BLOCK_SIZE = 32
OCP_MX_DTYPES = {
"mxfp4",
"mxfp6_e3m2",
"mxfp6_e2m3",
"mxfp8_e4m3",
"mxfp8_e5m2",
"mxint8",
}
SUPPORTED_OCP_MX_DTYPES = {"mxfp4", "mxfp6_e3m2", "mxfp6_e2m3"}
class OCP_MX_Scheme(str, Enum):
w_mxfp4_a_mxfp4 = "w_mxfp4_a_mxfp4"
w_mxfp4_a_mxfp6_e3m2 = "w_mxfp4_a_mxfp6_e3m2"
w_mxfp4_a_mxfp6_e2m3 = "w_mxfp4_a_mxfp6_e2m3"
w_mxfp6_e3m2_a_mxfp6_e3m2 = "w_mxfp6_e3m2_a_mxfp6_e3m2"
w_mxfp6_e2m3_a_mxfp6_e2m3 = "w_mxfp6_e2m3_a_mxfp6_e2m3"
@classmethod
def from_quant_dtype(
cls, input_dtype: Union[str, None], weight_dtype: Union[str, None]
):
if input_dtype not in OCP_MX_DTYPES or weight_dtype not in OCP_MX_DTYPES:
return None
elif input_dtype == "mxfp4" and weight_dtype == "mxfp4":
return cls.w_mxfp4_a_mxfp4
elif input_dtype == "mxfp6_e3m2" and weight_dtype == "mxfp4":
return cls.w_mxfp4_a_mxfp6_e3m2
elif input_dtype == "mxfp6_e2m3" and weight_dtype == "mxfp4":
return cls.w_mxfp4_a_mxfp6_e2m3
elif input_dtype == "mxfp6_e3m2" and weight_dtype == "mxfp6_e3m2":
return cls.w_mxfp6_e3m2_a_mxfp6_e3m2
elif input_dtype == "mxfp6_e2m3" and weight_dtype == "mxfp6_e2m3":
return cls.w_mxfp6_e2m3_a_mxfp6_e2m3
else:
logger.warning(
"input_dtype='%s' and"
" weight_dtype='%s' is not supported "
"in OCP_MX_Scheme at the moment.",
input_dtype,
weight_dtype,
)
return None

View File

@ -337,8 +337,11 @@ class scalar_types:
float16_e5m10 = ScalarType.float_IEEE754(5, 10)
# fp6, https://github.com/usyd-fsalab/fp6_llm/tree/main
# and https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf
float6_e3m2f = ScalarType.float_(3, 2, True, NanRepr.NONE)
float6_e2m3f = ScalarType.float_(2, 3, True, NanRepr.NONE)
# fp4, https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf
float4_e2m1f = ScalarType.float_(2, 1, True, NanRepr.NONE)

View File

@ -89,7 +89,7 @@ flashinfer_trtllm_fp8_per_tensor_scale_moe = _lazy_import_wrapper(
flashinfer_cutlass_fused_moe = _lazy_import_wrapper(
"flashinfer.fused_moe", "cutlass_fused_moe"
)
fp4_quantize = _lazy_import_wrapper("flashinfer", "fp4_quantize")
flashinfer_fp4_quantize = _lazy_import_wrapper("flashinfer", "fp4_quantize")
nvfp4_block_scale_interleave = _lazy_import_wrapper(
"flashinfer", "nvfp4_block_scale_interleave"
)
@ -442,7 +442,7 @@ __all__ = [
"has_flashinfer",
"flashinfer_trtllm_fp8_block_scale_moe",
"flashinfer_cutlass_fused_moe",
"fp4_quantize",
"flashinfer_fp4_quantize",
"nvfp4_block_scale_interleave",
"trtllm_fp4_block_scale_moe",
"autotune",