Add Nvidia ModelOpt config adaptation (#19815)

Signed-off-by: Zhiyu Cheng <zhiyuc@nvidia.com>
This commit is contained in:
Zhiyu 2025-07-21 07:02:58 -07:00 committed by GitHub
parent d97841078b
commit 6b46c4b653
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 287 additions and 32 deletions

View File

@ -0,0 +1,91 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Test ModelOpt quantization method setup and weight loading.
Run `pytest tests/quantization/test_modelopt.py`.
"""
import os
import pytest
import torch
from tests.quantization.utils import is_quant_method_supported
from vllm.platforms import current_platform
@pytest.fixture(scope="function", autouse=True)
def use_v0_only(monkeypatch):
"""
This module relies on V0 internals, so set VLLM_USE_V1=0.
"""
if not current_platform.is_cpu():
monkeypatch.setenv('VLLM_USE_V1', '0')
@pytest.mark.skipif(not is_quant_method_supported("modelopt"),
reason="ModelOpt FP8 is not supported on this GPU type.")
def test_modelopt_fp8_checkpoint_setup(vllm_runner):
"""Test ModelOpt FP8 checkpoint loading and structure validation."""
# TODO: provide a small publically available test checkpoint
model_path = ("/home/scratch.omniml_data_1/zhiyu/ckpts/test_ckpts/"
"TinyLlama-1.1B-Chat-v1.0-fp8-0710")
# Skip test if checkpoint doesn't exist
if not os.path.exists(model_path):
pytest.skip(f"Test checkpoint not found at {model_path}. "
"This test requires a local ModelOpt FP8 checkpoint.")
with vllm_runner(model_path, quantization="modelopt",
enforce_eager=True) as llm:
def check_model(model):
layer = model.model.layers[0]
qkv_proj = layer.self_attn.qkv_proj
o_proj = layer.self_attn.o_proj
gate_up_proj = layer.mlp.gate_up_proj
down_proj = layer.mlp.down_proj
# Check that ModelOpt quantization method is properly applied
from vllm.model_executor.layers.quantization.modelopt import (
ModelOptFp8LinearMethod)
assert isinstance(qkv_proj.quant_method, ModelOptFp8LinearMethod)
assert isinstance(o_proj.quant_method, ModelOptFp8LinearMethod)
assert isinstance(gate_up_proj.quant_method,
ModelOptFp8LinearMethod)
assert isinstance(down_proj.quant_method, ModelOptFp8LinearMethod)
# Check weight dtype is FP8
assert qkv_proj.weight.dtype == torch.float8_e4m3fn
assert o_proj.weight.dtype == torch.float8_e4m3fn
assert gate_up_proj.weight.dtype == torch.float8_e4m3fn
assert down_proj.weight.dtype == torch.float8_e4m3fn
# Check scales are present and have correct dtype
assert hasattr(qkv_proj, 'weight_scale')
assert hasattr(qkv_proj, 'input_scale')
assert qkv_proj.weight_scale.dtype == torch.float32
assert qkv_proj.input_scale.dtype == torch.float32
assert hasattr(o_proj, 'weight_scale')
assert hasattr(o_proj, 'input_scale')
assert o_proj.weight_scale.dtype == torch.float32
assert o_proj.input_scale.dtype == torch.float32
assert hasattr(gate_up_proj, 'weight_scale')
assert hasattr(gate_up_proj, 'input_scale')
assert gate_up_proj.weight_scale.dtype == torch.float32
assert gate_up_proj.input_scale.dtype == torch.float32
assert hasattr(down_proj, 'weight_scale')
assert hasattr(down_proj, 'input_scale')
assert down_proj.weight_scale.dtype == torch.float32
assert down_proj.input_scale.dtype == torch.float32
llm.apply_model(check_model)
# Run a simple generation test to ensure the model works
output = llm.generate_greedy(["Hello my name is"], max_tokens=20)
assert output
print(f"ModelOpt FP8 output: {output}")

View File

@ -346,11 +346,11 @@ class ModelConfig:
"""Maximum number of data items per modality per prompt. Only applicable
for multimodal models."""
interleave_mm_strings: bool = False
"""Enable fully interleaved support for multimodal prompts, while using
"""Enable fully interleaved support for multimodal prompts, while using
--chat-template-content-format=string. Defaults to False."""
media_io_kwargs: dict[str, dict[str, Any]] = field(default_factory=dict)
"""Additional args passed to process media inputs, keyed by modalities.
For example, to set num_frames for video, set
"""Additional args passed to process media inputs, keyed by modalities.
For example, to set num_frames for video, set
`--media-io-kwargs '{"video": {"num_frames": 40} }'` """
use_async_output_proc: bool = True
"""Whether to use async output processor."""
@ -1000,9 +1000,13 @@ class ModelConfig:
quant_cfg = self._parse_quant_hf_config()
if quant_cfg is not None:
# Use the community standard 'quant_method'
quant_method = quant_cfg.get("quant_method", "").lower()
# Normalize library names
quant_method = quant_method.replace("compressed_tensors",
"compressed-tensors")
quant_cfg["quant_method"] = quant_method
# Quantization methods which are overrides (i.e. they have a
@ -1017,6 +1021,8 @@ class ModelConfig:
"awq_marlin",
"ipex",
"moe_wna16",
"modelopt",
"modelopt_fp4",
]
quantization_methods = [
q for q in supported_quantization if q not in overrides
@ -3185,8 +3191,8 @@ class MultiModalConfig:
"""
media_io_kwargs: dict[str, dict[str, Any]] = field(default_factory=dict)
"""Additional args passed to process media inputs, keyed by modalities.
For example, to set num_frames for video, set
"""Additional args passed to process media inputs, keyed by modalities.
For example, to set num_frames for video, set
`--media-io-kwargs '{"video": {"num_frames": 40} }'` """
mm_processor_kwargs: Optional[dict[str, object]] = None
@ -4086,7 +4092,7 @@ class CompilationConfig:
- True: inductor compilation is used (custom_ops disabled by default).
One graph for symbolic shape and one graph per size in compile_sizes
are compiled using configurations in inductor_compile_config.
This setting is ignored if level<PIECEWISE."""
compile_sizes: Optional[list[Union[int, str]]] = None
"""Sizes to compile for inductor. In addition
@ -4385,7 +4391,7 @@ class VllmConfig:
As a shorthand, `-O<n>` can be used to directly specify the compilation
level `n`: `-O3` is equivalent to `-O.level=3` (same as `-O='{"level":3}'`).
Currently, -O <n> and -O=<n> are supported as well but this will likely be
Currently, -O <n> and -O=<n> are supported as well but this will likely be
removed in favor of clearer -O<n> syntax in the future.
NOTE: level 0 is the default level without any optimization. level 1 and 2

View File

@ -75,20 +75,64 @@ class ModelOptFp8Config(QuantizationConfig):
def get_config_filenames(cls) -> list[str]:
return ["hf_quant_config.json"]
@classmethod
def override_quantization_method(
cls, hf_quant_cfg, user_quant) -> Optional[QuantizationMethods]:
"""Detect if this ModelOpt config should be used based on
quantization config."""
if hf_quant_cfg is None:
return None
# Use the community standard 'quant_method'
quant_method = hf_quant_cfg.get("quant_method", "").lower()
# Only proceed if the method is explicitly "modelopt"
if quant_method != "modelopt":
return None
# Look for ModelOpt-specific config structure
if "quantization" in hf_quant_cfg:
quant_config = hf_quant_cfg["quantization"]
if isinstance(quant_config, dict):
quant_algo = quant_config.get("quant_algo", "")
if "FP8" in quant_algo:
return "modelopt"
else:
# Check for compressed-tensors style config with specific quant_algo
quant_algo = hf_quant_cfg.get("quant_algo", "")
if isinstance(quant_algo, str) and "FP8" in quant_algo:
return "modelopt"
return None
@classmethod
def from_config(cls, config: dict[str, Any]) -> "ModelOptFp8Config":
quant_config = cls.get_from_keys(config, ["quantization"])
quant_method = quant_config["quant_algo"]
kv_cache_quant_method = cls.get_from_keys(
config, ["quantization"]).get("kv_cache_quant_algo")
exclude_modules = cls.get_from_keys(
config, ["quantization"]).get("exclude_modules")
# Handle both ModelOpt format and compressed-tensors style format
if "quantization" in config:
# ModelOpt format: {"quantization": {"quant_algo": "..."}}
quant_config = cls.get_from_keys(config, ["quantization"])
if not isinstance(quant_config, dict):
raise ValueError(
"Expected 'quantization' to be a dictionary in config")
quant_method = quant_config.get("quant_algo", "")
if not quant_method:
raise ValueError("Missing 'quant_algo' in quantization config")
kv_cache_quant_method = quant_config.get("kv_cache_quant_algo")
exclude_modules = quant_config.get("exclude_modules")
else:
# Compressed-tensors style format:
# {"quant_algo": "...", "quant_method": "modelopt"}
quant_method = config.get("quant_algo", "")
kv_cache_quant_method = config.get("kv_cache_quant_algo")
exclude_modules = config.get("exclude_modules")
if quant_method not in QUANT_ALGOS:
raise ValueError(f"ModelOpt currently only supports: {QUANT_ALGOS}"
" quantizations in vLLM. Please check the "
"`hf_quant_config.json` file for your model's "
"quant configuration.")
raise ValueError(
f"ModelOpt currently only supports: {QUANT_ALGOS} "
"quantizations in vLLM. Please check the "
"`hf_quant_config.json` file for your model's "
"quant configuration.")
is_checkpoint_fp8_serialized = ("FP8" in quant_method)
return cls(is_checkpoint_fp8_serialized, kv_cache_quant_method,
@ -434,7 +478,7 @@ class ModelOptNvFp4Config(QuantizationConfig):
def __init__(
self,
is_checkpoint_nvfp4_serialized: bool,
kv_cache_quant_algo: str,
kv_cache_quant_algo: Optional[str],
exclude_modules: list[str],
group_size: int = 16,
) -> None:
@ -465,24 +509,138 @@ class ModelOptNvFp4Config(QuantizationConfig):
def get_config_filenames(cls) -> list[str]:
return ["hf_quant_config.json"]
@classmethod
def override_quantization_method(
cls, hf_quant_cfg, user_quant) -> Optional[QuantizationMethods]:
"""Detect if this ModelOpt FP4 config should be used based on
quantization config."""
if hf_quant_cfg is None:
return None
# Use the community standard 'quant_method'
quant_method = hf_quant_cfg.get("quant_method", "").lower()
# Only proceed if the method is explicitly "modelopt"
if quant_method != "modelopt":
return None
# Look for ModelOpt-specific config structure
if "quantization" in hf_quant_cfg:
quant_config = hf_quant_cfg["quantization"]
if isinstance(quant_config, dict):
quant_algo = quant_config.get("quant_algo", "")
if "NVFP4" in quant_algo:
return "modelopt_fp4"
else:
# Check for compressed-tensors style config with specific
# quant_algo field
quant_algo = hf_quant_cfg.get("quant_algo", "")
if isinstance(quant_algo, str) and "FP4" in quant_algo.upper():
return "modelopt_fp4"
return None
@classmethod
def from_config(cls, config: dict[str, Any]) -> "ModelOptNvFp4Config":
quant_config = cls.get_from_keys(config, ["quantization"])
quant_method = quant_config["quant_algo"]
# Handle both traditional ModelOpt format and compressed-tensors
# style format
if "quantization" in config:
# Traditional ModelOpt format:
# {"quantization": {"quant_algo": "..."}}
quant_config = cls.get_from_keys(config, ["quantization"])
if not isinstance(quant_config, dict):
raise ValueError(
"Expected 'quantization' to be a dictionary in config")
quant_method = quant_config.get("quant_algo", "")
if not quant_method:
raise ValueError("Missing 'quant_algo' in quantization config")
# Handle kv_cache_quant_algo with proper type validation
kv_cache_quant_algo_raw = quant_config.get("kv_cache_quant_algo")
if kv_cache_quant_algo_raw is None:
# No KV cache quantization by default
kv_cache_quant_algo = None
elif isinstance(kv_cache_quant_algo_raw, str):
kv_cache_quant_algo = kv_cache_quant_algo_raw
else:
raise ValueError(f"kv_cache_quant_algo must be a string, got "
f"{type(kv_cache_quant_algo_raw)}")
# Handle group_size with proper type validation
group_size_raw = quant_config.get("group_size")
if group_size_raw is None:
group_size = 16 # Default value
elif isinstance(group_size_raw, int):
group_size = group_size_raw
else:
try:
group_size = int(group_size_raw)
except (ValueError, TypeError):
raise ValueError(f"group_size must be an integer, got "
f"{type(group_size_raw)}") from None
exclude_modules = quant_config.get("exclude_modules", [])
if not isinstance(exclude_modules, list):
raise ValueError(f"exclude_modules must be a list, got "
f"{type(exclude_modules)}")
else:
# Compressed-tensors style format:
# {"quant_algo": "...", "quant_method": "modelopt"}
quant_method = config.get("quant_algo", "")
# Handle kv_cache_quant_algo with proper type validation
kv_cache_quant_algo_raw = config.get("kv_cache_quant_algo")
if kv_cache_quant_algo_raw is None:
# No KV cache quantization by default
kv_cache_quant_algo = None
elif isinstance(kv_cache_quant_algo_raw, str):
kv_cache_quant_algo = kv_cache_quant_algo_raw
else:
raise ValueError(f"kv_cache_quant_algo must be a string, got "
f"{type(kv_cache_quant_algo_raw)}")
# Handle group_size with proper type validation
group_size_raw = config.get("group_size")
if group_size_raw is None:
group_size = 16 # Default value
elif isinstance(group_size_raw, int):
group_size = group_size_raw
else:
try:
group_size = int(group_size_raw)
except (ValueError, TypeError):
raise ValueError(f"group_size must be an integer, got "
f"{type(group_size_raw)}") from None
exclude_modules = config.get("exclude_modules", [])
if not isinstance(exclude_modules, list):
raise ValueError(f"exclude_modules must be a list, got "
f"{type(exclude_modules)}")
if quant_method not in QUANT_ALGOS:
raise ValueError(f"ModelOpt currently only supports: {QUANT_ALGOS}"
" quantizations in vLLM. Please check the "
"`hf_quant_config.json` file for your model's "
"quant configuration.")
raise ValueError(
f"ModelOpt currently only supports: {QUANT_ALGOS} "
"quantizations in vLLM. Please check the "
"`hf_quant_config.json` file for your model's "
"quant configuration.")
is_checkpoint_nvfp4_serialized = ("NVFP4" in quant_method)
if ("group_size" and "kv_cache_quant_algo"
and "exclude_modules") not in quant_config:
raise ValueError("NVFP4 quantization requires group size and "
"kv_cache_quant_algo specified in "
"hf_quant_config.json")
kv_cache_quant_algo = quant_config["kv_cache_quant_algo"]
group_size = quant_config["group_size"]
exclude_modules = quant_config["exclude_modules"]
# For FP4, these fields are required
if is_checkpoint_nvfp4_serialized and "quantization" in config:
# Check if required fields are present in the quantization config
quant_config = config["quantization"]
required_fields = [
"group_size", "kv_cache_quant_algo", "exclude_modules"
]
missing_fields = [
field for field in required_fields if field not in quant_config
]
if missing_fields:
raise ValueError(
f"NVFP4 quantization requires the following fields in "
f"hf_quant_config.json: {missing_fields}")
return cls(is_checkpoint_nvfp4_serialized, kv_cache_quant_algo,
exclude_modules, group_size)