diff --git a/requirements/rocm.txt b/requirements/rocm.txt index 7038c9024c6b..c3bb65b70a0b 100644 --- a/requirements/rocm.txt +++ b/requirements/rocm.txt @@ -17,4 +17,4 @@ setuptools>=77.0.3,<80.0.0 setuptools-scm>=8 runai-model-streamer==0.11.0 runai-model-streamer-s3==0.11.0 -conch-triton-kernels==1.2.1 +conch-triton-kernels==1.2.1 \ No newline at end of file diff --git a/setup.py b/setup.py index fa406b868c07..ca6e0a8592cc 100644 --- a/setup.py +++ b/setup.py @@ -695,6 +695,8 @@ setup( "video": [], # Kept for backwards compatibility # FlashInfer should be updated together with the Dockerfile "flashinfer": ["flashinfer-python==0.2.12"], + # Optional deps for AMD FP4 quantization support + "petit-kernel": ["petit-kernel"], }, cmdclass=cmdclass, package_data=package_data, diff --git a/vllm/config/__init__.py b/vllm/config/__init__.py index fbc4dd3989f5..6ce40626b3a8 100644 --- a/vllm/config/__init__.py +++ b/vllm/config/__init__.py @@ -1119,9 +1119,20 @@ class ModelConfig: def _verify_quantization(self) -> None: supported_quantization = me_quant.QUANTIZATION_METHODS optimized_quantization_methods = [ - "fp8", "modelopt", "gptq_marlin_24", "gptq_marlin", "awq_marlin", - "fbgemm_fp8", "compressed-tensors", "experts_int8", "quark", - "modelopt_fp4", "bitblas", "gptq_bitblas", "inc" + "fp8", + "modelopt", + "gptq_marlin_24", + "gptq_marlin", + "awq_marlin", + "fbgemm_fp8", + "compressed-tensors", + "experts_int8", + "quark", + "modelopt_fp4", + "bitblas", + "gptq_bitblas", + "inc", + "petit_nvfp4", ] if self.quantization is not None: self.quantization = cast(me_quant.QuantizationMethods, @@ -1153,6 +1164,7 @@ class ModelConfig: "moe_wna16", "modelopt", "modelopt_fp4", + "petit_nvfp4", ] quantization_methods = [ q for q in supported_quantization if q not in overrides diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index 9b1ab7af0ac8..5725c841e529 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -52,6 +52,7 @@ WEIGHT_LOADER_V2_SUPPORTED = [ "HQQMarlinMethod", "QuarkLinearMethod", "ModelOptNvFp4LinearMethod", + "PetitNvFp4LinearMethod", ] diff --git a/vllm/model_executor/layers/quantization/__init__.py b/vllm/model_executor/layers/quantization/__init__.py index ea51468422dc..d73fcf368f26 100644 --- a/vllm/model_executor/layers/quantization/__init__.py +++ b/vllm/model_executor/layers/quantization/__init__.py @@ -35,6 +35,7 @@ QuantizationMethods = Literal[ "rtn", "inc", "mxfp4", + "petit_nvfp4", ] QUANTIZATION_METHODS: list[str] = list(get_args(QuantizationMethods)) @@ -108,6 +109,7 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]: from .moe_wna16 import MoeWNA16Config from .mxfp4 import Mxfp4Config from .neuron_quant import NeuronQuantConfig + from .petit import PetitNvFp4Config from .ptpc_fp8 import PTPCFp8Config from .rtn import RTNConfig from .torchao import TorchAOConfig @@ -142,6 +144,7 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]: "rtn": RTNConfig, "inc": INCConfig, "mxfp4": Mxfp4Config, + "petit_nvfp4": PetitNvFp4Config, } # Update the `method_to_config` with customized quantization methods. method_to_config.update(_CUSTOMIZED_METHOD_TO_QUANT_CONFIG) diff --git a/vllm/model_executor/layers/quantization/petit.py b/vllm/model_executor/layers/quantization/petit.py new file mode 100644 index 000000000000..5b9fee69bb02 --- /dev/null +++ b/vllm/model_executor/layers/quantization/petit.py @@ -0,0 +1,306 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/modelopt.py + +from typing import Any, Optional + +import regex as re +import torch +from torch.nn.parameter import Parameter + +from vllm.logger import init_logger +from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, + UnquantizedLinearMethod) +from vllm.model_executor.layers.quantization import QuantizationMethods +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig, QuantizeMethodBase) +from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod +from vllm.model_executor.layers.quantization.utils.petit_utils import ( + apply_petit_nvfp4_linear, prepare_nvfp4_layer_for_petit, + verify_petit_nvfp4_supported) +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + is_layer_skipped) +from vllm.model_executor.parameter import (ModelWeightParameter, + PerTensorScaleParameter) +from vllm.platforms import current_platform + +# Initialize logger for the module +logger = init_logger(__name__) + + +# Configuration class to support the NVFP4 quantized model +# generated by the ModelOpt quantization tool +class PetitNvFp4Config(QuantizationConfig): + """Config class for Petit FP4.""" + + def __init__( + self, + is_checkpoint_nvfp4_serialized: bool = False, + kv_cache_quant_algo: Optional[str] = None, + group_size: Optional[int] = None, + exclude_modules: Optional[list[str]] = None, + ) -> None: + self._check_hardware_support() + self.is_checkpoint_nvfp4_serialized = is_checkpoint_nvfp4_serialized + if is_checkpoint_nvfp4_serialized: + logger.warning("Detected nvfp4 checkpoint. Please note that the " + "format is experimental and subject to change.") + self.group_size = group_size + self.kv_cache_quant_algo = kv_cache_quant_algo + self.exclude_modules = exclude_modules + + def _check_hardware_support(self) -> None: + """ + Verifies that the current hardware is supported by the Petit backend. + This backend is specifically designed for AMD GPUs and is not + supported on the CUDA platform. + """ + # This check ensures the code is NOT running on an NVIDIA GPU. + if current_platform.is_cuda(): + raise ValueError( + "The 'petit' quantization backend is designed for AMD GPUs " + "and is not supported on the CUDA platform. For NVIDIA GPUs, " + "please use a different quantization method such as FP8, AWQ, " + "or GPTQ.") + + @classmethod + def get_name(cls) -> QuantizationMethods: + return "petit_nvfp4" + + @classmethod + def get_supported_act_dtypes(cls) -> list[torch.dtype]: + return [torch.bfloat16, torch.half] + + @classmethod + def get_min_capability(cls) -> int: + # Petit supports the gfx90a and gfx942 GPUs + return 90 + + @classmethod + def get_config_filenames(cls) -> list[str]: + return ["hf_quant_config.json"] + + @classmethod + def from_config(cls, config: dict[str, Any]) -> "PetitNvFp4Config": + qc = cls.get_from_keys(config, ["quantization"]) + + quant_method_raw = qc.get("quant_algo") + if not isinstance(quant_method_raw, str) or not quant_method_raw: + raise ValueError( + "Missing or invalid 'quant_algo' in quantization config.") + quant_method = quant_method_raw.upper() + + group_size_raw = qc.get("group_size") + if not isinstance(group_size_raw, int): + raise ValueError( + "Missing or invalid 'group_size' (int) in hf_quant_config.json." + ) + group_size = group_size_raw + + verify_petit_nvfp4_supported(quant_method, group_size) + + kv_cache_quant_algo_raw = qc.get("kv_cache_quant_algo") or "auto" + if not isinstance(kv_cache_quant_algo_raw, str): + raise ValueError( + "'kv_cache_quant_algo' must be a string if provided.") + kv_cache_quant_algo = kv_cache_quant_algo_raw + + exclude_raw = qc.get("exclude_modules", []) + if exclude_raw is None: + exclude_modules: list[str] = [] + elif isinstance(exclude_raw, list) and all( + isinstance(x, str) for x in exclude_raw): + exclude_modules = exclude_raw + else: + raise ValueError( + "'exclude_modules' must be a list[str] (or omitted).") + + is_checkpoint_nvfp4_serialized = "NVFP4" in quant_method + + return cls( + is_checkpoint_nvfp4_serialized=is_checkpoint_nvfp4_serialized, + kv_cache_quant_algo=kv_cache_quant_algo, + group_size=group_size, + exclude_modules=exclude_modules, + ) + + @classmethod + def override_quantization_method( + cls, hf_quant_cfg, user_quant) -> Optional[QuantizationMethods]: + if not current_platform.is_rocm(): + return None + + qc = hf_quant_cfg.get("quantization", hf_quant_cfg) + algo = (qc.get("quant_algo") or qc.get("quant_method") or "").upper() + if algo in ("NVFP4", "MODELOPT_FP4", "MODELOPT"): + return cls.get_name() # "petit_nvfp4" + return None + + @classmethod + def is_petit_nvfp4_compatible(cls, quant_config: dict[str, Any]) -> bool: + qc = quant_config.get("quantization", quant_config) + algo = (qc.get("quant_algo") or qc.get("quant_method") or "").upper() + return algo == "NVFP4" + + def is_layer_excluded(self, prefix: str, + exclude_modules: list[str]) -> bool: + for pattern in exclude_modules: + regex_str = pattern.replace(".", r"\.").replace("*", r".*") + if re.fullmatch(regex_str, prefix): + return True + return False + + def get_quant_method(self, layer: torch.nn.Module, + prefix: str) -> Optional["QuantizeMethodBase"]: + from vllm.attention.layer import Attention # Avoid circular import + + exclude = self.require_exclude_modules() + + if isinstance(layer, LinearBase): + if is_layer_skipped(prefix, exclude) or self.is_layer_excluded( + prefix, exclude): + return UnquantizedLinearMethod() + return PetitNvFp4LinearMethod(self) + elif isinstance(layer, Attention): + return PetitFp8KVCacheMethod(self) + return None + + def get_scaled_act_names(self) -> list[str]: + return [] + + def require_group_size(self) -> int: + if self.group_size is None: + logger.warning("group_size not set; defaulting to 16 for NVFP4.") + return 16 + return self.group_size + + def require_kv_cache_quant_algo(self) -> str: + return self.kv_cache_quant_algo or "auto" + + def require_exclude_modules(self) -> list[str]: + return list(self.exclude_modules or []) + + +class PetitFp8KVCacheMethod(BaseKVCacheMethod): + """ + Supports loading kv-cache scaling factors from FP8 checkpoints. + """ + + def __init__(self, quant_config: PetitNvFp4Config): + super().__init__(quant_config) + + +class PetitNvFp4LinearMethod(LinearMethodBase): + """Linear method for NVFP4. + Supports loading NVFP4 checkpoints with the following structure: + + |Tensor Name | datatype | shape | + |----------------------------------------------------| + |input_scale | torch.float32 | scalar | + |weight | NVFP4(SE2M1) | [1, X, y/2] | + |weight_scale | FP8-E4M3 | [X, Y] | + |weight_scale_2 | torch.float32 | scalar | + + The weights are quantized per block of 16 elements. + Args: quant_config: The ModelOpt quantization config. + """ + + def __init__(self, quant_config: PetitNvFp4Config): + self.quant_config = quant_config + + def create_weights( + self, + layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: list[int], + input_size: int, + output_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + del input_size, output_size + if not self.quant_config.is_checkpoint_nvfp4_serialized: + raise ValueError("NVFP4 quantization was selected, " + " dynamic quantization is not supported.") + + output_size_per_partition = sum(output_partition_sizes) + weight_loader = extra_weight_attrs.get("weight_loader") + + layer.logical_widths = output_partition_sizes + + layer.input_size_per_partition = input_size_per_partition + layer.output_size_per_partition = output_size_per_partition + if input_size_per_partition % 16 != 0: + raise ValueError("Unsupported model when in features size is " + "not multiple of 16") + + weight_dtype = (torch.float8_e4m3fn + if self.quant_config.is_checkpoint_nvfp4_serialized + else params_dtype) + + weight = ModelWeightParameter( + data=torch.empty( + # 2 fp4 data is packed in one uint8 in the input dimension + output_size_per_partition, + input_size_per_partition // 2, + dtype=torch.uint8, + ), + input_dim=1, + output_dim=0, + weight_loader=weight_loader, + ) + layer.register_parameter("weight", weight) + + input_scale = PerTensorScaleParameter( + data=torch.empty(len(output_partition_sizes), dtype=torch.float32), + weight_loader=weight_loader, + ) + + layer.register_parameter("input_scale", input_scale) + + weight_scale_2 = PerTensorScaleParameter( + data=torch.empty(len(output_partition_sizes), dtype=torch.float32), + weight_loader=weight_loader, + ) + layer.register_parameter("weight_scale_2", weight_scale_2) + + group_size = self.quant_config.require_group_size() + weight_scale = ModelWeightParameter( + data=torch.empty( + output_size_per_partition, + input_size_per_partition // group_size, + dtype=weight_dtype, + ), + input_dim=1, + output_dim=0, + weight_loader=weight_loader, + ) + + layer.register_parameter("weight_scale", weight_scale) + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + input_scale_2 = layer.input_scale.max().to(torch.float32) + weight_scale_2 = layer.weight_scale_2.max().to(torch.float32) + layer.input_scale = Parameter(input_scale_2, requires_grad=False) + layer.weight_scale_2 = Parameter(weight_scale_2, requires_grad=False) + layer.alpha = Parameter(layer.input_scale * layer.weight_scale_2, + requires_grad=False) + + prepare_nvfp4_layer_for_petit(layer) + del layer.input_scale + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + return apply_petit_nvfp4_linear( + input=x, + weight=layer.weight, + weight_scale=layer.weight_scale, + weight_scale_2=layer.weight_scale_2, + size_n=layer.output_size_per_partition, + size_k=layer.input_size_per_partition, + bias=bias, + ) diff --git a/vllm/model_executor/layers/quantization/utils/petit_utils.py b/vllm/model_executor/layers/quantization/utils/petit_utils.py new file mode 100644 index 000000000000..00d3def1db81 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/petit_utils.py @@ -0,0 +1,122 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from typing import TYPE_CHECKING, Optional + +import torch + +# TYPE_CHECKING is used for static type analysis to prevent circular imports. +if TYPE_CHECKING: + from types import ModuleType + +# 1. Create a global variable as a placeholder for the module +_petit_kernel: Optional["ModuleType"] = None + +_PETIT_INSTALL_MSG = ("Petit is not installed. Please install it with " + "`pip install petit-kernel`.") + + +def _import_petit_kernel() -> "ModuleType": + """ + A helper function to handle the lazy import. + The first time this function is called, it will import the petit_kernel + library and store it in the global _petit_kernel variable. + Subsequent calls will return the already-loaded module directly. + """ + global _petit_kernel + if _petit_kernel is not None: + return _petit_kernel + + try: + import petit_kernel + _petit_kernel = petit_kernel + return _petit_kernel + except ImportError: + # The 'from None' syntax prevents chaining the original ImportError, + # making the traceback cleaner. + raise ImportError(_PETIT_INSTALL_MSG) from None + + +# The _require_petit function can now be a simple alias for consistency. +_require_petit = _import_petit_kernel + + +def _check_petit_nvfp4_supported( + quant_method: str, + group_size: Optional[int]) -> tuple[bool, Optional[str]]: + if quant_method != "NVFP4": + return ( + False, + ("Petit currently only supports: NVFP4 quantizations in sglang. " + "Please check the `hf_quant_config.json` file for your model's " + "quant configuration."), + ) + if group_size is not None and group_size != 16: + return ( + False, + "Petit currently only supports: group_size=16 quantizations.", + ) + return (True, None) + + +def verify_petit_nvfp4_supported(quant_method: str, + group_size: Optional[int]) -> None: + supported, error_msg = _check_petit_nvfp4_supported( + quant_method, group_size) + if not supported: + assert error_msg is not None + raise ValueError(error_msg) + + +def prepare_nvfp4_layer_for_petit(layer: torch.nn.Module) -> None: + # 2. Call _import_petit_kernel() to trigger (or get) the import. + petit_kernel = _import_petit_kernel() + + # Repack weights to petit format + part_size_n = layer.output_size_per_partition + part_size_k = layer.input_size_per_partition + qweight = layer.weight.view(torch.int32).contiguous() + + # 3. Call functions through the imported module variable. + petit_qweight = petit_kernel.repack_nvfp4(qweight, + size_n=part_size_n, + size_k=part_size_k) + layer.weight = torch.nn.Parameter(petit_qweight, requires_grad=False) + + # Permute scales + weight_scale = petit_kernel.process_nvfp4_scales(scales=layer.weight_scale, + size_k=part_size_k, + size_n=part_size_n) + layer.weight_scale = torch.nn.Parameter(weight_scale, requires_grad=False) + + +def apply_petit_nvfp4_linear( + input: torch.Tensor, + weight: torch.Tensor, + weight_scale: torch.Tensor, + weight_scale_2: torch.Tensor, + size_n: int, + size_k: int, + bias: Optional[torch.Tensor] = None, +) -> torch.Tensor: + # Trigger (or get) the import here as well. + petit_kernel = _import_petit_kernel() + + reshaped_x = input.reshape(-1, input.shape[-1]) + out_shape = input.shape[:-1] + (size_n, ) + + # TODO: Use auto-tuning to find the performant solution_id + # Call the function via the module variable. + output = petit_kernel.mul_nvfp4_a16( + a=reshaped_x, + b=weight, + s=weight_scale, + global_scale=weight_scale_2, + size_m=reshaped_x.size(0), + size_n=size_n, + size_k=size_k, + solution_id=-1, + ) + if bias is not None: + output.add_(bias) # In-place add + + return output.reshape(out_shape) diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py index 317bc401a799..323ec591c50a 100644 --- a/vllm/platforms/rocm.py +++ b/vllm/platforms/rocm.py @@ -171,7 +171,7 @@ class RocmPlatform(Platform): supported_quantization: list[str] = [ "awq", "gptq", "fp8", "compressed-tensors", "fbgemm_fp8", "gguf", - "quark", "ptpc_fp8", "mxfp4" + "quark", "ptpc_fp8", "mxfp4", "petit_nvfp4" ] @classmethod