[Feature]: Support NVIDIA ModelOpt HF FP8 variants FP8_PER_CHANNEL_PER_TOKEN and FP8_PB_WO in vLLM (#30957)

This commit is contained in:
CedricHuang 2025-12-22 11:34:49 +08:00 committed by GitHub
parent 097978a15d
commit 19cc9468fd
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 437 additions and 15 deletions

View File

@ -8,6 +8,16 @@ We recommend installing the library with:
pip install nvidia-modelopt
```
## Supported ModelOpt checkpoint formats
vLLM detects ModelOpt checkpoints via `hf_quant_config.json` and supports the
following `quantization.quant_algo` values:
- `FP8`: per-tensor weight scale (+ optional static activation scale).
- `FP8_PER_CHANNEL_PER_TOKEN`: per-channel weight scale and dynamic per-token activation quantization.
- `FP8_PB_WO` (ModelOpt may emit `fp8_pb_wo`): block-scaled FP8 weight-only (typically 128×128 blocks).
- `NVFP4`: ModelOpt NVFP4 checkpoints (use `quantization="modelopt_fp4"`).
## Quantizing HuggingFace Models with PTQ
You can quantize HuggingFace models using the example scripts provided in the Model Optimizer repository. The primary script for LLM PTQ is typically found within the `examples/llm_ptq` directory.
@ -80,3 +90,24 @@ The quantized checkpoint can then be deployed with vLLM. As an example, the foll
if __name__ == "__main__":
main()
```
## Running the OpenAI-compatible server
To serve a local ModelOpt checkpoint via the OpenAI-compatible API:
```bash
vllm serve <path_to_exported_checkpoint> \
--quantization modelopt \
--host 0.0.0.0 --port 8000
```
## Testing (local checkpoints)
vLLM's ModelOpt unit tests are gated by local checkpoint paths and are skipped
by default in CI. To run the tests locally:
```bash
export VLLM_TEST_MODELOPT_FP8_PC_PT_MODEL_PATH=<path_to_fp8_pc_pt_checkpoint>
export VLLM_TEST_MODELOPT_FP8_PB_WO_MODEL_PATH=<path_to_fp8_pb_wo_checkpoint>
pytest -q tests/quantization/test_modelopt.py
```

View File

@ -6,6 +6,7 @@ Run `pytest tests/quantization/test_modelopt.py`.
"""
import os
from typing import NoReturn
import pytest
import torch
@ -19,6 +20,28 @@ def enable_pickle(monkeypatch):
monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1")
def _skip(msg: str) -> NoReturn:
pytest.skip(msg)
raise RuntimeError(msg)
def _snapshot_download_or_skip(model_id: str) -> str:
try:
from huggingface_hub import snapshot_download
except Exception as e: # pragma: no cover
_skip(f"huggingface_hub is required to download {model_id}: {e}")
try:
return snapshot_download(
repo_id=model_id,
repo_type="model",
# These checkpoints are already small; download full repo for simplicity.
allow_patterns=["*"],
)
except Exception as e:
_skip(f"Failed to download {model_id} from the HF Hub: {e}")
@pytest.mark.skipif(
not is_quant_method_supported("modelopt"),
reason="ModelOpt FP8 is not supported on this GPU type.",
@ -91,3 +114,121 @@ def test_modelopt_fp8_checkpoint_setup(vllm_runner):
output = llm.generate_greedy(["Hello my name is"], max_tokens=4)
assert output
print(f"ModelOpt FP8 output: {output}")
@pytest.mark.skipif(
not is_quant_method_supported("modelopt"),
reason="ModelOpt FP8 is not supported on this GPU type.",
)
def test_modelopt_fp8_pc_pt_checkpoint_setup(vllm_runner):
"""Test ModelOpt FP8_PER_CHANNEL_PER_TOKEN checkpoint setup."""
model_id = "CedricHwang/qwen2.5-0.5b-modelopt-fp8-pc-pt"
model_path = _snapshot_download_or_skip(model_id)
with vllm_runner(model_path, quantization="modelopt", enforce_eager=True) as llm:
def check_model(model):
layer = model.model.layers[0]
qkv_proj = layer.self_attn.qkv_proj
o_proj = layer.self_attn.o_proj
gate_up_proj = layer.mlp.gate_up_proj
down_proj = layer.mlp.down_proj
from vllm.model_executor.layers.quantization.modelopt import (
ModelOptFp8PcPtLinearMethod,
)
assert isinstance(qkv_proj.quant_method, ModelOptFp8PcPtLinearMethod)
assert isinstance(o_proj.quant_method, ModelOptFp8PcPtLinearMethod)
assert isinstance(gate_up_proj.quant_method, ModelOptFp8PcPtLinearMethod)
assert isinstance(down_proj.quant_method, ModelOptFp8PcPtLinearMethod)
assert qkv_proj.weight.dtype == torch.float8_e4m3fn
assert o_proj.weight.dtype == torch.float8_e4m3fn
assert gate_up_proj.weight.dtype == torch.float8_e4m3fn
assert down_proj.weight.dtype == torch.float8_e4m3fn
# Per-channel scales; activations are dynamically scaled per token.
assert hasattr(qkv_proj, "weight_scale")
assert qkv_proj.weight_scale.dtype == torch.float32
assert qkv_proj.weight_scale.dim() == 1
assert not hasattr(qkv_proj, "input_scale")
assert hasattr(o_proj, "weight_scale")
assert o_proj.weight_scale.dtype == torch.float32
assert o_proj.weight_scale.dim() == 1
assert not hasattr(o_proj, "input_scale")
assert hasattr(gate_up_proj, "weight_scale")
assert gate_up_proj.weight_scale.dtype == torch.float32
assert gate_up_proj.weight_scale.dim() == 1
assert not hasattr(gate_up_proj, "input_scale")
assert hasattr(down_proj, "weight_scale")
assert down_proj.weight_scale.dtype == torch.float32
assert down_proj.weight_scale.dim() == 1
assert not hasattr(down_proj, "input_scale")
llm.apply_model(check_model)
output = llm.generate_greedy(["Hello my name is"], max_tokens=4)
assert output
print(f"ModelOpt FP8_PER_CHANNEL_PER_TOKEN output: {output}")
@pytest.mark.skipif(
not is_quant_method_supported("modelopt"),
reason="ModelOpt FP8 is not supported on this GPU type.",
)
def test_modelopt_fp8_pb_wo_checkpoint_setup(vllm_runner):
"""Test ModelOpt FP8_PB_WO checkpoint setup."""
model_id = "CedricHwang/qwen2.5-0.5b-modelopt-fp8-pb-wo"
model_path = _snapshot_download_or_skip(model_id)
with vllm_runner(model_path, quantization="modelopt", enforce_eager=True) as llm:
def check_model(model):
layer = model.model.layers[0]
qkv_proj = layer.self_attn.qkv_proj
o_proj = layer.self_attn.o_proj
gate_up_proj = layer.mlp.gate_up_proj
down_proj = layer.mlp.down_proj
from vllm.model_executor.layers.quantization.modelopt import (
ModelOptFp8PbWoLinearMethod,
)
assert isinstance(qkv_proj.quant_method, ModelOptFp8PbWoLinearMethod)
assert isinstance(o_proj.quant_method, ModelOptFp8PbWoLinearMethod)
assert isinstance(gate_up_proj.quant_method, ModelOptFp8PbWoLinearMethod)
assert isinstance(down_proj.quant_method, ModelOptFp8PbWoLinearMethod)
assert qkv_proj.weight.dtype == torch.float8_e4m3fn
assert o_proj.weight.dtype == torch.float8_e4m3fn
assert gate_up_proj.weight.dtype == torch.float8_e4m3fn
assert down_proj.weight.dtype == torch.float8_e4m3fn
# Block scales; should be materialized as a 2D [out_blk, in_blk] tensor.
assert hasattr(qkv_proj, "weight_scale")
assert qkv_proj.weight_scale.dtype == torch.float32
assert qkv_proj.weight_scale.dim() == 2
assert hasattr(o_proj, "weight_scale")
assert o_proj.weight_scale.dtype == torch.float32
assert o_proj.weight_scale.dim() == 2
assert hasattr(gate_up_proj, "weight_scale")
assert gate_up_proj.weight_scale.dtype == torch.float32
assert gate_up_proj.weight_scale.dim() == 2
assert hasattr(down_proj, "weight_scale")
assert down_proj.weight_scale.dtype == torch.float32
assert down_proj.weight_scale.dim() == 2
llm.apply_model(check_model)
output = llm.generate_greedy(["Hello my name is"], max_tokens=4)
assert output
print(f"ModelOpt FP8_PB_WO output: {output}")

View File

@ -843,12 +843,18 @@ class ModelConfig:
producer_name = quant_cfg.get("producer", {}).get("name")
if producer_name == "modelopt":
quant_algo = quant_cfg.get("quantization", {}).get("quant_algo")
if quant_algo == "FP8":
quant_cfg["quant_method"] = "modelopt"
elif quant_algo == "NVFP4":
quant_cfg["quant_method"] = "modelopt_fp4"
elif quant_algo is not None:
raise ValueError(f"Unknown ModelOpt quant algo: {quant_algo}")
if quant_algo is not None:
quant_algo_upper = str(quant_algo).upper()
if quant_algo_upper in {
"FP8",
"FP8_PER_CHANNEL_PER_TOKEN",
"FP8_PB_WO",
}:
quant_cfg["quant_method"] = "modelopt"
elif quant_algo_upper == "NVFP4":
quant_cfg["quant_method"] = "modelopt_fp4"
else:
raise ValueError(f"Unknown ModelOpt quant algo: {quant_algo}")
return quant_cfg

View File

@ -53,6 +53,8 @@ WEIGHT_LOADER_V2_SUPPORTED = [
"GPTQLinearMethod",
"FBGEMMFp8LinearMethod",
"ModelOptFp8LinearMethod",
"ModelOptFp8PcPtLinearMethod",
"ModelOptFp8PbWoLinearMethod",
"IPEXAWQLinearMethod",
"IPEXGPTQLinearMethod",
"HQQMarlinMethod",

View File

@ -55,6 +55,9 @@ from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
select_cutlass_fp8_gemm_impl,
swap_w13_to_w31,
)
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
W8A8BlockFp8LinearOp,
)
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
get_marlin_input_dtype,
)
@ -72,9 +75,15 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
Fp8LinearOp,
cutlass_block_fp8_supported,
requantize_with_max_scale,
)
from vllm.model_executor.parameter import ModelWeightParameter, PerTensorScaleParameter
from vllm.model_executor.parameter import (
BlockQuantScaleParameter,
ChannelQuantScaleParameter,
ModelWeightParameter,
PerTensorScaleParameter,
)
from vllm.scalar_type import scalar_types
from vllm.utils.flashinfer import (
flashinfer_scaled_fp4_mm,
@ -88,7 +97,16 @@ if TYPE_CHECKING:
logger = init_logger(__name__)
QUANT_ALGOS = ["FP8", "NVFP4"]
QUANT_ALGOS = [
# FP8 (per-tensor weight + optional static activation scale).
"FP8",
# FP8 per-channel weight scale + per-token activation scale.
"FP8_PER_CHANNEL_PER_TOKEN",
# FP8 per-block weight-only (ModelOpt may emit this as lowercase).
"FP8_PB_WO",
# FP4
"NVFP4",
]
KV_CACHE_QUANT_ALGOS = ["FP8"]
@ -255,6 +273,9 @@ class ModelOptQuantConfigBase(QuantizationConfig):
if not quant_method:
raise ValueError("Missing 'quant_algo' in quantization config")
# Normalize quant_algo for robust matching (ModelOpt may emit lowercase).
quant_method = str(quant_method).upper()
if kv_cache_quant_method is None:
# No KV cache quantization, keep this branch just to have this comment
pass
@ -263,6 +284,8 @@ class ModelOptQuantConfigBase(QuantizationConfig):
f"kv_cache_quant_algo must be a string, got "
f"{type(kv_cache_quant_method)}"
)
else:
kv_cache_quant_method = kv_cache_quant_method.upper()
if not isinstance(exclude_modules, list):
raise ValueError(
@ -302,17 +325,34 @@ class ModelOptFp8Config(ModelOptQuantConfigBase):
def __init__(
self,
quant_method: str,
is_checkpoint_fp8_serialized: bool,
kv_cache_quant_method: str | None,
exclude_modules: list[str],
) -> None:
super().__init__(exclude_modules)
self.quant_method = quant_method
self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized
self.kv_cache_quant_method = kv_cache_quant_method
if is_checkpoint_fp8_serialized:
logger.warning(
"Detected ModelOpt fp8 checkpoint. Please note that"
" the format is experimental and could change."
"Detected ModelOpt fp8 checkpoint (quant_algo=%s). Please note "
"that the format is experimental and could change.",
quant_method,
)
# Select LinearMethod implementation based on quant_algo.
if self.quant_method == "FP8":
self.LinearMethodCls = ModelOptFp8LinearMethod
elif self.quant_method == "FP8_PER_CHANNEL_PER_TOKEN":
self.LinearMethodCls = ModelOptFp8PcPtLinearMethod
elif self.quant_method == "FP8_PB_WO":
self.LinearMethodCls = ModelOptFp8PbWoLinearMethod
else:
raise ValueError(
"Unsupported ModelOpt FP8 quant_algo for vLLM: "
f"{self.quant_method}. Supported: FP8 / "
"FP8_PER_CHANNEL_PER_TOKEN / FP8_PB_WO."
)
def get_name(self) -> QuantizationMethods:
@ -346,13 +386,13 @@ class ModelOptFp8Config(ModelOptQuantConfigBase):
if "quantization" in hf_quant_cfg:
quant_config = hf_quant_cfg["quantization"]
if isinstance(quant_config, dict):
quant_algo = quant_config.get("quant_algo", "")
if "FP8" in quant_algo:
quant_algo = str(quant_config.get("quant_algo", ""))
if "FP8" in quant_algo.upper():
return "modelopt"
else:
# Check for compressed-tensors style config with specific quant_algo
quant_algo = hf_quant_cfg.get("quant_algo", "")
if isinstance(quant_algo, str) and "FP8" in quant_algo:
quant_algo = str(hf_quant_cfg.get("quant_algo", ""))
if "FP8" in quant_algo.upper():
return "modelopt"
return None
@ -369,7 +409,12 @@ class ModelOptFp8Config(ModelOptQuantConfigBase):
) -> "ModelOptFp8Config":
is_checkpoint_fp8_serialized = "FP8" in quant_method
return cls(is_checkpoint_fp8_serialized, kv_cache_quant_method, exclude_modules)
return cls(
quant_method,
is_checkpoint_fp8_serialized,
kv_cache_quant_method,
exclude_modules,
)
class ModelOptFp8LinearMethod(LinearMethodBase):
@ -464,6 +509,203 @@ class ModelOptFp8LinearMethod(LinearMethodBase):
)
class ModelOptFp8PcPtLinearMethod(LinearMethodBase):
"""Linear method for ModelOpt FP8_PER_CHANNEL_PER_TOKEN checkpoints.
Expected checkpoint structure (per Linear):
- weight: fp8-e4m3fn, shape [out, in]
- weight_scale: fp32, shape [out] (per-output-channel)
- no input_scale (activations are dynamically quantized per-token)
"""
def __init__(self, quant_config: ModelOptFp8Config) -> None:
self.quant_config = quant_config
self.fp8_linear = Fp8LinearOp(
act_quant_static=False, act_quant_group_shape=GroupShape.PER_TOKEN
)
def create_weights(
self,
layer: torch.nn.Module,
input_size_per_partition: int,
output_partition_sizes: list[int],
input_size: int,
output_size: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
):
del input_size, output_size
if not self.quant_config.is_checkpoint_fp8_serialized:
raise ValueError(
"FP8_PER_CHANNEL_PER_TOKEN currently only supports "
"FP8-serialized checkpoints."
)
output_size_per_partition = sum(output_partition_sizes)
weight_loader = extra_weight_attrs.get("weight_loader")
layer.logical_widths = output_partition_sizes
layer.input_size_per_partition = input_size_per_partition
layer.output_size_per_partition = output_size_per_partition
weight = ModelWeightParameter(
data=torch.empty(
output_size_per_partition,
input_size_per_partition,
dtype=torch.float8_e4m3fn,
),
input_dim=1,
output_dim=0,
weight_loader=weight_loader,
)
layer.register_parameter("weight", weight)
weight_scale = ChannelQuantScaleParameter(
data=torch.empty(output_size_per_partition, dtype=torch.float32),
output_dim=0,
weight_loader=weight_loader,
)
weight_scale[:] = torch.finfo(torch.float32).min
layer.register_parameter("weight_scale", weight_scale)
def process_weights_after_loading(self, layer: Module) -> None:
layer.weight = Parameter(layer.weight.t(), requires_grad=False)
layer.weight_scale = Parameter(layer.weight_scale.data, requires_grad=False)
def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: torch.Tensor | None = None,
) -> torch.Tensor:
return self.fp8_linear.apply(
input=x,
weight=layer.weight,
weight_scale=layer.weight_scale,
input_scale=None,
bias=bias,
)
class ModelOptFp8PbWoLinearMethod(LinearMethodBase):
"""Linear method for ModelOpt FP8_PB_WO checkpoints.
ModelOpt exports `weight_scale` as a 4D tensor:
[out_blk, 1, in_blk, 1]
where block size is typically 128 for both dims.
vLLM executes it as FP8 GEMM with *dynamic per-token* activation quant.
"""
_WEIGHT_BLOCK_SIZE: tuple[int, int] = (128, 128)
def __init__(self, quant_config: ModelOptFp8Config) -> None:
self.quant_config = quant_config
block_n, block_k = self._WEIGHT_BLOCK_SIZE
self.weight_block_size = list(self._WEIGHT_BLOCK_SIZE)
self.w8a8_block_fp8_linear = W8A8BlockFp8LinearOp(
weight_group_shape=GroupShape(block_n, block_k),
act_quant_group_shape=GroupShape(1, block_k),
cutlass_block_fp8_supported=cutlass_block_fp8_supported(),
use_aiter_and_is_supported=False,
)
def create_weights(
self,
layer: torch.nn.Module,
input_size_per_partition: int,
output_partition_sizes: list[int],
input_size: int,
output_size: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
):
del input_size, output_size
if not self.quant_config.is_checkpoint_fp8_serialized:
raise ValueError(
"FP8_PB_WO currently only supports FP8-serialized checkpoints."
)
output_size_per_partition = sum(output_partition_sizes)
weight_loader = extra_weight_attrs.get("weight_loader")
layer.logical_widths = output_partition_sizes
layer.input_size_per_partition = input_size_per_partition
layer.output_size_per_partition = output_size_per_partition
# Expose block size so the v2 weight loaders can translate offsets from
# element-space -> block-space for BlockQuantScaleParameter.
layer.weight_block_size = self.weight_block_size
weight = ModelWeightParameter(
data=torch.empty(
output_size_per_partition,
input_size_per_partition,
dtype=torch.float8_e4m3fn,
),
input_dim=1,
output_dim=0,
weight_loader=weight_loader,
)
layer.register_parameter("weight", weight)
block_n, block_k = self._WEIGHT_BLOCK_SIZE
if output_size_per_partition % block_n != 0:
raise ValueError(
"ModelOpt FP8_PB_WO requires out_features divisible by "
f"{block_n}, got {output_size_per_partition}."
)
if input_size_per_partition % block_k != 0:
raise ValueError(
"ModelOpt FP8_PB_WO requires in_features divisible by "
f"{block_k}, got {input_size_per_partition}."
)
out_blks = output_size_per_partition // block_n
in_blks = input_size_per_partition // block_k
# Match ModelOpt's exported shape so weight loading works without a
# custom loader: [out_blk, 1, in_blk, 1]
weight_scale = BlockQuantScaleParameter(
data=torch.empty((out_blks, 1, in_blks, 1), dtype=torch.float32),
input_dim=2,
output_dim=0,
weight_loader=weight_loader,
)
weight_scale[:] = torch.finfo(torch.float32).min
layer.register_parameter("weight_scale", weight_scale)
def process_weights_after_loading(self, layer: Module) -> None:
# Keep weight in [out, in] layout for W8A8BlockFp8LinearOp.
layer.weight = Parameter(layer.weight.data, requires_grad=False)
scale = layer.weight_scale
if scale.dim() == 4:
# [out_blk, 1, in_blk, 1] -> [out_blk, in_blk]
scale = scale.squeeze(1).squeeze(-1)
elif scale.dim() != 2:
raise ValueError(
"Unexpected ModelOpt FP8_PB_WO weight_scale shape: "
f"{tuple(scale.shape)}."
)
layer.weight_scale = Parameter(scale.contiguous(), requires_grad=False)
def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: torch.Tensor | None = None,
) -> torch.Tensor:
return self.w8a8_block_fp8_linear.apply(
input=x,
weight=layer.weight,
weight_scale=layer.weight_scale,
input_scale=None,
bias=bias,
)
class ModelOptFp8MoEMethod(FusedMoEMethodBase):
"""MoE method for ModelOpt FP8.
Supports loading FP8 checkpoints with static weight scale and