diff --git a/tests/models/quantization/test_mxfp4.py b/tests/models/quantization/test_mxfp4.py new file mode 100644 index 0000000000000..9a060829525e1 --- /dev/null +++ b/tests/models/quantization/test_mxfp4.py @@ -0,0 +1,40 @@ +# SPDX-License-Identifier: Apache-2.0 +# flake8: noqa +"""Tests Quark mxfp4 models against ground truth generation +""" +import pytest + +from vllm import LLM, SamplingParams + +MODELS = ["amd/Llama-2-7b-chat-hf-wmxfp4-amxfp4-kvfp8-scale-uint8"] + +EXPECTED_STRS_MAP = { + "amd/Llama-2-7b-chat-hf-wmxfp4-amxfp4-kvfp8-scale-uint8": [ + '\n### Key Features\n\n* **High-throughput Inference**: vLL', + '\nArtificial intelligence (AI) has evolved significantly since its inception in the 1', + 'Artificial intelligence (AI) and human intelligence (HI) are two distinct concepts that have been', + 'A neural network is a machine learning model inspired by the structure of the human brain. It consists of', + '\nTitle: The Dreaming Robot\n\nAs the sun set on the bustling metropol', + '\nThe COVID-19 pandemic has had a profound impact on global economic structures and business', + 'The Mona Lisa painting, created by Leonardo da Vinci in the early 16th', + " everybody knows this proverbial saying, but did you know that it's not entirely accurate?", + ] +} + + +@pytest.mark.skip(reason="Model to be released in the future") +@pytest.mark.quant_model +@pytest.mark.parametrize("model_name", MODELS) +def test_models(example_prompts, model_name) -> None: + sampling_params = SamplingParams(max_tokens=20, temperature=0) + llm = LLM( + model=model_name, + kv_cache_dtype="fp8", + quantization="quark", + ) + outputs = llm.generate(example_prompts, sampling_params) + for i, output in enumerate(outputs): + output_str = output.outputs[0].text + expected_str = EXPECTED_STRS_MAP[model_name][i] + assert expected_str == output_str, ( + f"Expected: {expected_str!r}\nvLLM: {output_str!r}") diff --git a/vllm/envs.py b/vllm/envs.py index ea40bfff11b5b..c8bb39ceb7b22 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -84,6 +84,7 @@ if TYPE_CHECKING: VLLM_ROCM_FP8_PADDING: bool = True VLLM_ROCM_MOE_PADDING: bool = True VLLM_ROCM_CUSTOM_PAGED_ATTN: bool = True + VLLM_QUARK_EMU_MEM_OPT: bool = False VLLM_ENABLE_V1_MULTIPROCESSING: bool = True VLLM_LOG_BATCHSIZE_INTERVAL: float = -1 VLLM_DISABLE_COMPILE_CACHE: bool = False @@ -583,6 +584,14 @@ environment_variables: dict[str, Callable[[], Any]] = { lambda: (os.getenv("VLLM_ROCM_CUSTOM_PAGED_ATTN", "True").lower() in ("true", "1")), + # If set, when running in Quark emulation mode, do not dequantize the + # weights at load time. Instead, dequantize weights on-the-fly during + # kernel execution. + # This allows running larger models at the cost of slower inference. + # This flag has no effect when not running in Quark emulation mode. + "VLLM_QUARK_EMU_MEM_OPT": + lambda: bool(int(os.getenv("VLLM_QUARK_EMU_MEM_OPT", "0"))), + # Divisor for dynamic query scale factor calculation for FP8 KV Cache "Q_SCALE_CONSTANT": lambda: int(os.getenv("Q_SCALE_CONSTANT", "200")), diff --git a/vllm/model_executor/layers/quantization/quark/quark.py b/vllm/model_executor/layers/quantization/quark/quark.py index da23121900849..66e677f56ffd4 100644 --- a/vllm/model_executor/layers/quantization/quark/quark.py +++ b/vllm/model_executor/layers/quantization/quark/quark.py @@ -5,6 +5,7 @@ from typing import Any, Dict, List, Optional, cast import torch +from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, UnquantizedLinearMethod) @@ -15,13 +16,15 @@ from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod from vllm.model_executor.layers.quantization.quark.quark_moe import ( # noqa: E501 QuarkMoEMethod) from vllm.model_executor.layers.quantization.quark.schemes import ( - QuarkScheme, QuarkW8A8Fp8, QuarkW8A8Int8) + QuarkScheme, QuarkW4A4MXFP4, QuarkW8A8Fp8, QuarkW8A8Int8) from vllm.model_executor.layers.quantization.quark.utils import ( deep_compare, should_ignore_layer) from vllm.platforms import current_platform __all__ = ["QuarkLinearMethod"] +logger = init_logger(__name__) + class QuarkConfig(QuantizationConfig): @@ -67,6 +70,7 @@ class QuarkConfig(QuantizationConfig): return QuarkLinearMethod(self) if isinstance(layer, Attention): return QuarkKVCacheMethod(self) + if isinstance(layer, FusedMoE): return QuarkMoEMethod.get_moe_method(self, module=layer, @@ -205,6 +209,54 @@ class QuarkConfig(QuantizationConfig): # Only symmetric weight quantization supported. return is_int8_dtype and is_tensor and is_weight_symmetric and is_static + def _is_mx_fp4(self, weight_quant: Optional[Dict[str, Any]], + input_quant: Optional[Dict[str, Any]]) -> bool: + # Confirm weights and input quantized. + if weight_quant is None or input_quant is None: + logger.debug("Quark model is not in MX-FP4 format: " + "weight_quant or input_quant not set") + return False + + # Input and weight dtype needs to be fp4. + if weight_quant.get("dtype") != "fp4" or input_quant.get( + "dtype") != "fp4": + logger.debug("Quark model is not in MX-FP4 format: dtype not fp4") + return False + + # Input and weight qscheme needs to be per group. + if weight_quant.get("qscheme") != "per_group" or input_quant.get( + "qscheme") != "per_group": + logger.debug("Quark model is not in MX-FP4 format: not per_group") + return False + + # Input and weight group size needs to be 32. + if weight_quant.get("group_size") != 32 or input_quant.get( + "group_size") != 32: + logger.debug( + "Quark model is not in MX-FP4 format: not group_size=32") + return False + + # Weights need to use static quantization. + if weight_quant.get("is_dynamic") is True: + logger.debug( + "Quark model is not in MX-FP4 format: not weight static") + return False + + # Activations need to use dynamic quantization. + if input_quant.get("is_dynamic") is False: + logger.debug( + "Quark model is not in MX-FP4 format: not activation dynamic") + return False + + # Activations and weight scales need to be in e8m0 format. + if weight_quant.get("scale_format") != "e8m0" or input_quant.get( + "scale_format") != "e8m0": + logger.debug( + "Quark model is not in MX-FP4 format: not scale_format e8m0") + return False + + return True + def _find_matched_config(self, layer_name: str, module: torch.nn.Module) -> Dict[str, Any]: @@ -269,6 +321,8 @@ class QuarkConfig(QuantizationConfig): return QuarkW8A8Int8(qscheme=weight_qscheme, is_static_input_scheme=True, input_symmetric=input_config.get("symmetric")) + elif self._is_mx_fp4(weight_config, input_config): + return QuarkW4A4MXFP4(weight_config, input_config) raise NotImplementedError("No quark compatible scheme was found. " f"Weight config: {weight_config}, " diff --git a/vllm/model_executor/layers/quantization/quark/schemes/__init__.py b/vllm/model_executor/layers/quantization/quark/schemes/__init__.py index 9069b5a0d515d..d7dac17574ffe 100644 --- a/vllm/model_executor/layers/quantization/quark/schemes/__init__.py +++ b/vllm/model_executor/layers/quantization/quark/schemes/__init__.py @@ -1,7 +1,8 @@ # SPDX-License-Identifier: Apache-2.0 from .quark_scheme import QuarkScheme +from .quark_w4a4_mxfp4 import QuarkW4A4MXFP4 from .quark_w8a8_fp8 import QuarkW8A8Fp8 from .quark_w8a8_int8 import QuarkW8A8Int8 -__all__ = ["QuarkScheme", "QuarkW8A8Fp8", "QuarkW8A8Int8"] +__all__ = ["QuarkScheme", "QuarkW8A8Fp8", "QuarkW8A8Int8", "QuarkW4A4MXFP4"] diff --git a/vllm/model_executor/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py b/vllm/model_executor/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py new file mode 100644 index 0000000000000..9da52a732fc41 --- /dev/null +++ b/vllm/model_executor/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py @@ -0,0 +1,125 @@ +# SPDX-License-Identifier: Apache-2.0 + +from typing import Any, Callable, Dict, List, Optional + +import torch +import torch.nn.functional as F + +import vllm.envs as envs +from vllm.model_executor.layers.quantization.quark.schemes import QuarkScheme +from vllm.model_executor.layers.quantization.utils.mxfp4_utils import ( + OCP_MX_BLOCK_SIZE, per_token_group_quant_mxfp4) +from vllm.model_executor.parameter import (GroupQuantScaleParameter, + PackedvLLMParameter) +from vllm.platforms import current_platform + +__all__ = ["QuarkW4A4MXFP4"] + + +class QuarkW4A4MXFP4(QuarkScheme): + + def __init__(self, weight_quant_spec: Dict[str, Any], + input_quant_spec: Dict[str, Any]): + self.out_dtype = torch.get_default_dtype() + self.qscheme = "per_group" + self.weight_quant_spec = weight_quant_spec + self.input_quant_spec = input_quant_spec + self.emulate = not current_platform.supports_mx() + + @classmethod + def get_min_capability(cls) -> int: + return 70 + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + layer.weight = torch.nn.Parameter(layer.weight.data, + requires_grad=False) + layer.weight_scale = torch.nn.Parameter(layer.weight_scale.data, + requires_grad=False) + + if self.emulate: + try: + from quark.torch.export.nn.modules import realquantizer + from quark.torch.quantization.config.config import ( + QuantizationSpec) + except ImportError as err: + raise ImportError( + "The package `amd-quark` is required to use AMD Quark " + "MX-FP4 models. Please install it with `pip install " + "amd-quark`.") from err + + weight_quant_spec = QuantizationSpec.from_dict( + self.weight_quant_spec) + + weight_quantizer = realquantizer.get_real_quantizer( + qspec=weight_quant_spec, + quantizer=None, + real_quantized=True, + reorder=False, + float_dtype=self.out_dtype, + scale_shape=layer.weight_scale.shape, + zero_point_shape=None, + ) + weight_quantizer.scale.data = layer.weight_scale.data + + if not envs.VLLM_QUARK_EMU_MEM_OPT: + layer.weight = torch.nn.Parameter( + weight_quantizer(layer.weight.data).to(self.out_dtype), + requires_grad=False, + ) + else: + self.weight_quantizer = weight_quantizer + layer.weight_scale = None + + # This call is necessary to release the scales memory. + torch.cuda.empty_cache() + + def create_weights(self, layer: torch.nn.Module, + output_partition_sizes: List[int], + input_size_per_partition: int, + params_dtype: torch.dtype, weight_loader: Callable, + **kwargs): + output_size_per_partition = sum(output_partition_sizes) + layer.logical_widths = output_partition_sizes + + # WEIGHT + weight = PackedvLLMParameter( + data=torch.empty( + output_size_per_partition, + input_size_per_partition // 2, + dtype=torch.uint8, + ), + input_dim=1, + output_dim=0, + packed_dim=1, + packed_factor=2, + weight_loader=weight_loader, + ) + layer.register_parameter("weight", weight) + + # WEIGHT SCALE + weight_scale = GroupQuantScaleParameter( + data=torch.empty( + output_size_per_partition, + input_size_per_partition // OCP_MX_BLOCK_SIZE, + dtype=torch.uint8, + ), + input_dim=1, + output_dim=0, + weight_loader=weight_loader, + ) + layer.register_parameter("weight_scale", weight_scale) + + def apply_weights(self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + + if self.emulate: + if envs.VLLM_QUARK_EMU_MEM_OPT: + dq_w = self.weight_quantizer(layer.weight).to(self.out_dtype) + else: + dq_w = layer.weight + qdq_x, _ = per_token_group_quant_mxfp4(x, OCP_MX_BLOCK_SIZE) + return F.linear(qdq_x, dq_w, bias) + else: + raise NotImplementedError() diff --git a/vllm/model_executor/layers/quantization/utils/mxfp4_utils.py b/vllm/model_executor/layers/quantization/utils/mxfp4_utils.py new file mode 100644 index 0000000000000..6312c3934fd4f --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/mxfp4_utils.py @@ -0,0 +1,45 @@ +# SPDX-License-Identifier: Apache-2.0 +from typing import Tuple + +import torch + +OCP_MX_BLOCK_SIZE = 32 + + +def per_token_group_quant_mxfp4(x: torch.Tensor, + block_k: int, + scale_calculation_mode: str = "even" + ) -> Tuple[torch.Tensor, torch.Tensor]: + try: + from quark.torch.kernel.hw_emulation.hw_emulation_interface import ( + fake_quantize_fp4_fp6_per_group_with_scale) + from quark.torch.quantization.utils import (even_round, + reshape_to_blocks) + except ImportError as err: + raise ImportError("The package `amd-quark` is required to use " + "MX-FP4 models. Please install it with `pip install " + "amd-quark`.") from err + + axis = -1 + block_x = reshape_to_blocks(x, block_k, axis) + amax, _ = torch.max(torch.abs(block_x), dim=-1, keepdim=True) + amax = amax.squeeze(-1) + + # TODO: there are other rounding strategies supported in quark and in the + # config.json that we do not check for here! + if scale_calculation_mode != "even": + raise NotImplementedError( + f"Scale calculation mode {scale_calculation_mode} is not yet " + "supported in MX-FP4 quantization") + scale = even_round(amax, "fp4") + + # Apply dequantize(quantize(x)). + x = fake_quantize_fp4_fp6_per_group_with_scale( + x, + scale.to(x.device), + axis=axis, + group_size=block_k, + quant_dtype="fp4", + ) + + return x, scale diff --git a/vllm/model_executor/model_loader/utils.py b/vllm/model_executor/model_loader/utils.py index 42528cd7e4334..ddc857aebdc86 100644 --- a/vllm/model_executor/model_loader/utils.py +++ b/vllm/model_executor/model_loader/utils.py @@ -220,7 +220,7 @@ def get_model_architecture( # Special handling for quantized Mixtral. # FIXME(woosuk): This is a temporary hack. mixtral_supported = [ - "fp8", "compressed-tensors", "gptq_marlin", "awq_marlin" + "fp8", "compressed-tensors", "gptq_marlin", "awq_marlin", "quark" ] if (model_config.quantization is not None diff --git a/vllm/platforms/interface.py b/vllm/platforms/interface.py index 5df0e9d3d0728..f097ecc0a9e4f 100644 --- a/vllm/platforms/interface.py +++ b/vllm/platforms/interface.py @@ -339,6 +339,13 @@ class Platform: """ return "vllm.distributed.device_communicators.base_device_communicator.DeviceCommunicatorBase" # noqa + @classmethod + def supports_mx(cls) -> bool: + """ + Returns whether the current platform supports MX types. + """ + return False + @classmethod def supports_fp8(cls) -> bool: """ diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py index ff63f9656c01b..8a49203034237 100644 --- a/vllm/platforms/rocm.py +++ b/vllm/platforms/rocm.py @@ -327,6 +327,11 @@ class RocmPlatform(Platform): def get_device_communicator_cls(cls) -> str: return "vllm.distributed.device_communicators.cuda_communicator.CudaCommunicator" # noqa + @classmethod + def supports_mx(cls) -> bool: + gcn_arch = torch.cuda.get_device_properties(0).gcnArchName + return any(gfx in gcn_arch for gfx in ["gfx95"]) + @classmethod def supports_fp8(cls) -> bool: gcn_arch = torch.cuda.get_device_properties(0).gcnArchName