mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 14:56:08 +08:00
Support FP8 Quantization and Inference Run on Intel Gaudi (HPU) using INC (Intel Neural Compressor) (#12010)
Signed-off-by: Nir David <ndavid@habana.ai> Signed-off-by: Uri Livne <ulivne@habana.ai> Co-authored-by: Uri Livne <ulivne@habana.ai>
This commit is contained in:
parent
ac2bf41e53
commit
01513a334a
@ -10,6 +10,7 @@ Contents:
|
||||
- [BitBLAS](bitblas.md)
|
||||
- [GGUF](gguf.md)
|
||||
- [GPTQModel](gptqmodel.md)
|
||||
- [INC](inc.md)
|
||||
- [INT4 W4A16](int4.md)
|
||||
- [INT8 W8A8](int8.md)
|
||||
- [FP8 W8A8](fp8.md)
|
||||
|
||||
56
docs/features/quantization/inc.md
Normal file
56
docs/features/quantization/inc.md
Normal file
@ -0,0 +1,56 @@
|
||||
---
|
||||
title: FP8 INC
|
||||
---
|
||||
[](){ #inc }
|
||||
|
||||
vLLM supports FP8 (8-bit floating point) weight and activation quantization using Intel® Neural Compressor (INC) on Intel® Gaudi® 2 and Intel® Gaudi® 3 AI accelerators.
|
||||
Currently, quantization is validated only in Llama models.
|
||||
|
||||
Intel Gaudi supports quantization of various modules and functions, including, but not limited to `Linear`, `KVCache`, `Matmul` and `Softmax`. For more information, please refer to:
|
||||
[Supported Modules\\Supported Functions\\Custom Patched Modules](https://docs.habana.ai/en/latest/PyTorch/Inference_on_PyTorch/Quantization/Inference_Using_FP8.html#supported-modules).
|
||||
|
||||
!!! note
|
||||
Measurement files are required to run quantized models with vLLM on Gaudi accelerators. The FP8 model calibration procedure is described in the [vllm-hpu-extention](https://github.com/HabanaAI/vllm-hpu-extension/tree/main/calibration/README.md) package.
|
||||
|
||||
!!! note
|
||||
`QUANT_CONFIG` is an environment variable that points to the measurement or quantization [JSON config file](https://docs.habana.ai/en/latest/PyTorch/Inference_on_PyTorch/Quantization/Inference_Using_FP8.html#supported-json-config-file-options).
|
||||
The measurement configuration file is used during the calibration procedure to collect measurements for a given model. The quantization configuration is used during inference.
|
||||
|
||||
## Run Online Inference Using FP8
|
||||
|
||||
Once you've completed the model calibration process and collected the measurements, you can run FP8 inference with vLLM using the following command:
|
||||
|
||||
```bash
|
||||
export QUANT_CONFIG=/path/to/quant/config/inc/meta-llama-3.1-405b-instruct/maxabs_measure_g3.json
|
||||
vllm serve meta-llama/Llama-3.1-405B-Instruct --quantization inc --kv-cache-dtype fp8_inc --tensor_paralel_size 8
|
||||
```
|
||||
|
||||
!!! tip
|
||||
If you are just prototyping or testing your model with FP8, you can use the `VLLM_SKIP_WARMUP=true` environment variable to disable the warmup stage, which can take a long time. However, we do not recommend disabling this feature in production environments as it causes a significant performance drop.
|
||||
|
||||
!!! tip
|
||||
When using FP8 models, you may experience timeouts caused by the long compilation time of FP8 operations. To mitigate this problem, you can use the below environment variables:
|
||||
`VLLM_ENGINE_ITERATION_TIMEOUT_S` - to adjust the vLLM server timeout. You can set the value in seconds, e.g., 600 equals 10 minutes.
|
||||
`VLLM_RPC_TIMEOUT` - to adjust the RPC protocol timeout used by the OpenAI-compatible API. This value is in microseconds, e.g., 600000 equals 10 minutes.
|
||||
|
||||
## Run Offline Inference Using FP8
|
||||
|
||||
To run offline inference (after completing the model calibration process):
|
||||
|
||||
* Set the "QUANT_CONFIG" environment variable to point to a JSON configuration file with QUANTIZE mode.
|
||||
* Pass `quantization=inc` and `kv_cache_dtype=fp8_inc` as parameters to the `LLM` object.
|
||||
* Call shutdown method of the model_executor at the end of the run.
|
||||
|
||||
```python
|
||||
from vllm import LLM
|
||||
llm = LLM("llama3.1/Meta-Llama-3.1-8B-Instruct", quantization="inc", kv_cache_dtype="fp8_inc")
|
||||
...
|
||||
# Call llm.generate on the required prompts and sampling params.
|
||||
...
|
||||
llm.llm_engine.model_executor.shutdown()
|
||||
```
|
||||
|
||||
## Device for the Model's Weights Uploading
|
||||
|
||||
The unquantized weights are first loaded onto the CPU, then quantized and transferred to the target device (HPU) for model execution.
|
||||
This reduces the device memory footprint of model weights, as only quantized weights are stored in the device memory.
|
||||
@ -2,18 +2,19 @@
|
||||
|
||||
The table below shows the compatibility of various quantization implementations with different hardware platforms in vLLM:
|
||||
|
||||
| Implementation | Volta | Turing | Ampere | Ada | Hopper | AMD GPU | Intel GPU | x86 CPU | AWS Neuron | Google TPU |
|
||||
|-----------------------|---------|----------|----------|-------|----------|-----------|-------------|-----------|------------------|--------------|
|
||||
| AWQ | ❌ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ❌ | ✅︎ | ✅︎ | ❌ | ❌ |
|
||||
| GPTQ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ❌ | ✅︎ | ✅︎ | ❌ | ❌ |
|
||||
| Marlin (GPTQ/AWQ/FP8) | ❌ | ❌ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
| INT8 (W8A8) | ❌ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ✅︎ | ✅︎ | ✅︎ |
|
||||
| FP8 (W8A8) | ❌ | ❌ | ❌ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ✅︎ | ❌ |
|
||||
| BitBLAS (GPTQ) | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
| AQLM | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
| bitsandbytes | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
| DeepSpeedFP | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
| GGUF | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ❌ | ❌ |
|
||||
| Implementation | Volta | Turing | Ampere | Ada | Hopper | AMD GPU | Intel GPU | Intel Gaudi | x86 CPU | AWS Neuron | Google TPU |
|
||||
|-----------------------|---------|----------|----------|-------|----------|-----------|-------------|-------------|-----------|--------------|--------------|
|
||||
| AWQ | ❌ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ❌ | ✅︎ | ❌ | ✅︎ | ❌ | ❌ |
|
||||
| GPTQ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ❌ | ✅︎ | ❌ | ✅︎ | ❌ | ❌ |
|
||||
| Marlin (GPTQ/AWQ/FP8) | ❌ | ❌ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
| INT8 (W8A8) | ❌ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ❌ | ✅︎ | ✅︎ | ✅︎ |
|
||||
| FP8 (W8A8) | ❌ | ❌ | ❌ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ❌ | ✅︎ | ❌ |
|
||||
| BitBLAS (GPTQ) | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
| AQLM | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
| bitsandbytes | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
| DeepSpeedFP | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
| GGUF | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
| INC (W8A8) | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ✅︎ | ❌ | ❌ | ❌ |
|
||||
|
||||
- Volta refers to SM 7.0, Turing to SM 7.5, Ampere to SM 8.0/8.6, Ada to SM 8.9, and Hopper to SM 9.0.
|
||||
- ✅︎ indicates that the quantization method is supported on the specified hardware.
|
||||
|
||||
@ -28,7 +28,7 @@ To verify that the Intel Gaudi software was correctly installed, run:
|
||||
hl-smi # verify that hl-smi is in your PATH and each Gaudi accelerator is visible
|
||||
apt list --installed | grep habana # verify that habanalabs-firmware-tools, habanalabs-graph, habanalabs-rdma-core, habanalabs-thunk and habanalabs-container-runtime are installed
|
||||
pip list | grep habana # verify that habana-torch-plugin, habana-torch-dataloader, habana-pyhlml and habana-media-loader are installed
|
||||
pip list | grep neural # verify that neural_compressor is installed
|
||||
pip list | grep neural # verify that neural_compressor_pt is installed
|
||||
```
|
||||
|
||||
Refer to [Intel Gaudi Software Stack Verification](https://docs.habana.ai/en/latest/Installation_Guide/SW_Verification.html#platform-upgrade)
|
||||
@ -120,12 +120,13 @@ docker run \
|
||||
- Inference with [HPU Graphs](https://docs.habana.ai/en/latest/PyTorch/Inference_on_PyTorch/Inference_Using_HPU_Graphs.html)
|
||||
for accelerating low-batch latency and throughput
|
||||
- Attention with Linear Biases (ALiBi)
|
||||
- INC quantization
|
||||
|
||||
### Unsupported features
|
||||
|
||||
- Beam search
|
||||
- LoRA adapters
|
||||
- Quantization
|
||||
- AWQ quantization
|
||||
- Prefill chunking (mixed-batch inferencing)
|
||||
|
||||
### Supported configurations
|
||||
|
||||
@ -963,7 +963,7 @@ class ModelConfig:
|
||||
optimized_quantization_methods = [
|
||||
"fp8", "marlin", "modelopt", "gptq_marlin_24", "gptq_marlin",
|
||||
"awq_marlin", "fbgemm_fp8", "compressed-tensors", "experts_int8",
|
||||
"quark", "modelopt_fp4", "bitblas", "gptq_bitblas"
|
||||
"quark", "modelopt_fp4", "bitblas", "gptq_bitblas", "inc"
|
||||
]
|
||||
if self.quantization is not None:
|
||||
self.quantization = cast(me_quant.QuantizationMethods,
|
||||
@ -1563,7 +1563,7 @@ class ModelConfig:
|
||||
|
||||
|
||||
BlockSize = Literal[1, 8, 16, 32, 64, 128]
|
||||
CacheDType = Literal["auto", "fp8", "fp8_e4m3", "fp8_e5m2"]
|
||||
CacheDType = Literal["auto", "fp8", "fp8_e4m3", "fp8_e5m2", "fp8_inc"]
|
||||
PrefixCachingHashAlgo = Literal["builtin", "sha256", "sha256_cbor_64bit"]
|
||||
|
||||
|
||||
@ -1593,7 +1593,7 @@ class CacheConfig:
|
||||
cache_dtype: CacheDType = "auto"
|
||||
"""Data type for kv cache storage. If "auto", will use model data type.
|
||||
CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. ROCm (AMD GPU) supports
|
||||
fp8 (=fp8_e4m3)."""
|
||||
fp8 (=fp8_e4m3). Intel Gaudi (HPU) supports fp8 (using fp8_inc)."""
|
||||
is_attention_free: bool = False
|
||||
"""Whether the model is attention-free. This is primarily set in
|
||||
`ModelConfig` and that value should be manually duplicated here."""
|
||||
@ -1691,7 +1691,7 @@ class CacheConfig:
|
||||
"Using fp8 data type to store kv cache. It reduces the GPU "
|
||||
"memory footprint and boosts the performance. "
|
||||
"Meanwhile, it may cause accuracy drop without a proper "
|
||||
"scaling factor")
|
||||
"scaling factor.")
|
||||
else:
|
||||
raise ValueError(f"Unknown kv cache dtype: {self.cache_dtype}")
|
||||
|
||||
@ -1781,6 +1781,9 @@ class LoadConfig:
|
||||
default_factory=dict)
|
||||
"""Extra config for model loader. This will be passed to the model loader
|
||||
corresponding to the chosen load_format."""
|
||||
device: Optional[str] = None
|
||||
"""Device to which model weights will be loaded, default to
|
||||
device_config.device"""
|
||||
ignore_patterns: Optional[Union[list[str], str]] = None
|
||||
"""The list of patterns to ignore when loading the model. Default to
|
||||
"original/**/*" to avoid repeated loading of llama's checkpoints."""
|
||||
@ -1907,7 +1910,7 @@ class ParallelConfig:
|
||||
or equal to the number of GPUs available, "mp" will be used to
|
||||
keep processing on a single host. Otherwise, this will default
|
||||
to "ray" if Ray is installed and fail otherwise. Note that tpu
|
||||
and hpu only support Ray for distributed inference."""
|
||||
only support Ray for distributed inference."""
|
||||
|
||||
worker_cls: str = "auto"
|
||||
"""The full name of the worker class to use. If "auto", the worker class
|
||||
|
||||
@ -139,6 +139,10 @@ def get_type_hints(type_hint: TypeHint) -> set[TypeHint]:
|
||||
return type_hints
|
||||
|
||||
|
||||
def is_online_quantization(quantization: Any) -> bool:
|
||||
return quantization in ["inc"]
|
||||
|
||||
|
||||
@functools.lru_cache(maxsize=30)
|
||||
def _compute_kwargs(cls: ConfigType) -> dict[str, Any]:
|
||||
cls_docs = get_attr_docs(cls)
|
||||
@ -960,6 +964,8 @@ class EngineArgs:
|
||||
return LoadConfig(
|
||||
load_format=self.load_format,
|
||||
download_dir=self.download_dir,
|
||||
device="cpu"
|
||||
if is_online_quantization(self.quantization) else None,
|
||||
model_loader_extra_config=self.model_loader_extra_config,
|
||||
ignore_patterns=self.ignore_patterns,
|
||||
use_tqdm_on_load=self.use_tqdm_on_load,
|
||||
@ -1359,7 +1365,9 @@ class EngineArgs:
|
||||
supported = False
|
||||
if current_platform.is_rocm() or (
|
||||
current_platform.is_cuda()
|
||||
and current_platform.is_device_capability(100)):
|
||||
and current_platform.is_device_capability(100)) or (
|
||||
current_platform.device_name
|
||||
== "hpu"): # handle hpu also for OOT platform
|
||||
supported = True
|
||||
elif fp8_attention and will_use_fa:
|
||||
from vllm.attention.utils.fa_utils import (
|
||||
|
||||
@ -36,6 +36,7 @@ QuantizationMethods = Literal[
|
||||
"torchao",
|
||||
"auto-round",
|
||||
"rtn",
|
||||
"inc",
|
||||
]
|
||||
QUANTIZATION_METHODS: list[str] = list(get_args(QuantizationMethods))
|
||||
|
||||
@ -104,6 +105,7 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]:
|
||||
from .gptq_marlin import GPTQMarlinConfig
|
||||
from .gptq_marlin_24 import GPTQMarlin24Config
|
||||
from .hqq_marlin import HQQMarlinConfig
|
||||
from .inc import INCConfig
|
||||
from .ipex_quant import IPEXConfig
|
||||
from .marlin import MarlinConfig
|
||||
from .modelopt import ModelOptFp8Config, ModelOptNvFp4Config
|
||||
@ -144,7 +146,8 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]:
|
||||
"moe_wna16": MoeWNA16Config,
|
||||
"torchao": TorchAOConfig,
|
||||
"auto-round": AutoRoundConfig,
|
||||
"rtn": RTNConfig
|
||||
"rtn": RTNConfig,
|
||||
"inc": INCConfig,
|
||||
}
|
||||
# Update the `method_to_config` with customized quantization methods.
|
||||
method_to_config.update(_CUSTOMIZED_METHOD_TO_QUANT_CONFIG)
|
||||
@ -157,4 +160,4 @@ __all__ = [
|
||||
"QuantizationMethods",
|
||||
"get_quantization_config",
|
||||
"QUANTIZATION_METHODS",
|
||||
]
|
||||
]
|
||||
|
||||
61
vllm/model_executor/layers/quantization/inc.py
Normal file
61
vllm/model_executor/layers/quantization/inc.py
Normal file
@ -0,0 +1,61 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
#
|
||||
# Intel Gaudi supports quantization of various modules and functions,
|
||||
# including, but not limited to `Linear`, `KVCache`, `Matmul` and `Softmax`.
|
||||
# During model loading,
|
||||
# INC will patch layers with quantization/dequantization operators.
|
||||
# Meanwhile, INC will convert original weight to target datatype
|
||||
# and loading to target device.
|
||||
# static scaling should be provided through Quant_CONFIG:
|
||||
# `QUANT_CONFIG` is an environment variable,
|
||||
# that points to the measurement or quantization JSON config file.
|
||||
# The measurement configuration file is used during the calibration procedure,
|
||||
# to collect measurements for a given model.
|
||||
# The quantization configuration is used during inference.
|
||||
# For more information, please refer to:
|
||||
# https://docs.habana.ai/en/v1.21.1/PyTorch/vLLM_Inference/vLLM_FP8_Inference.html
|
||||
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.layers.fused_moe.layer import (
|
||||
FusedMoE, UnquantizedFusedMoEMethod)
|
||||
from vllm.model_executor.layers.linear import (LinearBase,
|
||||
UnquantizedLinearMethod)
|
||||
from vllm.model_executor.layers.quantization import QuantizationMethods
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig, QuantizeMethodBase)
|
||||
|
||||
|
||||
class INCConfig(QuantizationConfig):
|
||||
"""Config class for FP8 using Intel Neural Compressor."""
|
||||
|
||||
@classmethod
|
||||
def get_name(cls) -> QuantizationMethods:
|
||||
return "inc"
|
||||
|
||||
@classmethod
|
||||
def get_supported_act_dtypes(cls) -> list[torch.dtype]:
|
||||
return [torch.bfloat16]
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: dict[str, Any]) -> "INCConfig":
|
||||
raise AssertionError
|
||||
|
||||
def get_quant_method(self, layer: torch.nn.Module,
|
||||
prefix: str) -> Optional["QuantizeMethodBase"]:
|
||||
if isinstance(layer, LinearBase):
|
||||
return UnquantizedLinearMethod()
|
||||
elif isinstance(layer, FusedMoE):
|
||||
return UnquantizedFusedMoEMethod(layer.moe_config)
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
raise AssertionError
|
||||
|
||||
@staticmethod
|
||||
def get_config_filenames() -> list[str]:
|
||||
return []
|
||||
@ -6,9 +6,12 @@ import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm.config import LoadConfig, ModelConfig, VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.model_loader.utils import (
|
||||
initialize_model, process_weights_after_loading, set_default_torch_dtype)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class BaseModelLoader(ABC):
|
||||
"""Base class for model loaders."""
|
||||
@ -32,11 +35,16 @@ class BaseModelLoader(ABC):
|
||||
model_config: ModelConfig) -> nn.Module:
|
||||
"""Load a model with the given configurations."""
|
||||
device_config = vllm_config.device_config
|
||||
target_device = torch.device(device_config.device)
|
||||
load_config = vllm_config.load_config
|
||||
load_device = device_config.device if load_config.device is None else \
|
||||
load_config.device
|
||||
target_device = torch.device(load_device)
|
||||
with set_default_torch_dtype(model_config.dtype):
|
||||
with target_device:
|
||||
model = initialize_model(vllm_config=vllm_config,
|
||||
model_config=model_config)
|
||||
|
||||
logger.debug("Loading weights on %s ...", load_device)
|
||||
# Quantization does not happen in `load_weights` but after it
|
||||
self.load_weights(model, model_config)
|
||||
process_weights_after_loading(model, model_config, target_device)
|
||||
|
||||
@ -152,8 +152,8 @@ def get_quant_config(model_config: ModelConfig,
|
||||
quant_cls = get_quantization_config(model_config.quantization)
|
||||
|
||||
# GGUF doesn't have config file
|
||||
if model_config.quantization == "gguf":
|
||||
return quant_cls.from_config({})
|
||||
if model_config.quantization in ("gguf", "inc"):
|
||||
return quant_cls()
|
||||
|
||||
# Read the quantization config from the HF model config, if available.
|
||||
hf_quant_config = getattr(model_config.hf_config, "quantization_config",
|
||||
|
||||
@ -179,6 +179,7 @@ STR_DTYPE_TO_TORCH_DTYPE = {
|
||||
"fp8_e4m3": torch.uint8,
|
||||
"fp8_e5m2": torch.uint8,
|
||||
"int8": torch.int8,
|
||||
"fp8_inc": torch.float8_e4m3fn,
|
||||
}
|
||||
|
||||
TORCH_DTYPE_TO_NUMPY_DTYPE = {
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user