diff --git a/tests/quantization/test_modelopt.py b/tests/quantization/test_modelopt.py new file mode 100644 index 000000000000..fcbfa681d75c --- /dev/null +++ b/tests/quantization/test_modelopt.py @@ -0,0 +1,91 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Test ModelOpt quantization method setup and weight loading. + +Run `pytest tests/quantization/test_modelopt.py`. +""" + +import os + +import pytest +import torch + +from tests.quantization.utils import is_quant_method_supported +from vllm.platforms import current_platform + + +@pytest.fixture(scope="function", autouse=True) +def use_v0_only(monkeypatch): + """ + This module relies on V0 internals, so set VLLM_USE_V1=0. + """ + if not current_platform.is_cpu(): + monkeypatch.setenv('VLLM_USE_V1', '0') + + +@pytest.mark.skipif(not is_quant_method_supported("modelopt"), + reason="ModelOpt FP8 is not supported on this GPU type.") +def test_modelopt_fp8_checkpoint_setup(vllm_runner): + """Test ModelOpt FP8 checkpoint loading and structure validation.""" + # TODO: provide a small publically available test checkpoint + model_path = ("/home/scratch.omniml_data_1/zhiyu/ckpts/test_ckpts/" + "TinyLlama-1.1B-Chat-v1.0-fp8-0710") + + # Skip test if checkpoint doesn't exist + if not os.path.exists(model_path): + pytest.skip(f"Test checkpoint not found at {model_path}. " + "This test requires a local ModelOpt FP8 checkpoint.") + + with vllm_runner(model_path, quantization="modelopt", + enforce_eager=True) as llm: + + def check_model(model): + layer = model.model.layers[0] + + qkv_proj = layer.self_attn.qkv_proj + o_proj = layer.self_attn.o_proj + gate_up_proj = layer.mlp.gate_up_proj + down_proj = layer.mlp.down_proj + + # Check that ModelOpt quantization method is properly applied + from vllm.model_executor.layers.quantization.modelopt import ( + ModelOptFp8LinearMethod) + assert isinstance(qkv_proj.quant_method, ModelOptFp8LinearMethod) + assert isinstance(o_proj.quant_method, ModelOptFp8LinearMethod) + assert isinstance(gate_up_proj.quant_method, + ModelOptFp8LinearMethod) + assert isinstance(down_proj.quant_method, ModelOptFp8LinearMethod) + + # Check weight dtype is FP8 + assert qkv_proj.weight.dtype == torch.float8_e4m3fn + assert o_proj.weight.dtype == torch.float8_e4m3fn + assert gate_up_proj.weight.dtype == torch.float8_e4m3fn + assert down_proj.weight.dtype == torch.float8_e4m3fn + + # Check scales are present and have correct dtype + assert hasattr(qkv_proj, 'weight_scale') + assert hasattr(qkv_proj, 'input_scale') + assert qkv_proj.weight_scale.dtype == torch.float32 + assert qkv_proj.input_scale.dtype == torch.float32 + + assert hasattr(o_proj, 'weight_scale') + assert hasattr(o_proj, 'input_scale') + assert o_proj.weight_scale.dtype == torch.float32 + assert o_proj.input_scale.dtype == torch.float32 + + assert hasattr(gate_up_proj, 'weight_scale') + assert hasattr(gate_up_proj, 'input_scale') + assert gate_up_proj.weight_scale.dtype == torch.float32 + assert gate_up_proj.input_scale.dtype == torch.float32 + + assert hasattr(down_proj, 'weight_scale') + assert hasattr(down_proj, 'input_scale') + assert down_proj.weight_scale.dtype == torch.float32 + assert down_proj.input_scale.dtype == torch.float32 + + llm.apply_model(check_model) + + # Run a simple generation test to ensure the model works + output = llm.generate_greedy(["Hello my name is"], max_tokens=20) + assert output + print(f"ModelOpt FP8 output: {output}") diff --git a/vllm/config.py b/vllm/config.py index 4cafbc926052..3e6aa2a93e6a 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -346,11 +346,11 @@ class ModelConfig: """Maximum number of data items per modality per prompt. Only applicable for multimodal models.""" interleave_mm_strings: bool = False - """Enable fully interleaved support for multimodal prompts, while using + """Enable fully interleaved support for multimodal prompts, while using --chat-template-content-format=string. Defaults to False.""" media_io_kwargs: dict[str, dict[str, Any]] = field(default_factory=dict) - """Additional args passed to process media inputs, keyed by modalities. - For example, to set num_frames for video, set + """Additional args passed to process media inputs, keyed by modalities. + For example, to set num_frames for video, set `--media-io-kwargs '{"video": {"num_frames": 40} }'` """ use_async_output_proc: bool = True """Whether to use async output processor.""" @@ -1000,9 +1000,13 @@ class ModelConfig: quant_cfg = self._parse_quant_hf_config() if quant_cfg is not None: + # Use the community standard 'quant_method' quant_method = quant_cfg.get("quant_method", "").lower() + + # Normalize library names quant_method = quant_method.replace("compressed_tensors", "compressed-tensors") + quant_cfg["quant_method"] = quant_method # Quantization methods which are overrides (i.e. they have a @@ -1017,6 +1021,8 @@ class ModelConfig: "awq_marlin", "ipex", "moe_wna16", + "modelopt", + "modelopt_fp4", ] quantization_methods = [ q for q in supported_quantization if q not in overrides @@ -3185,8 +3191,8 @@ class MultiModalConfig: """ media_io_kwargs: dict[str, dict[str, Any]] = field(default_factory=dict) - """Additional args passed to process media inputs, keyed by modalities. - For example, to set num_frames for video, set + """Additional args passed to process media inputs, keyed by modalities. + For example, to set num_frames for video, set `--media-io-kwargs '{"video": {"num_frames": 40} }'` """ mm_processor_kwargs: Optional[dict[str, object]] = None @@ -4086,7 +4092,7 @@ class CompilationConfig: - True: inductor compilation is used (custom_ops disabled by default). One graph for symbolic shape and one graph per size in compile_sizes are compiled using configurations in inductor_compile_config. - + This setting is ignored if level` can be used to directly specify the compilation level `n`: `-O3` is equivalent to `-O.level=3` (same as `-O='{"level":3}'`). - Currently, -O and -O= are supported as well but this will likely be + Currently, -O and -O= are supported as well but this will likely be removed in favor of clearer -O syntax in the future. NOTE: level 0 is the default level without any optimization. level 1 and 2 diff --git a/vllm/model_executor/layers/quantization/modelopt.py b/vllm/model_executor/layers/quantization/modelopt.py index 20def70d1976..460334d77f0a 100644 --- a/vllm/model_executor/layers/quantization/modelopt.py +++ b/vllm/model_executor/layers/quantization/modelopt.py @@ -75,20 +75,64 @@ class ModelOptFp8Config(QuantizationConfig): def get_config_filenames(cls) -> list[str]: return ["hf_quant_config.json"] + @classmethod + def override_quantization_method( + cls, hf_quant_cfg, user_quant) -> Optional[QuantizationMethods]: + """Detect if this ModelOpt config should be used based on + quantization config.""" + + if hf_quant_cfg is None: + return None + + # Use the community standard 'quant_method' + quant_method = hf_quant_cfg.get("quant_method", "").lower() + + # Only proceed if the method is explicitly "modelopt" + if quant_method != "modelopt": + return None + + # Look for ModelOpt-specific config structure + if "quantization" in hf_quant_cfg: + quant_config = hf_quant_cfg["quantization"] + if isinstance(quant_config, dict): + quant_algo = quant_config.get("quant_algo", "") + if "FP8" in quant_algo: + return "modelopt" + else: + # Check for compressed-tensors style config with specific quant_algo + quant_algo = hf_quant_cfg.get("quant_algo", "") + if isinstance(quant_algo, str) and "FP8" in quant_algo: + return "modelopt" + + return None + @classmethod def from_config(cls, config: dict[str, Any]) -> "ModelOptFp8Config": - quant_config = cls.get_from_keys(config, ["quantization"]) - quant_method = quant_config["quant_algo"] - kv_cache_quant_method = cls.get_from_keys( - config, ["quantization"]).get("kv_cache_quant_algo") - exclude_modules = cls.get_from_keys( - config, ["quantization"]).get("exclude_modules") + # Handle both ModelOpt format and compressed-tensors style format + if "quantization" in config: + # ModelOpt format: {"quantization": {"quant_algo": "..."}} + quant_config = cls.get_from_keys(config, ["quantization"]) + if not isinstance(quant_config, dict): + raise ValueError( + "Expected 'quantization' to be a dictionary in config") + quant_method = quant_config.get("quant_algo", "") + if not quant_method: + raise ValueError("Missing 'quant_algo' in quantization config") + kv_cache_quant_method = quant_config.get("kv_cache_quant_algo") + exclude_modules = quant_config.get("exclude_modules") + else: + # Compressed-tensors style format: + # {"quant_algo": "...", "quant_method": "modelopt"} + quant_method = config.get("quant_algo", "") + kv_cache_quant_method = config.get("kv_cache_quant_algo") + exclude_modules = config.get("exclude_modules") if quant_method not in QUANT_ALGOS: - raise ValueError(f"ModelOpt currently only supports: {QUANT_ALGOS}" - " quantizations in vLLM. Please check the " - "`hf_quant_config.json` file for your model's " - "quant configuration.") + raise ValueError( + f"ModelOpt currently only supports: {QUANT_ALGOS} " + "quantizations in vLLM. Please check the " + "`hf_quant_config.json` file for your model's " + "quant configuration.") is_checkpoint_fp8_serialized = ("FP8" in quant_method) return cls(is_checkpoint_fp8_serialized, kv_cache_quant_method, @@ -434,7 +478,7 @@ class ModelOptNvFp4Config(QuantizationConfig): def __init__( self, is_checkpoint_nvfp4_serialized: bool, - kv_cache_quant_algo: str, + kv_cache_quant_algo: Optional[str], exclude_modules: list[str], group_size: int = 16, ) -> None: @@ -465,24 +509,138 @@ class ModelOptNvFp4Config(QuantizationConfig): def get_config_filenames(cls) -> list[str]: return ["hf_quant_config.json"] + @classmethod + def override_quantization_method( + cls, hf_quant_cfg, user_quant) -> Optional[QuantizationMethods]: + """Detect if this ModelOpt FP4 config should be used based on + quantization config.""" + if hf_quant_cfg is None: + return None + + # Use the community standard 'quant_method' + quant_method = hf_quant_cfg.get("quant_method", "").lower() + + # Only proceed if the method is explicitly "modelopt" + if quant_method != "modelopt": + return None + + # Look for ModelOpt-specific config structure + if "quantization" in hf_quant_cfg: + quant_config = hf_quant_cfg["quantization"] + if isinstance(quant_config, dict): + quant_algo = quant_config.get("quant_algo", "") + if "NVFP4" in quant_algo: + return "modelopt_fp4" + else: + # Check for compressed-tensors style config with specific + # quant_algo field + quant_algo = hf_quant_cfg.get("quant_algo", "") + if isinstance(quant_algo, str) and "FP4" in quant_algo.upper(): + return "modelopt_fp4" + + return None + @classmethod def from_config(cls, config: dict[str, Any]) -> "ModelOptNvFp4Config": - quant_config = cls.get_from_keys(config, ["quantization"]) - quant_method = quant_config["quant_algo"] + # Handle both traditional ModelOpt format and compressed-tensors + # style format + if "quantization" in config: + # Traditional ModelOpt format: + # {"quantization": {"quant_algo": "..."}} + quant_config = cls.get_from_keys(config, ["quantization"]) + if not isinstance(quant_config, dict): + raise ValueError( + "Expected 'quantization' to be a dictionary in config") + + quant_method = quant_config.get("quant_algo", "") + if not quant_method: + raise ValueError("Missing 'quant_algo' in quantization config") + + # Handle kv_cache_quant_algo with proper type validation + kv_cache_quant_algo_raw = quant_config.get("kv_cache_quant_algo") + if kv_cache_quant_algo_raw is None: + # No KV cache quantization by default + kv_cache_quant_algo = None + elif isinstance(kv_cache_quant_algo_raw, str): + kv_cache_quant_algo = kv_cache_quant_algo_raw + else: + raise ValueError(f"kv_cache_quant_algo must be a string, got " + f"{type(kv_cache_quant_algo_raw)}") + + # Handle group_size with proper type validation + group_size_raw = quant_config.get("group_size") + if group_size_raw is None: + group_size = 16 # Default value + elif isinstance(group_size_raw, int): + group_size = group_size_raw + else: + try: + group_size = int(group_size_raw) + except (ValueError, TypeError): + raise ValueError(f"group_size must be an integer, got " + f"{type(group_size_raw)}") from None + + exclude_modules = quant_config.get("exclude_modules", []) + if not isinstance(exclude_modules, list): + raise ValueError(f"exclude_modules must be a list, got " + f"{type(exclude_modules)}") + else: + # Compressed-tensors style format: + # {"quant_algo": "...", "quant_method": "modelopt"} + quant_method = config.get("quant_algo", "") + + # Handle kv_cache_quant_algo with proper type validation + kv_cache_quant_algo_raw = config.get("kv_cache_quant_algo") + if kv_cache_quant_algo_raw is None: + # No KV cache quantization by default + kv_cache_quant_algo = None + elif isinstance(kv_cache_quant_algo_raw, str): + kv_cache_quant_algo = kv_cache_quant_algo_raw + else: + raise ValueError(f"kv_cache_quant_algo must be a string, got " + f"{type(kv_cache_quant_algo_raw)}") + + # Handle group_size with proper type validation + group_size_raw = config.get("group_size") + if group_size_raw is None: + group_size = 16 # Default value + elif isinstance(group_size_raw, int): + group_size = group_size_raw + else: + try: + group_size = int(group_size_raw) + except (ValueError, TypeError): + raise ValueError(f"group_size must be an integer, got " + f"{type(group_size_raw)}") from None + + exclude_modules = config.get("exclude_modules", []) + if not isinstance(exclude_modules, list): + raise ValueError(f"exclude_modules must be a list, got " + f"{type(exclude_modules)}") + if quant_method not in QUANT_ALGOS: - raise ValueError(f"ModelOpt currently only supports: {QUANT_ALGOS}" - " quantizations in vLLM. Please check the " - "`hf_quant_config.json` file for your model's " - "quant configuration.") + raise ValueError( + f"ModelOpt currently only supports: {QUANT_ALGOS} " + "quantizations in vLLM. Please check the " + "`hf_quant_config.json` file for your model's " + "quant configuration.") is_checkpoint_nvfp4_serialized = ("NVFP4" in quant_method) - if ("group_size" and "kv_cache_quant_algo" - and "exclude_modules") not in quant_config: - raise ValueError("NVFP4 quantization requires group size and " - "kv_cache_quant_algo specified in " - "hf_quant_config.json") - kv_cache_quant_algo = quant_config["kv_cache_quant_algo"] - group_size = quant_config["group_size"] - exclude_modules = quant_config["exclude_modules"] + + # For FP4, these fields are required + if is_checkpoint_nvfp4_serialized and "quantization" in config: + # Check if required fields are present in the quantization config + quant_config = config["quantization"] + required_fields = [ + "group_size", "kv_cache_quant_algo", "exclude_modules" + ] + missing_fields = [ + field for field in required_fields if field not in quant_config + ] + if missing_fields: + raise ValueError( + f"NVFP4 quantization requires the following fields in " + f"hf_quant_config.json: {missing_fields}") + return cls(is_checkpoint_nvfp4_serialized, kv_cache_quant_algo, exclude_modules, group_size)