From 19cc9468fd0fa1701e7cb74b5928b329a1d16cf1 Mon Sep 17 00:00:00 2001 From: CedricHuang <38417461+CedricHwong@users.noreply.github.com> Date: Mon, 22 Dec 2025 11:34:49 +0800 Subject: [PATCH] [Feature]: Support NVIDIA ModelOpt HF FP8 variants FP8_PER_CHANNEL_PER_TOKEN and FP8_PB_WO in vLLM (#30957) --- docs/features/quantization/modelopt.md | 31 +++ tests/quantization/test_modelopt.py | 141 ++++++++++ vllm/config/model.py | 18 +- vllm/model_executor/layers/linear.py | 2 + .../layers/quantization/modelopt.py | 260 +++++++++++++++++- 5 files changed, 437 insertions(+), 15 deletions(-) diff --git a/docs/features/quantization/modelopt.md b/docs/features/quantization/modelopt.md index b02d5ba9e89a2..5c846767bc5b8 100644 --- a/docs/features/quantization/modelopt.md +++ b/docs/features/quantization/modelopt.md @@ -8,6 +8,16 @@ We recommend installing the library with: pip install nvidia-modelopt ``` +## Supported ModelOpt checkpoint formats + +vLLM detects ModelOpt checkpoints via `hf_quant_config.json` and supports the +following `quantization.quant_algo` values: + +- `FP8`: per-tensor weight scale (+ optional static activation scale). +- `FP8_PER_CHANNEL_PER_TOKEN`: per-channel weight scale and dynamic per-token activation quantization. +- `FP8_PB_WO` (ModelOpt may emit `fp8_pb_wo`): block-scaled FP8 weight-only (typically 128×128 blocks). +- `NVFP4`: ModelOpt NVFP4 checkpoints (use `quantization="modelopt_fp4"`). + ## Quantizing HuggingFace Models with PTQ You can quantize HuggingFace models using the example scripts provided in the Model Optimizer repository. The primary script for LLM PTQ is typically found within the `examples/llm_ptq` directory. @@ -80,3 +90,24 @@ The quantized checkpoint can then be deployed with vLLM. As an example, the foll if __name__ == "__main__": main() ``` + +## Running the OpenAI-compatible server + +To serve a local ModelOpt checkpoint via the OpenAI-compatible API: + +```bash +vllm serve \ + --quantization modelopt \ + --host 0.0.0.0 --port 8000 +``` + +## Testing (local checkpoints) + +vLLM's ModelOpt unit tests are gated by local checkpoint paths and are skipped +by default in CI. To run the tests locally: + +```bash +export VLLM_TEST_MODELOPT_FP8_PC_PT_MODEL_PATH= +export VLLM_TEST_MODELOPT_FP8_PB_WO_MODEL_PATH= +pytest -q tests/quantization/test_modelopt.py +``` diff --git a/tests/quantization/test_modelopt.py b/tests/quantization/test_modelopt.py index 0298994c396f6..154b29d7017ac 100644 --- a/tests/quantization/test_modelopt.py +++ b/tests/quantization/test_modelopt.py @@ -6,6 +6,7 @@ Run `pytest tests/quantization/test_modelopt.py`. """ import os +from typing import NoReturn import pytest import torch @@ -19,6 +20,28 @@ def enable_pickle(monkeypatch): monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1") +def _skip(msg: str) -> NoReturn: + pytest.skip(msg) + raise RuntimeError(msg) + + +def _snapshot_download_or_skip(model_id: str) -> str: + try: + from huggingface_hub import snapshot_download + except Exception as e: # pragma: no cover + _skip(f"huggingface_hub is required to download {model_id}: {e}") + + try: + return snapshot_download( + repo_id=model_id, + repo_type="model", + # These checkpoints are already small; download full repo for simplicity. + allow_patterns=["*"], + ) + except Exception as e: + _skip(f"Failed to download {model_id} from the HF Hub: {e}") + + @pytest.mark.skipif( not is_quant_method_supported("modelopt"), reason="ModelOpt FP8 is not supported on this GPU type.", @@ -91,3 +114,121 @@ def test_modelopt_fp8_checkpoint_setup(vllm_runner): output = llm.generate_greedy(["Hello my name is"], max_tokens=4) assert output print(f"ModelOpt FP8 output: {output}") + + +@pytest.mark.skipif( + not is_quant_method_supported("modelopt"), + reason="ModelOpt FP8 is not supported on this GPU type.", +) +def test_modelopt_fp8_pc_pt_checkpoint_setup(vllm_runner): + """Test ModelOpt FP8_PER_CHANNEL_PER_TOKEN checkpoint setup.""" + model_id = "CedricHwang/qwen2.5-0.5b-modelopt-fp8-pc-pt" + model_path = _snapshot_download_or_skip(model_id) + + with vllm_runner(model_path, quantization="modelopt", enforce_eager=True) as llm: + + def check_model(model): + layer = model.model.layers[0] + + qkv_proj = layer.self_attn.qkv_proj + o_proj = layer.self_attn.o_proj + gate_up_proj = layer.mlp.gate_up_proj + down_proj = layer.mlp.down_proj + + from vllm.model_executor.layers.quantization.modelopt import ( + ModelOptFp8PcPtLinearMethod, + ) + + assert isinstance(qkv_proj.quant_method, ModelOptFp8PcPtLinearMethod) + assert isinstance(o_proj.quant_method, ModelOptFp8PcPtLinearMethod) + assert isinstance(gate_up_proj.quant_method, ModelOptFp8PcPtLinearMethod) + assert isinstance(down_proj.quant_method, ModelOptFp8PcPtLinearMethod) + + assert qkv_proj.weight.dtype == torch.float8_e4m3fn + assert o_proj.weight.dtype == torch.float8_e4m3fn + assert gate_up_proj.weight.dtype == torch.float8_e4m3fn + assert down_proj.weight.dtype == torch.float8_e4m3fn + + # Per-channel scales; activations are dynamically scaled per token. + assert hasattr(qkv_proj, "weight_scale") + assert qkv_proj.weight_scale.dtype == torch.float32 + assert qkv_proj.weight_scale.dim() == 1 + assert not hasattr(qkv_proj, "input_scale") + + assert hasattr(o_proj, "weight_scale") + assert o_proj.weight_scale.dtype == torch.float32 + assert o_proj.weight_scale.dim() == 1 + assert not hasattr(o_proj, "input_scale") + + assert hasattr(gate_up_proj, "weight_scale") + assert gate_up_proj.weight_scale.dtype == torch.float32 + assert gate_up_proj.weight_scale.dim() == 1 + assert not hasattr(gate_up_proj, "input_scale") + + assert hasattr(down_proj, "weight_scale") + assert down_proj.weight_scale.dtype == torch.float32 + assert down_proj.weight_scale.dim() == 1 + assert not hasattr(down_proj, "input_scale") + + llm.apply_model(check_model) + + output = llm.generate_greedy(["Hello my name is"], max_tokens=4) + assert output + print(f"ModelOpt FP8_PER_CHANNEL_PER_TOKEN output: {output}") + + +@pytest.mark.skipif( + not is_quant_method_supported("modelopt"), + reason="ModelOpt FP8 is not supported on this GPU type.", +) +def test_modelopt_fp8_pb_wo_checkpoint_setup(vllm_runner): + """Test ModelOpt FP8_PB_WO checkpoint setup.""" + model_id = "CedricHwang/qwen2.5-0.5b-modelopt-fp8-pb-wo" + model_path = _snapshot_download_or_skip(model_id) + + with vllm_runner(model_path, quantization="modelopt", enforce_eager=True) as llm: + + def check_model(model): + layer = model.model.layers[0] + + qkv_proj = layer.self_attn.qkv_proj + o_proj = layer.self_attn.o_proj + gate_up_proj = layer.mlp.gate_up_proj + down_proj = layer.mlp.down_proj + + from vllm.model_executor.layers.quantization.modelopt import ( + ModelOptFp8PbWoLinearMethod, + ) + + assert isinstance(qkv_proj.quant_method, ModelOptFp8PbWoLinearMethod) + assert isinstance(o_proj.quant_method, ModelOptFp8PbWoLinearMethod) + assert isinstance(gate_up_proj.quant_method, ModelOptFp8PbWoLinearMethod) + assert isinstance(down_proj.quant_method, ModelOptFp8PbWoLinearMethod) + + assert qkv_proj.weight.dtype == torch.float8_e4m3fn + assert o_proj.weight.dtype == torch.float8_e4m3fn + assert gate_up_proj.weight.dtype == torch.float8_e4m3fn + assert down_proj.weight.dtype == torch.float8_e4m3fn + + # Block scales; should be materialized as a 2D [out_blk, in_blk] tensor. + assert hasattr(qkv_proj, "weight_scale") + assert qkv_proj.weight_scale.dtype == torch.float32 + assert qkv_proj.weight_scale.dim() == 2 + + assert hasattr(o_proj, "weight_scale") + assert o_proj.weight_scale.dtype == torch.float32 + assert o_proj.weight_scale.dim() == 2 + + assert hasattr(gate_up_proj, "weight_scale") + assert gate_up_proj.weight_scale.dtype == torch.float32 + assert gate_up_proj.weight_scale.dim() == 2 + + assert hasattr(down_proj, "weight_scale") + assert down_proj.weight_scale.dtype == torch.float32 + assert down_proj.weight_scale.dim() == 2 + + llm.apply_model(check_model) + + output = llm.generate_greedy(["Hello my name is"], max_tokens=4) + assert output + print(f"ModelOpt FP8_PB_WO output: {output}") diff --git a/vllm/config/model.py b/vllm/config/model.py index db5789b709372..c796e300ab155 100644 --- a/vllm/config/model.py +++ b/vllm/config/model.py @@ -843,12 +843,18 @@ class ModelConfig: producer_name = quant_cfg.get("producer", {}).get("name") if producer_name == "modelopt": quant_algo = quant_cfg.get("quantization", {}).get("quant_algo") - if quant_algo == "FP8": - quant_cfg["quant_method"] = "modelopt" - elif quant_algo == "NVFP4": - quant_cfg["quant_method"] = "modelopt_fp4" - elif quant_algo is not None: - raise ValueError(f"Unknown ModelOpt quant algo: {quant_algo}") + if quant_algo is not None: + quant_algo_upper = str(quant_algo).upper() + if quant_algo_upper in { + "FP8", + "FP8_PER_CHANNEL_PER_TOKEN", + "FP8_PB_WO", + }: + quant_cfg["quant_method"] = "modelopt" + elif quant_algo_upper == "NVFP4": + quant_cfg["quant_method"] = "modelopt_fp4" + else: + raise ValueError(f"Unknown ModelOpt quant algo: {quant_algo}") return quant_cfg diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index 4ca4f75711ac7..402f0bf69ceaa 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -53,6 +53,8 @@ WEIGHT_LOADER_V2_SUPPORTED = [ "GPTQLinearMethod", "FBGEMMFp8LinearMethod", "ModelOptFp8LinearMethod", + "ModelOptFp8PcPtLinearMethod", + "ModelOptFp8PbWoLinearMethod", "IPEXAWQLinearMethod", "IPEXGPTQLinearMethod", "HQQMarlinMethod", diff --git a/vllm/model_executor/layers/quantization/modelopt.py b/vllm/model_executor/layers/quantization/modelopt.py index 54e8673fcfbb8..afbefe1fedc18 100644 --- a/vllm/model_executor/layers/quantization/modelopt.py +++ b/vllm/model_executor/layers/quantization/modelopt.py @@ -55,6 +55,9 @@ from vllm.model_executor.layers.quantization.utils.flashinfer_utils import ( select_cutlass_fp8_gemm_impl, swap_w13_to_w31, ) +from vllm.model_executor.layers.quantization.utils.fp8_utils import ( + W8A8BlockFp8LinearOp, +) from vllm.model_executor.layers.quantization.utils.marlin_utils import ( get_marlin_input_dtype, ) @@ -72,9 +75,15 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import ( ) from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( Fp8LinearOp, + cutlass_block_fp8_supported, requantize_with_max_scale, ) -from vllm.model_executor.parameter import ModelWeightParameter, PerTensorScaleParameter +from vllm.model_executor.parameter import ( + BlockQuantScaleParameter, + ChannelQuantScaleParameter, + ModelWeightParameter, + PerTensorScaleParameter, +) from vllm.scalar_type import scalar_types from vllm.utils.flashinfer import ( flashinfer_scaled_fp4_mm, @@ -88,7 +97,16 @@ if TYPE_CHECKING: logger = init_logger(__name__) -QUANT_ALGOS = ["FP8", "NVFP4"] +QUANT_ALGOS = [ + # FP8 (per-tensor weight + optional static activation scale). + "FP8", + # FP8 per-channel weight scale + per-token activation scale. + "FP8_PER_CHANNEL_PER_TOKEN", + # FP8 per-block weight-only (ModelOpt may emit this as lowercase). + "FP8_PB_WO", + # FP4 + "NVFP4", +] KV_CACHE_QUANT_ALGOS = ["FP8"] @@ -255,6 +273,9 @@ class ModelOptQuantConfigBase(QuantizationConfig): if not quant_method: raise ValueError("Missing 'quant_algo' in quantization config") + # Normalize quant_algo for robust matching (ModelOpt may emit lowercase). + quant_method = str(quant_method).upper() + if kv_cache_quant_method is None: # No KV cache quantization, keep this branch just to have this comment pass @@ -263,6 +284,8 @@ class ModelOptQuantConfigBase(QuantizationConfig): f"kv_cache_quant_algo must be a string, got " f"{type(kv_cache_quant_method)}" ) + else: + kv_cache_quant_method = kv_cache_quant_method.upper() if not isinstance(exclude_modules, list): raise ValueError( @@ -302,17 +325,34 @@ class ModelOptFp8Config(ModelOptQuantConfigBase): def __init__( self, + quant_method: str, is_checkpoint_fp8_serialized: bool, kv_cache_quant_method: str | None, exclude_modules: list[str], ) -> None: super().__init__(exclude_modules) + self.quant_method = quant_method self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized self.kv_cache_quant_method = kv_cache_quant_method if is_checkpoint_fp8_serialized: logger.warning( - "Detected ModelOpt fp8 checkpoint. Please note that" - " the format is experimental and could change." + "Detected ModelOpt fp8 checkpoint (quant_algo=%s). Please note " + "that the format is experimental and could change.", + quant_method, + ) + + # Select LinearMethod implementation based on quant_algo. + if self.quant_method == "FP8": + self.LinearMethodCls = ModelOptFp8LinearMethod + elif self.quant_method == "FP8_PER_CHANNEL_PER_TOKEN": + self.LinearMethodCls = ModelOptFp8PcPtLinearMethod + elif self.quant_method == "FP8_PB_WO": + self.LinearMethodCls = ModelOptFp8PbWoLinearMethod + else: + raise ValueError( + "Unsupported ModelOpt FP8 quant_algo for vLLM: " + f"{self.quant_method}. Supported: FP8 / " + "FP8_PER_CHANNEL_PER_TOKEN / FP8_PB_WO." ) def get_name(self) -> QuantizationMethods: @@ -346,13 +386,13 @@ class ModelOptFp8Config(ModelOptQuantConfigBase): if "quantization" in hf_quant_cfg: quant_config = hf_quant_cfg["quantization"] if isinstance(quant_config, dict): - quant_algo = quant_config.get("quant_algo", "") - if "FP8" in quant_algo: + quant_algo = str(quant_config.get("quant_algo", "")) + if "FP8" in quant_algo.upper(): return "modelopt" else: # Check for compressed-tensors style config with specific quant_algo - quant_algo = hf_quant_cfg.get("quant_algo", "") - if isinstance(quant_algo, str) and "FP8" in quant_algo: + quant_algo = str(hf_quant_cfg.get("quant_algo", "")) + if "FP8" in quant_algo.upper(): return "modelopt" return None @@ -369,7 +409,12 @@ class ModelOptFp8Config(ModelOptQuantConfigBase): ) -> "ModelOptFp8Config": is_checkpoint_fp8_serialized = "FP8" in quant_method - return cls(is_checkpoint_fp8_serialized, kv_cache_quant_method, exclude_modules) + return cls( + quant_method, + is_checkpoint_fp8_serialized, + kv_cache_quant_method, + exclude_modules, + ) class ModelOptFp8LinearMethod(LinearMethodBase): @@ -464,6 +509,203 @@ class ModelOptFp8LinearMethod(LinearMethodBase): ) +class ModelOptFp8PcPtLinearMethod(LinearMethodBase): + """Linear method for ModelOpt FP8_PER_CHANNEL_PER_TOKEN checkpoints. + + Expected checkpoint structure (per Linear): + - weight: fp8-e4m3fn, shape [out, in] + - weight_scale: fp32, shape [out] (per-output-channel) + - no input_scale (activations are dynamically quantized per-token) + """ + + def __init__(self, quant_config: ModelOptFp8Config) -> None: + self.quant_config = quant_config + self.fp8_linear = Fp8LinearOp( + act_quant_static=False, act_quant_group_shape=GroupShape.PER_TOKEN + ) + + def create_weights( + self, + layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: list[int], + input_size: int, + output_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + del input_size, output_size + + if not self.quant_config.is_checkpoint_fp8_serialized: + raise ValueError( + "FP8_PER_CHANNEL_PER_TOKEN currently only supports " + "FP8-serialized checkpoints." + ) + + output_size_per_partition = sum(output_partition_sizes) + weight_loader = extra_weight_attrs.get("weight_loader") + layer.logical_widths = output_partition_sizes + layer.input_size_per_partition = input_size_per_partition + layer.output_size_per_partition = output_size_per_partition + + weight = ModelWeightParameter( + data=torch.empty( + output_size_per_partition, + input_size_per_partition, + dtype=torch.float8_e4m3fn, + ), + input_dim=1, + output_dim=0, + weight_loader=weight_loader, + ) + layer.register_parameter("weight", weight) + + weight_scale = ChannelQuantScaleParameter( + data=torch.empty(output_size_per_partition, dtype=torch.float32), + output_dim=0, + weight_loader=weight_loader, + ) + weight_scale[:] = torch.finfo(torch.float32).min + layer.register_parameter("weight_scale", weight_scale) + + def process_weights_after_loading(self, layer: Module) -> None: + layer.weight = Parameter(layer.weight.t(), requires_grad=False) + layer.weight_scale = Parameter(layer.weight_scale.data, requires_grad=False) + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: torch.Tensor | None = None, + ) -> torch.Tensor: + return self.fp8_linear.apply( + input=x, + weight=layer.weight, + weight_scale=layer.weight_scale, + input_scale=None, + bias=bias, + ) + + +class ModelOptFp8PbWoLinearMethod(LinearMethodBase): + """Linear method for ModelOpt FP8_PB_WO checkpoints. + + ModelOpt exports `weight_scale` as a 4D tensor: + [out_blk, 1, in_blk, 1] + where block size is typically 128 for both dims. + + vLLM executes it as FP8 GEMM with *dynamic per-token* activation quant. + """ + + _WEIGHT_BLOCK_SIZE: tuple[int, int] = (128, 128) + + def __init__(self, quant_config: ModelOptFp8Config) -> None: + self.quant_config = quant_config + block_n, block_k = self._WEIGHT_BLOCK_SIZE + self.weight_block_size = list(self._WEIGHT_BLOCK_SIZE) + self.w8a8_block_fp8_linear = W8A8BlockFp8LinearOp( + weight_group_shape=GroupShape(block_n, block_k), + act_quant_group_shape=GroupShape(1, block_k), + cutlass_block_fp8_supported=cutlass_block_fp8_supported(), + use_aiter_and_is_supported=False, + ) + + def create_weights( + self, + layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: list[int], + input_size: int, + output_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + del input_size, output_size + + if not self.quant_config.is_checkpoint_fp8_serialized: + raise ValueError( + "FP8_PB_WO currently only supports FP8-serialized checkpoints." + ) + + output_size_per_partition = sum(output_partition_sizes) + weight_loader = extra_weight_attrs.get("weight_loader") + layer.logical_widths = output_partition_sizes + layer.input_size_per_partition = input_size_per_partition + layer.output_size_per_partition = output_size_per_partition + + # Expose block size so the v2 weight loaders can translate offsets from + # element-space -> block-space for BlockQuantScaleParameter. + layer.weight_block_size = self.weight_block_size + + weight = ModelWeightParameter( + data=torch.empty( + output_size_per_partition, + input_size_per_partition, + dtype=torch.float8_e4m3fn, + ), + input_dim=1, + output_dim=0, + weight_loader=weight_loader, + ) + layer.register_parameter("weight", weight) + + block_n, block_k = self._WEIGHT_BLOCK_SIZE + if output_size_per_partition % block_n != 0: + raise ValueError( + "ModelOpt FP8_PB_WO requires out_features divisible by " + f"{block_n}, got {output_size_per_partition}." + ) + if input_size_per_partition % block_k != 0: + raise ValueError( + "ModelOpt FP8_PB_WO requires in_features divisible by " + f"{block_k}, got {input_size_per_partition}." + ) + + out_blks = output_size_per_partition // block_n + in_blks = input_size_per_partition // block_k + + # Match ModelOpt's exported shape so weight loading works without a + # custom loader: [out_blk, 1, in_blk, 1] + weight_scale = BlockQuantScaleParameter( + data=torch.empty((out_blks, 1, in_blks, 1), dtype=torch.float32), + input_dim=2, + output_dim=0, + weight_loader=weight_loader, + ) + weight_scale[:] = torch.finfo(torch.float32).min + layer.register_parameter("weight_scale", weight_scale) + + def process_weights_after_loading(self, layer: Module) -> None: + # Keep weight in [out, in] layout for W8A8BlockFp8LinearOp. + layer.weight = Parameter(layer.weight.data, requires_grad=False) + + scale = layer.weight_scale + if scale.dim() == 4: + # [out_blk, 1, in_blk, 1] -> [out_blk, in_blk] + scale = scale.squeeze(1).squeeze(-1) + elif scale.dim() != 2: + raise ValueError( + "Unexpected ModelOpt FP8_PB_WO weight_scale shape: " + f"{tuple(scale.shape)}." + ) + + layer.weight_scale = Parameter(scale.contiguous(), requires_grad=False) + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: torch.Tensor | None = None, + ) -> torch.Tensor: + return self.w8a8_block_fp8_linear.apply( + input=x, + weight=layer.weight, + weight_scale=layer.weight_scale, + input_scale=None, + bias=bias, + ) + + class ModelOptFp8MoEMethod(FusedMoEMethodBase): """MoE method for ModelOpt FP8. Supports loading FP8 checkpoints with static weight scale and