mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-24 22:55:44 +08:00
[Quantization] Quark MXFP4 format loading (#16943)
This commit is contained in:
parent
f98e307588
commit
db593aa67f
40
tests/models/quantization/test_mxfp4.py
Normal file
40
tests/models/quantization/test_mxfp4.py
Normal file
@ -0,0 +1,40 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# flake8: noqa
|
||||
"""Tests Quark mxfp4 models against ground truth generation
|
||||
"""
|
||||
import pytest
|
||||
|
||||
from vllm import LLM, SamplingParams
|
||||
|
||||
MODELS = ["amd/Llama-2-7b-chat-hf-wmxfp4-amxfp4-kvfp8-scale-uint8"]
|
||||
|
||||
EXPECTED_STRS_MAP = {
|
||||
"amd/Llama-2-7b-chat-hf-wmxfp4-amxfp4-kvfp8-scale-uint8": [
|
||||
'\n### Key Features\n\n* **High-throughput Inference**: vLL',
|
||||
'\nArtificial intelligence (AI) has evolved significantly since its inception in the 1',
|
||||
'Artificial intelligence (AI) and human intelligence (HI) are two distinct concepts that have been',
|
||||
'A neural network is a machine learning model inspired by the structure of the human brain. It consists of',
|
||||
'\nTitle: The Dreaming Robot\n\nAs the sun set on the bustling metropol',
|
||||
'\nThe COVID-19 pandemic has had a profound impact on global economic structures and business',
|
||||
'The Mona Lisa painting, created by Leonardo da Vinci in the early 16th',
|
||||
" everybody knows this proverbial saying, but did you know that it's not entirely accurate?",
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="Model to be released in the future")
|
||||
@pytest.mark.quant_model
|
||||
@pytest.mark.parametrize("model_name", MODELS)
|
||||
def test_models(example_prompts, model_name) -> None:
|
||||
sampling_params = SamplingParams(max_tokens=20, temperature=0)
|
||||
llm = LLM(
|
||||
model=model_name,
|
||||
kv_cache_dtype="fp8",
|
||||
quantization="quark",
|
||||
)
|
||||
outputs = llm.generate(example_prompts, sampling_params)
|
||||
for i, output in enumerate(outputs):
|
||||
output_str = output.outputs[0].text
|
||||
expected_str = EXPECTED_STRS_MAP[model_name][i]
|
||||
assert expected_str == output_str, (
|
||||
f"Expected: {expected_str!r}\nvLLM: {output_str!r}")
|
||||
@ -84,6 +84,7 @@ if TYPE_CHECKING:
|
||||
VLLM_ROCM_FP8_PADDING: bool = True
|
||||
VLLM_ROCM_MOE_PADDING: bool = True
|
||||
VLLM_ROCM_CUSTOM_PAGED_ATTN: bool = True
|
||||
VLLM_QUARK_EMU_MEM_OPT: bool = False
|
||||
VLLM_ENABLE_V1_MULTIPROCESSING: bool = True
|
||||
VLLM_LOG_BATCHSIZE_INTERVAL: float = -1
|
||||
VLLM_DISABLE_COMPILE_CACHE: bool = False
|
||||
@ -583,6 +584,14 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
||||
lambda: (os.getenv("VLLM_ROCM_CUSTOM_PAGED_ATTN", "True").lower() in
|
||||
("true", "1")),
|
||||
|
||||
# If set, when running in Quark emulation mode, do not dequantize the
|
||||
# weights at load time. Instead, dequantize weights on-the-fly during
|
||||
# kernel execution.
|
||||
# This allows running larger models at the cost of slower inference.
|
||||
# This flag has no effect when not running in Quark emulation mode.
|
||||
"VLLM_QUARK_EMU_MEM_OPT":
|
||||
lambda: bool(int(os.getenv("VLLM_QUARK_EMU_MEM_OPT", "0"))),
|
||||
|
||||
# Divisor for dynamic query scale factor calculation for FP8 KV Cache
|
||||
"Q_SCALE_CONSTANT":
|
||||
lambda: int(os.getenv("Q_SCALE_CONSTANT", "200")),
|
||||
|
||||
@ -5,6 +5,7 @@ from typing import Any, Dict, List, Optional, cast
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoE
|
||||
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
|
||||
UnquantizedLinearMethod)
|
||||
@ -15,13 +16,15 @@ from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
|
||||
from vllm.model_executor.layers.quantization.quark.quark_moe import ( # noqa: E501
|
||||
QuarkMoEMethod)
|
||||
from vllm.model_executor.layers.quantization.quark.schemes import (
|
||||
QuarkScheme, QuarkW8A8Fp8, QuarkW8A8Int8)
|
||||
QuarkScheme, QuarkW4A4MXFP4, QuarkW8A8Fp8, QuarkW8A8Int8)
|
||||
from vllm.model_executor.layers.quantization.quark.utils import (
|
||||
deep_compare, should_ignore_layer)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
__all__ = ["QuarkLinearMethod"]
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class QuarkConfig(QuantizationConfig):
|
||||
|
||||
@ -67,6 +70,7 @@ class QuarkConfig(QuantizationConfig):
|
||||
return QuarkLinearMethod(self)
|
||||
if isinstance(layer, Attention):
|
||||
return QuarkKVCacheMethod(self)
|
||||
|
||||
if isinstance(layer, FusedMoE):
|
||||
return QuarkMoEMethod.get_moe_method(self,
|
||||
module=layer,
|
||||
@ -205,6 +209,54 @@ class QuarkConfig(QuantizationConfig):
|
||||
# Only symmetric weight quantization supported.
|
||||
return is_int8_dtype and is_tensor and is_weight_symmetric and is_static
|
||||
|
||||
def _is_mx_fp4(self, weight_quant: Optional[Dict[str, Any]],
|
||||
input_quant: Optional[Dict[str, Any]]) -> bool:
|
||||
# Confirm weights and input quantized.
|
||||
if weight_quant is None or input_quant is None:
|
||||
logger.debug("Quark model is not in MX-FP4 format: "
|
||||
"weight_quant or input_quant not set")
|
||||
return False
|
||||
|
||||
# Input and weight dtype needs to be fp4.
|
||||
if weight_quant.get("dtype") != "fp4" or input_quant.get(
|
||||
"dtype") != "fp4":
|
||||
logger.debug("Quark model is not in MX-FP4 format: dtype not fp4")
|
||||
return False
|
||||
|
||||
# Input and weight qscheme needs to be per group.
|
||||
if weight_quant.get("qscheme") != "per_group" or input_quant.get(
|
||||
"qscheme") != "per_group":
|
||||
logger.debug("Quark model is not in MX-FP4 format: not per_group")
|
||||
return False
|
||||
|
||||
# Input and weight group size needs to be 32.
|
||||
if weight_quant.get("group_size") != 32 or input_quant.get(
|
||||
"group_size") != 32:
|
||||
logger.debug(
|
||||
"Quark model is not in MX-FP4 format: not group_size=32")
|
||||
return False
|
||||
|
||||
# Weights need to use static quantization.
|
||||
if weight_quant.get("is_dynamic") is True:
|
||||
logger.debug(
|
||||
"Quark model is not in MX-FP4 format: not weight static")
|
||||
return False
|
||||
|
||||
# Activations need to use dynamic quantization.
|
||||
if input_quant.get("is_dynamic") is False:
|
||||
logger.debug(
|
||||
"Quark model is not in MX-FP4 format: not activation dynamic")
|
||||
return False
|
||||
|
||||
# Activations and weight scales need to be in e8m0 format.
|
||||
if weight_quant.get("scale_format") != "e8m0" or input_quant.get(
|
||||
"scale_format") != "e8m0":
|
||||
logger.debug(
|
||||
"Quark model is not in MX-FP4 format: not scale_format e8m0")
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def _find_matched_config(self, layer_name: str,
|
||||
module: torch.nn.Module) -> Dict[str, Any]:
|
||||
|
||||
@ -269,6 +321,8 @@ class QuarkConfig(QuantizationConfig):
|
||||
return QuarkW8A8Int8(qscheme=weight_qscheme,
|
||||
is_static_input_scheme=True,
|
||||
input_symmetric=input_config.get("symmetric"))
|
||||
elif self._is_mx_fp4(weight_config, input_config):
|
||||
return QuarkW4A4MXFP4(weight_config, input_config)
|
||||
|
||||
raise NotImplementedError("No quark compatible scheme was found. "
|
||||
f"Weight config: {weight_config}, "
|
||||
|
||||
@ -1,7 +1,8 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from .quark_scheme import QuarkScheme
|
||||
from .quark_w4a4_mxfp4 import QuarkW4A4MXFP4
|
||||
from .quark_w8a8_fp8 import QuarkW8A8Fp8
|
||||
from .quark_w8a8_int8 import QuarkW8A8Int8
|
||||
|
||||
__all__ = ["QuarkScheme", "QuarkW8A8Fp8", "QuarkW8A8Int8"]
|
||||
__all__ = ["QuarkScheme", "QuarkW8A8Fp8", "QuarkW8A8Int8", "QuarkW4A4MXFP4"]
|
||||
|
||||
@ -0,0 +1,125 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from typing import Any, Callable, Dict, List, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.model_executor.layers.quantization.quark.schemes import QuarkScheme
|
||||
from vllm.model_executor.layers.quantization.utils.mxfp4_utils import (
|
||||
OCP_MX_BLOCK_SIZE, per_token_group_quant_mxfp4)
|
||||
from vllm.model_executor.parameter import (GroupQuantScaleParameter,
|
||||
PackedvLLMParameter)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
__all__ = ["QuarkW4A4MXFP4"]
|
||||
|
||||
|
||||
class QuarkW4A4MXFP4(QuarkScheme):
|
||||
|
||||
def __init__(self, weight_quant_spec: Dict[str, Any],
|
||||
input_quant_spec: Dict[str, Any]):
|
||||
self.out_dtype = torch.get_default_dtype()
|
||||
self.qscheme = "per_group"
|
||||
self.weight_quant_spec = weight_quant_spec
|
||||
self.input_quant_spec = input_quant_spec
|
||||
self.emulate = not current_platform.supports_mx()
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
return 70
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
layer.weight = torch.nn.Parameter(layer.weight.data,
|
||||
requires_grad=False)
|
||||
layer.weight_scale = torch.nn.Parameter(layer.weight_scale.data,
|
||||
requires_grad=False)
|
||||
|
||||
if self.emulate:
|
||||
try:
|
||||
from quark.torch.export.nn.modules import realquantizer
|
||||
from quark.torch.quantization.config.config import (
|
||||
QuantizationSpec)
|
||||
except ImportError as err:
|
||||
raise ImportError(
|
||||
"The package `amd-quark` is required to use AMD Quark "
|
||||
"MX-FP4 models. Please install it with `pip install "
|
||||
"amd-quark`.") from err
|
||||
|
||||
weight_quant_spec = QuantizationSpec.from_dict(
|
||||
self.weight_quant_spec)
|
||||
|
||||
weight_quantizer = realquantizer.get_real_quantizer(
|
||||
qspec=weight_quant_spec,
|
||||
quantizer=None,
|
||||
real_quantized=True,
|
||||
reorder=False,
|
||||
float_dtype=self.out_dtype,
|
||||
scale_shape=layer.weight_scale.shape,
|
||||
zero_point_shape=None,
|
||||
)
|
||||
weight_quantizer.scale.data = layer.weight_scale.data
|
||||
|
||||
if not envs.VLLM_QUARK_EMU_MEM_OPT:
|
||||
layer.weight = torch.nn.Parameter(
|
||||
weight_quantizer(layer.weight.data).to(self.out_dtype),
|
||||
requires_grad=False,
|
||||
)
|
||||
else:
|
||||
self.weight_quantizer = weight_quantizer
|
||||
layer.weight_scale = None
|
||||
|
||||
# This call is necessary to release the scales memory.
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def create_weights(self, layer: torch.nn.Module,
|
||||
output_partition_sizes: List[int],
|
||||
input_size_per_partition: int,
|
||||
params_dtype: torch.dtype, weight_loader: Callable,
|
||||
**kwargs):
|
||||
output_size_per_partition = sum(output_partition_sizes)
|
||||
layer.logical_widths = output_partition_sizes
|
||||
|
||||
# WEIGHT
|
||||
weight = PackedvLLMParameter(
|
||||
data=torch.empty(
|
||||
output_size_per_partition,
|
||||
input_size_per_partition // 2,
|
||||
dtype=torch.uint8,
|
||||
),
|
||||
input_dim=1,
|
||||
output_dim=0,
|
||||
packed_dim=1,
|
||||
packed_factor=2,
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
layer.register_parameter("weight", weight)
|
||||
|
||||
# WEIGHT SCALE
|
||||
weight_scale = GroupQuantScaleParameter(
|
||||
data=torch.empty(
|
||||
output_size_per_partition,
|
||||
input_size_per_partition // OCP_MX_BLOCK_SIZE,
|
||||
dtype=torch.uint8,
|
||||
),
|
||||
input_dim=1,
|
||||
output_dim=0,
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
layer.register_parameter("weight_scale", weight_scale)
|
||||
|
||||
def apply_weights(self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
|
||||
if self.emulate:
|
||||
if envs.VLLM_QUARK_EMU_MEM_OPT:
|
||||
dq_w = self.weight_quantizer(layer.weight).to(self.out_dtype)
|
||||
else:
|
||||
dq_w = layer.weight
|
||||
qdq_x, _ = per_token_group_quant_mxfp4(x, OCP_MX_BLOCK_SIZE)
|
||||
return F.linear(qdq_x, dq_w, bias)
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
45
vllm/model_executor/layers/quantization/utils/mxfp4_utils.py
Normal file
45
vllm/model_executor/layers/quantization/utils/mxfp4_utils.py
Normal file
@ -0,0 +1,45 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
|
||||
OCP_MX_BLOCK_SIZE = 32
|
||||
|
||||
|
||||
def per_token_group_quant_mxfp4(x: torch.Tensor,
|
||||
block_k: int,
|
||||
scale_calculation_mode: str = "even"
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
try:
|
||||
from quark.torch.kernel.hw_emulation.hw_emulation_interface import (
|
||||
fake_quantize_fp4_fp6_per_group_with_scale)
|
||||
from quark.torch.quantization.utils import (even_round,
|
||||
reshape_to_blocks)
|
||||
except ImportError as err:
|
||||
raise ImportError("The package `amd-quark` is required to use "
|
||||
"MX-FP4 models. Please install it with `pip install "
|
||||
"amd-quark`.") from err
|
||||
|
||||
axis = -1
|
||||
block_x = reshape_to_blocks(x, block_k, axis)
|
||||
amax, _ = torch.max(torch.abs(block_x), dim=-1, keepdim=True)
|
||||
amax = amax.squeeze(-1)
|
||||
|
||||
# TODO: there are other rounding strategies supported in quark and in the
|
||||
# config.json that we do not check for here!
|
||||
if scale_calculation_mode != "even":
|
||||
raise NotImplementedError(
|
||||
f"Scale calculation mode {scale_calculation_mode} is not yet "
|
||||
"supported in MX-FP4 quantization")
|
||||
scale = even_round(amax, "fp4")
|
||||
|
||||
# Apply dequantize(quantize(x)).
|
||||
x = fake_quantize_fp4_fp6_per_group_with_scale(
|
||||
x,
|
||||
scale.to(x.device),
|
||||
axis=axis,
|
||||
group_size=block_k,
|
||||
quant_dtype="fp4",
|
||||
)
|
||||
|
||||
return x, scale
|
||||
@ -220,7 +220,7 @@ def get_model_architecture(
|
||||
# Special handling for quantized Mixtral.
|
||||
# FIXME(woosuk): This is a temporary hack.
|
||||
mixtral_supported = [
|
||||
"fp8", "compressed-tensors", "gptq_marlin", "awq_marlin"
|
||||
"fp8", "compressed-tensors", "gptq_marlin", "awq_marlin", "quark"
|
||||
]
|
||||
|
||||
if (model_config.quantization is not None
|
||||
|
||||
@ -339,6 +339,13 @@ class Platform:
|
||||
"""
|
||||
return "vllm.distributed.device_communicators.base_device_communicator.DeviceCommunicatorBase" # noqa
|
||||
|
||||
@classmethod
|
||||
def supports_mx(cls) -> bool:
|
||||
"""
|
||||
Returns whether the current platform supports MX types.
|
||||
"""
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def supports_fp8(cls) -> bool:
|
||||
"""
|
||||
|
||||
@ -327,6 +327,11 @@ class RocmPlatform(Platform):
|
||||
def get_device_communicator_cls(cls) -> str:
|
||||
return "vllm.distributed.device_communicators.cuda_communicator.CudaCommunicator" # noqa
|
||||
|
||||
@classmethod
|
||||
def supports_mx(cls) -> bool:
|
||||
gcn_arch = torch.cuda.get_device_properties(0).gcnArchName
|
||||
return any(gfx in gcn_arch for gfx in ["gfx95"])
|
||||
|
||||
@classmethod
|
||||
def supports_fp8(cls) -> bool:
|
||||
gcn_arch = torch.cuda.get_device_properties(0).gcnArchName
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user