Support FP8 Quantization and Inference Run on Intel Gaudi (HPU) using INC (Intel Neural Compressor) (#12010)

Signed-off-by: Nir David <ndavid@habana.ai>
Signed-off-by: Uri Livne <ulivne@habana.ai>
Co-authored-by: Uri Livne <ulivne@habana.ai>
This commit is contained in:
Nir David 2025-07-16 22:33:41 +03:00 committed by GitHub
parent ac2bf41e53
commit 01513a334a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 168 additions and 25 deletions

View File

@ -10,6 +10,7 @@ Contents:
- [BitBLAS](bitblas.md)
- [GGUF](gguf.md)
- [GPTQModel](gptqmodel.md)
- [INC](inc.md)
- [INT4 W4A16](int4.md)
- [INT8 W8A8](int8.md)
- [FP8 W8A8](fp8.md)

View File

@ -0,0 +1,56 @@
---
title: FP8 INC
---
[](){ #inc }
vLLM supports FP8 (8-bit floating point) weight and activation quantization using Intel® Neural Compressor (INC) on Intel® Gaudi® 2 and Intel® Gaudi® 3 AI accelerators.
Currently, quantization is validated only in Llama models.
Intel Gaudi supports quantization of various modules and functions, including, but not limited to `Linear`, `KVCache`, `Matmul` and `Softmax`. For more information, please refer to:
[Supported Modules\\Supported Functions\\Custom Patched Modules](https://docs.habana.ai/en/latest/PyTorch/Inference_on_PyTorch/Quantization/Inference_Using_FP8.html#supported-modules).
!!! note
Measurement files are required to run quantized models with vLLM on Gaudi accelerators. The FP8 model calibration procedure is described in the [vllm-hpu-extention](https://github.com/HabanaAI/vllm-hpu-extension/tree/main/calibration/README.md) package.
!!! note
`QUANT_CONFIG` is an environment variable that points to the measurement or quantization [JSON config file](https://docs.habana.ai/en/latest/PyTorch/Inference_on_PyTorch/Quantization/Inference_Using_FP8.html#supported-json-config-file-options).
The measurement configuration file is used during the calibration procedure to collect measurements for a given model. The quantization configuration is used during inference.
## Run Online Inference Using FP8
Once you've completed the model calibration process and collected the measurements, you can run FP8 inference with vLLM using the following command:
```bash
export QUANT_CONFIG=/path/to/quant/config/inc/meta-llama-3.1-405b-instruct/maxabs_measure_g3.json
vllm serve meta-llama/Llama-3.1-405B-Instruct --quantization inc --kv-cache-dtype fp8_inc --tensor_paralel_size 8
```
!!! tip
If you are just prototyping or testing your model with FP8, you can use the `VLLM_SKIP_WARMUP=true` environment variable to disable the warmup stage, which can take a long time. However, we do not recommend disabling this feature in production environments as it causes a significant performance drop.
!!! tip
When using FP8 models, you may experience timeouts caused by the long compilation time of FP8 operations. To mitigate this problem, you can use the below environment variables:
`VLLM_ENGINE_ITERATION_TIMEOUT_S` - to adjust the vLLM server timeout. You can set the value in seconds, e.g., 600 equals 10 minutes.
`VLLM_RPC_TIMEOUT` - to adjust the RPC protocol timeout used by the OpenAI-compatible API. This value is in microseconds, e.g., 600000 equals 10 minutes.
## Run Offline Inference Using FP8
To run offline inference (after completing the model calibration process):
* Set the "QUANT_CONFIG" environment variable to point to a JSON configuration file with QUANTIZE mode.
* Pass `quantization=inc` and `kv_cache_dtype=fp8_inc` as parameters to the `LLM` object.
* Call shutdown method of the model_executor at the end of the run.
```python
from vllm import LLM
llm = LLM("llama3.1/Meta-Llama-3.1-8B-Instruct", quantization="inc", kv_cache_dtype="fp8_inc")
...
# Call llm.generate on the required prompts and sampling params.
...
llm.llm_engine.model_executor.shutdown()
```
## Device for the Model's Weights Uploading
The unquantized weights are first loaded onto the CPU, then quantized and transferred to the target device (HPU) for model execution.
This reduces the device memory footprint of model weights, as only quantized weights are stored in the device memory.

View File

@ -2,18 +2,19 @@
The table below shows the compatibility of various quantization implementations with different hardware platforms in vLLM:
| Implementation | Volta | Turing | Ampere | Ada | Hopper | AMD GPU | Intel GPU | x86 CPU | AWS Neuron | Google TPU |
|-----------------------|---------|----------|----------|-------|----------|-----------|-------------|-----------|------------------|--------------|
| AWQ | ❌ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ❌ | ✅︎ | ✅︎ | ❌ | ❌ |
| GPTQ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ❌ | ✅︎ | ✅︎ | ❌ | ❌ |
| Marlin (GPTQ/AWQ/FP8) | ❌ | ❌ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ❌ | ❌ | ❌ |
| INT8 (W8A8) | ❌ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ✅︎ | ✅︎ | ✅︎ |
| FP8 (W8A8) | ❌ | ❌ | ❌ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ✅︎ | ❌ |
| BitBLAS (GPTQ) | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ❌ | ❌ | ❌ |
| AQLM | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ❌ | ❌ | ❌ |
| bitsandbytes | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ❌ | ❌ | ❌ |
| DeepSpeedFP | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ❌ | ❌ | ❌ |
| GGUF | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ❌ | ❌ |
| Implementation | Volta | Turing | Ampere | Ada | Hopper | AMD GPU | Intel GPU | Intel Gaudi | x86 CPU | AWS Neuron | Google TPU |
|-----------------------|---------|----------|----------|-------|----------|-----------|-------------|-------------|-----------|--------------|--------------|
| AWQ | ❌ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ❌ | ✅︎ | ❌ | ✅︎ | ❌ | ❌ |
| GPTQ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ❌ | ✅︎ | ❌ | ✅︎ | ❌ | ❌ |
| Marlin (GPTQ/AWQ/FP8) | ❌ | ❌ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
| INT8 (W8A8) | ❌ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ❌ | ✅︎ | ✅︎ | ✅︎ |
| FP8 (W8A8) | ❌ | ❌ | ❌ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ❌ | ✅︎ | ❌ |
| BitBLAS (GPTQ) | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
| AQLM | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
| bitsandbytes | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
| DeepSpeedFP | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
| GGUF | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ❌ | ❌ | ❌ |
| INC (W8A8) | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ✅︎ | ❌ | ❌ | ❌ |
- Volta refers to SM 7.0, Turing to SM 7.5, Ampere to SM 8.0/8.6, Ada to SM 8.9, and Hopper to SM 9.0.
- ✅︎ indicates that the quantization method is supported on the specified hardware.

View File

@ -28,7 +28,7 @@ To verify that the Intel Gaudi software was correctly installed, run:
hl-smi # verify that hl-smi is in your PATH and each Gaudi accelerator is visible
apt list --installed | grep habana # verify that habanalabs-firmware-tools, habanalabs-graph, habanalabs-rdma-core, habanalabs-thunk and habanalabs-container-runtime are installed
pip list | grep habana # verify that habana-torch-plugin, habana-torch-dataloader, habana-pyhlml and habana-media-loader are installed
pip list | grep neural # verify that neural_compressor is installed
pip list | grep neural # verify that neural_compressor_pt is installed
```
Refer to [Intel Gaudi Software Stack Verification](https://docs.habana.ai/en/latest/Installation_Guide/SW_Verification.html#platform-upgrade)
@ -120,12 +120,13 @@ docker run \
- Inference with [HPU Graphs](https://docs.habana.ai/en/latest/PyTorch/Inference_on_PyTorch/Inference_Using_HPU_Graphs.html)
for accelerating low-batch latency and throughput
- Attention with Linear Biases (ALiBi)
- INC quantization
### Unsupported features
- Beam search
- LoRA adapters
- Quantization
- AWQ quantization
- Prefill chunking (mixed-batch inferencing)
### Supported configurations

View File

@ -963,7 +963,7 @@ class ModelConfig:
optimized_quantization_methods = [
"fp8", "marlin", "modelopt", "gptq_marlin_24", "gptq_marlin",
"awq_marlin", "fbgemm_fp8", "compressed-tensors", "experts_int8",
"quark", "modelopt_fp4", "bitblas", "gptq_bitblas"
"quark", "modelopt_fp4", "bitblas", "gptq_bitblas", "inc"
]
if self.quantization is not None:
self.quantization = cast(me_quant.QuantizationMethods,
@ -1563,7 +1563,7 @@ class ModelConfig:
BlockSize = Literal[1, 8, 16, 32, 64, 128]
CacheDType = Literal["auto", "fp8", "fp8_e4m3", "fp8_e5m2"]
CacheDType = Literal["auto", "fp8", "fp8_e4m3", "fp8_e5m2", "fp8_inc"]
PrefixCachingHashAlgo = Literal["builtin", "sha256", "sha256_cbor_64bit"]
@ -1593,7 +1593,7 @@ class CacheConfig:
cache_dtype: CacheDType = "auto"
"""Data type for kv cache storage. If "auto", will use model data type.
CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. ROCm (AMD GPU) supports
fp8 (=fp8_e4m3)."""
fp8 (=fp8_e4m3). Intel Gaudi (HPU) supports fp8 (using fp8_inc)."""
is_attention_free: bool = False
"""Whether the model is attention-free. This is primarily set in
`ModelConfig` and that value should be manually duplicated here."""
@ -1691,7 +1691,7 @@ class CacheConfig:
"Using fp8 data type to store kv cache. It reduces the GPU "
"memory footprint and boosts the performance. "
"Meanwhile, it may cause accuracy drop without a proper "
"scaling factor")
"scaling factor.")
else:
raise ValueError(f"Unknown kv cache dtype: {self.cache_dtype}")
@ -1781,6 +1781,9 @@ class LoadConfig:
default_factory=dict)
"""Extra config for model loader. This will be passed to the model loader
corresponding to the chosen load_format."""
device: Optional[str] = None
"""Device to which model weights will be loaded, default to
device_config.device"""
ignore_patterns: Optional[Union[list[str], str]] = None
"""The list of patterns to ignore when loading the model. Default to
"original/**/*" to avoid repeated loading of llama's checkpoints."""
@ -1907,7 +1910,7 @@ class ParallelConfig:
or equal to the number of GPUs available, "mp" will be used to
keep processing on a single host. Otherwise, this will default
to "ray" if Ray is installed and fail otherwise. Note that tpu
and hpu only support Ray for distributed inference."""
only support Ray for distributed inference."""
worker_cls: str = "auto"
"""The full name of the worker class to use. If "auto", the worker class

View File

@ -139,6 +139,10 @@ def get_type_hints(type_hint: TypeHint) -> set[TypeHint]:
return type_hints
def is_online_quantization(quantization: Any) -> bool:
return quantization in ["inc"]
@functools.lru_cache(maxsize=30)
def _compute_kwargs(cls: ConfigType) -> dict[str, Any]:
cls_docs = get_attr_docs(cls)
@ -960,6 +964,8 @@ class EngineArgs:
return LoadConfig(
load_format=self.load_format,
download_dir=self.download_dir,
device="cpu"
if is_online_quantization(self.quantization) else None,
model_loader_extra_config=self.model_loader_extra_config,
ignore_patterns=self.ignore_patterns,
use_tqdm_on_load=self.use_tqdm_on_load,
@ -1359,7 +1365,9 @@ class EngineArgs:
supported = False
if current_platform.is_rocm() or (
current_platform.is_cuda()
and current_platform.is_device_capability(100)):
and current_platform.is_device_capability(100)) or (
current_platform.device_name
== "hpu"): # handle hpu also for OOT platform
supported = True
elif fp8_attention and will_use_fa:
from vllm.attention.utils.fa_utils import (

View File

@ -36,6 +36,7 @@ QuantizationMethods = Literal[
"torchao",
"auto-round",
"rtn",
"inc",
]
QUANTIZATION_METHODS: list[str] = list(get_args(QuantizationMethods))
@ -104,6 +105,7 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]:
from .gptq_marlin import GPTQMarlinConfig
from .gptq_marlin_24 import GPTQMarlin24Config
from .hqq_marlin import HQQMarlinConfig
from .inc import INCConfig
from .ipex_quant import IPEXConfig
from .marlin import MarlinConfig
from .modelopt import ModelOptFp8Config, ModelOptNvFp4Config
@ -144,7 +146,8 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]:
"moe_wna16": MoeWNA16Config,
"torchao": TorchAOConfig,
"auto-round": AutoRoundConfig,
"rtn": RTNConfig
"rtn": RTNConfig,
"inc": INCConfig,
}
# Update the `method_to_config` with customized quantization methods.
method_to_config.update(_CUSTOMIZED_METHOD_TO_QUANT_CONFIG)
@ -157,4 +160,4 @@ __all__ = [
"QuantizationMethods",
"get_quantization_config",
"QUANTIZATION_METHODS",
]
]

View File

@ -0,0 +1,61 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
#
# Intel Gaudi supports quantization of various modules and functions,
# including, but not limited to `Linear`, `KVCache`, `Matmul` and `Softmax`.
# During model loading,
# INC will patch layers with quantization/dequantization operators.
# Meanwhile, INC will convert original weight to target datatype
# and loading to target device.
# static scaling should be provided through Quant_CONFIG:
# `QUANT_CONFIG` is an environment variable,
# that points to the measurement or quantization JSON config file.
# The measurement configuration file is used during the calibration procedure,
# to collect measurements for a given model.
# The quantization configuration is used during inference.
# For more information, please refer to:
# https://docs.habana.ai/en/v1.21.1/PyTorch/vLLM_Inference/vLLM_FP8_Inference.html
from typing import Any, Optional
import torch
from vllm.model_executor.layers.fused_moe.layer import (
FusedMoE, UnquantizedFusedMoEMethod)
from vllm.model_executor.layers.linear import (LinearBase,
UnquantizedLinearMethod)
from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase)
class INCConfig(QuantizationConfig):
"""Config class for FP8 using Intel Neural Compressor."""
@classmethod
def get_name(cls) -> QuantizationMethods:
return "inc"
@classmethod
def get_supported_act_dtypes(cls) -> list[torch.dtype]:
return [torch.bfloat16]
@classmethod
def from_config(cls, config: dict[str, Any]) -> "INCConfig":
raise AssertionError
def get_quant_method(self, layer: torch.nn.Module,
prefix: str) -> Optional["QuantizeMethodBase"]:
if isinstance(layer, LinearBase):
return UnquantizedLinearMethod()
elif isinstance(layer, FusedMoE):
return UnquantizedFusedMoEMethod(layer.moe_config)
return None
@classmethod
def get_min_capability(cls) -> int:
raise AssertionError
@staticmethod
def get_config_filenames() -> list[str]:
return []

View File

@ -6,9 +6,12 @@ import torch
import torch.nn as nn
from vllm.config import LoadConfig, ModelConfig, VllmConfig
from vllm.logger import init_logger
from vllm.model_executor.model_loader.utils import (
initialize_model, process_weights_after_loading, set_default_torch_dtype)
logger = init_logger(__name__)
class BaseModelLoader(ABC):
"""Base class for model loaders."""
@ -32,11 +35,16 @@ class BaseModelLoader(ABC):
model_config: ModelConfig) -> nn.Module:
"""Load a model with the given configurations."""
device_config = vllm_config.device_config
target_device = torch.device(device_config.device)
load_config = vllm_config.load_config
load_device = device_config.device if load_config.device is None else \
load_config.device
target_device = torch.device(load_device)
with set_default_torch_dtype(model_config.dtype):
with target_device:
model = initialize_model(vllm_config=vllm_config,
model_config=model_config)
logger.debug("Loading weights on %s ...", load_device)
# Quantization does not happen in `load_weights` but after it
self.load_weights(model, model_config)
process_weights_after_loading(model, model_config, target_device)

View File

@ -152,8 +152,8 @@ def get_quant_config(model_config: ModelConfig,
quant_cls = get_quantization_config(model_config.quantization)
# GGUF doesn't have config file
if model_config.quantization == "gguf":
return quant_cls.from_config({})
if model_config.quantization in ("gguf", "inc"):
return quant_cls()
# Read the quantization config from the HF model config, if available.
hf_quant_config = getattr(model_config.hf_config, "quantization_config",

View File

@ -179,6 +179,7 @@ STR_DTYPE_TO_TORCH_DTYPE = {
"fp8_e4m3": torch.uint8,
"fp8_e5m2": torch.uint8,
"int8": torch.int8,
"fp8_inc": torch.float8_e4m3fn,
}
TORCH_DTYPE_TO_NUMPY_DTYPE = {