diff --git a/docs/features/quantization/README.md b/docs/features/quantization/README.md index c30abdab5d61..e8c3b1123078 100644 --- a/docs/features/quantization/README.md +++ b/docs/features/quantization/README.md @@ -10,6 +10,7 @@ Contents: - [BitBLAS](bitblas.md) - [GGUF](gguf.md) - [GPTQModel](gptqmodel.md) +- [INC](inc.md) - [INT4 W4A16](int4.md) - [INT8 W8A8](int8.md) - [FP8 W8A8](fp8.md) diff --git a/docs/features/quantization/inc.md b/docs/features/quantization/inc.md new file mode 100644 index 000000000000..d97a462f5432 --- /dev/null +++ b/docs/features/quantization/inc.md @@ -0,0 +1,56 @@ +--- +title: FP8 INC +--- +[](){ #inc } + +vLLM supports FP8 (8-bit floating point) weight and activation quantization using Intel® Neural Compressor (INC) on Intel® Gaudi® 2 and Intel® Gaudi® 3 AI accelerators. +Currently, quantization is validated only in Llama models. + +Intel Gaudi supports quantization of various modules and functions, including, but not limited to `Linear`, `KVCache`, `Matmul` and `Softmax`. For more information, please refer to: +[Supported Modules\\Supported Functions\\Custom Patched Modules](https://docs.habana.ai/en/latest/PyTorch/Inference_on_PyTorch/Quantization/Inference_Using_FP8.html#supported-modules). + +!!! note + Measurement files are required to run quantized models with vLLM on Gaudi accelerators. The FP8 model calibration procedure is described in the [vllm-hpu-extention](https://github.com/HabanaAI/vllm-hpu-extension/tree/main/calibration/README.md) package. + +!!! note + `QUANT_CONFIG` is an environment variable that points to the measurement or quantization [JSON config file](https://docs.habana.ai/en/latest/PyTorch/Inference_on_PyTorch/Quantization/Inference_Using_FP8.html#supported-json-config-file-options). + The measurement configuration file is used during the calibration procedure to collect measurements for a given model. The quantization configuration is used during inference. + +## Run Online Inference Using FP8 + +Once you've completed the model calibration process and collected the measurements, you can run FP8 inference with vLLM using the following command: + +```bash +export QUANT_CONFIG=/path/to/quant/config/inc/meta-llama-3.1-405b-instruct/maxabs_measure_g3.json +vllm serve meta-llama/Llama-3.1-405B-Instruct --quantization inc --kv-cache-dtype fp8_inc --tensor_paralel_size 8 +``` + +!!! tip + If you are just prototyping or testing your model with FP8, you can use the `VLLM_SKIP_WARMUP=true` environment variable to disable the warmup stage, which can take a long time. However, we do not recommend disabling this feature in production environments as it causes a significant performance drop. + +!!! tip + When using FP8 models, you may experience timeouts caused by the long compilation time of FP8 operations. To mitigate this problem, you can use the below environment variables: + `VLLM_ENGINE_ITERATION_TIMEOUT_S` - to adjust the vLLM server timeout. You can set the value in seconds, e.g., 600 equals 10 minutes. + `VLLM_RPC_TIMEOUT` - to adjust the RPC protocol timeout used by the OpenAI-compatible API. This value is in microseconds, e.g., 600000 equals 10 minutes. + +## Run Offline Inference Using FP8 + +To run offline inference (after completing the model calibration process): + +* Set the "QUANT_CONFIG" environment variable to point to a JSON configuration file with QUANTIZE mode. +* Pass `quantization=inc` and `kv_cache_dtype=fp8_inc` as parameters to the `LLM` object. +* Call shutdown method of the model_executor at the end of the run. + +```python +from vllm import LLM +llm = LLM("llama3.1/Meta-Llama-3.1-8B-Instruct", quantization="inc", kv_cache_dtype="fp8_inc") +... +# Call llm.generate on the required prompts and sampling params. +... +llm.llm_engine.model_executor.shutdown() +``` + +## Device for the Model's Weights Uploading + +The unquantized weights are first loaded onto the CPU, then quantized and transferred to the target device (HPU) for model execution. +This reduces the device memory footprint of model weights, as only quantized weights are stored in the device memory. diff --git a/docs/features/quantization/supported_hardware.md b/docs/features/quantization/supported_hardware.md index bb4fe5b54b57..70a6a499562a 100644 --- a/docs/features/quantization/supported_hardware.md +++ b/docs/features/quantization/supported_hardware.md @@ -2,18 +2,19 @@ The table below shows the compatibility of various quantization implementations with different hardware platforms in vLLM: -| Implementation | Volta | Turing | Ampere | Ada | Hopper | AMD GPU | Intel GPU | x86 CPU | AWS Neuron | Google TPU | -|-----------------------|---------|----------|----------|-------|----------|-----------|-------------|-----------|------------------|--------------| -| AWQ | ❌ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ❌ | ✅︎ | ✅︎ | ❌ | ❌ | -| GPTQ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ❌ | ✅︎ | ✅︎ | ❌ | ❌ | -| Marlin (GPTQ/AWQ/FP8) | ❌ | ❌ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ❌ | ❌ | ❌ | -| INT8 (W8A8) | ❌ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ✅︎ | ✅︎ | ✅︎ | -| FP8 (W8A8) | ❌ | ❌ | ❌ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ✅︎ | ❌ | -| BitBLAS (GPTQ) | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ❌ | ❌ | ❌ | -| AQLM | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ❌ | ❌ | ❌ | -| bitsandbytes | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ❌ | ❌ | ❌ | -| DeepSpeedFP | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ❌ | ❌ | ❌ | -| GGUF | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ❌ | ❌ | +| Implementation | Volta | Turing | Ampere | Ada | Hopper | AMD GPU | Intel GPU | Intel Gaudi | x86 CPU | AWS Neuron | Google TPU | +|-----------------------|---------|----------|----------|-------|----------|-----------|-------------|-------------|-----------|--------------|--------------| +| AWQ | ❌ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ❌ | ✅︎ | ❌ | ✅︎ | ❌ | ❌ | +| GPTQ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ❌ | ✅︎ | ❌ | ✅︎ | ❌ | ❌ | +| Marlin (GPTQ/AWQ/FP8) | ❌ | ❌ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | +| INT8 (W8A8) | ❌ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ❌ | ✅︎ | ✅︎ | ✅︎ | +| FP8 (W8A8) | ❌ | ❌ | ❌ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ❌ | ✅︎ | ❌ | +| BitBLAS (GPTQ) | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | +| AQLM | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | +| bitsandbytes | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | +| DeepSpeedFP | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | +| GGUF | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ❌ | ❌ | ❌ | +| INC (W8A8) | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ✅︎ | ❌ | ❌ | ❌ | - Volta refers to SM 7.0, Turing to SM 7.5, Ampere to SM 8.0/8.6, Ada to SM 8.9, and Hopper to SM 9.0. - ✅︎ indicates that the quantization method is supported on the specified hardware. diff --git a/docs/getting_started/installation/intel_gaudi.md b/docs/getting_started/installation/intel_gaudi.md index 09cffb29cb3e..0be0d02d0679 100644 --- a/docs/getting_started/installation/intel_gaudi.md +++ b/docs/getting_started/installation/intel_gaudi.md @@ -28,7 +28,7 @@ To verify that the Intel Gaudi software was correctly installed, run: hl-smi # verify that hl-smi is in your PATH and each Gaudi accelerator is visible apt list --installed | grep habana # verify that habanalabs-firmware-tools, habanalabs-graph, habanalabs-rdma-core, habanalabs-thunk and habanalabs-container-runtime are installed pip list | grep habana # verify that habana-torch-plugin, habana-torch-dataloader, habana-pyhlml and habana-media-loader are installed -pip list | grep neural # verify that neural_compressor is installed +pip list | grep neural # verify that neural_compressor_pt is installed ``` Refer to [Intel Gaudi Software Stack Verification](https://docs.habana.ai/en/latest/Installation_Guide/SW_Verification.html#platform-upgrade) @@ -120,12 +120,13 @@ docker run \ - Inference with [HPU Graphs](https://docs.habana.ai/en/latest/PyTorch/Inference_on_PyTorch/Inference_Using_HPU_Graphs.html) for accelerating low-batch latency and throughput - Attention with Linear Biases (ALiBi) +- INC quantization ### Unsupported features - Beam search - LoRA adapters -- Quantization +- AWQ quantization - Prefill chunking (mixed-batch inferencing) ### Supported configurations diff --git a/vllm/config.py b/vllm/config.py index 6c56ac1eec81..22f740171369 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -963,7 +963,7 @@ class ModelConfig: optimized_quantization_methods = [ "fp8", "marlin", "modelopt", "gptq_marlin_24", "gptq_marlin", "awq_marlin", "fbgemm_fp8", "compressed-tensors", "experts_int8", - "quark", "modelopt_fp4", "bitblas", "gptq_bitblas" + "quark", "modelopt_fp4", "bitblas", "gptq_bitblas", "inc" ] if self.quantization is not None: self.quantization = cast(me_quant.QuantizationMethods, @@ -1563,7 +1563,7 @@ class ModelConfig: BlockSize = Literal[1, 8, 16, 32, 64, 128] -CacheDType = Literal["auto", "fp8", "fp8_e4m3", "fp8_e5m2"] +CacheDType = Literal["auto", "fp8", "fp8_e4m3", "fp8_e5m2", "fp8_inc"] PrefixCachingHashAlgo = Literal["builtin", "sha256", "sha256_cbor_64bit"] @@ -1593,7 +1593,7 @@ class CacheConfig: cache_dtype: CacheDType = "auto" """Data type for kv cache storage. If "auto", will use model data type. CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. ROCm (AMD GPU) supports - fp8 (=fp8_e4m3).""" + fp8 (=fp8_e4m3). Intel Gaudi (HPU) supports fp8 (using fp8_inc).""" is_attention_free: bool = False """Whether the model is attention-free. This is primarily set in `ModelConfig` and that value should be manually duplicated here.""" @@ -1691,7 +1691,7 @@ class CacheConfig: "Using fp8 data type to store kv cache. It reduces the GPU " "memory footprint and boosts the performance. " "Meanwhile, it may cause accuracy drop without a proper " - "scaling factor") + "scaling factor.") else: raise ValueError(f"Unknown kv cache dtype: {self.cache_dtype}") @@ -1781,6 +1781,9 @@ class LoadConfig: default_factory=dict) """Extra config for model loader. This will be passed to the model loader corresponding to the chosen load_format.""" + device: Optional[str] = None + """Device to which model weights will be loaded, default to + device_config.device""" ignore_patterns: Optional[Union[list[str], str]] = None """The list of patterns to ignore when loading the model. Default to "original/**/*" to avoid repeated loading of llama's checkpoints.""" @@ -1907,7 +1910,7 @@ class ParallelConfig: or equal to the number of GPUs available, "mp" will be used to keep processing on a single host. Otherwise, this will default to "ray" if Ray is installed and fail otherwise. Note that tpu - and hpu only support Ray for distributed inference.""" + only support Ray for distributed inference.""" worker_cls: str = "auto" """The full name of the worker class to use. If "auto", the worker class diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 7b73060e3495..ae5eb46fa967 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -139,6 +139,10 @@ def get_type_hints(type_hint: TypeHint) -> set[TypeHint]: return type_hints +def is_online_quantization(quantization: Any) -> bool: + return quantization in ["inc"] + + @functools.lru_cache(maxsize=30) def _compute_kwargs(cls: ConfigType) -> dict[str, Any]: cls_docs = get_attr_docs(cls) @@ -960,6 +964,8 @@ class EngineArgs: return LoadConfig( load_format=self.load_format, download_dir=self.download_dir, + device="cpu" + if is_online_quantization(self.quantization) else None, model_loader_extra_config=self.model_loader_extra_config, ignore_patterns=self.ignore_patterns, use_tqdm_on_load=self.use_tqdm_on_load, @@ -1359,7 +1365,9 @@ class EngineArgs: supported = False if current_platform.is_rocm() or ( current_platform.is_cuda() - and current_platform.is_device_capability(100)): + and current_platform.is_device_capability(100)) or ( + current_platform.device_name + == "hpu"): # handle hpu also for OOT platform supported = True elif fp8_attention and will_use_fa: from vllm.attention.utils.fa_utils import ( diff --git a/vllm/model_executor/layers/quantization/__init__.py b/vllm/model_executor/layers/quantization/__init__.py index 60217ee86ad1..95aea912a150 100644 --- a/vllm/model_executor/layers/quantization/__init__.py +++ b/vllm/model_executor/layers/quantization/__init__.py @@ -36,6 +36,7 @@ QuantizationMethods = Literal[ "torchao", "auto-round", "rtn", + "inc", ] QUANTIZATION_METHODS: list[str] = list(get_args(QuantizationMethods)) @@ -104,6 +105,7 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]: from .gptq_marlin import GPTQMarlinConfig from .gptq_marlin_24 import GPTQMarlin24Config from .hqq_marlin import HQQMarlinConfig + from .inc import INCConfig from .ipex_quant import IPEXConfig from .marlin import MarlinConfig from .modelopt import ModelOptFp8Config, ModelOptNvFp4Config @@ -144,7 +146,8 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]: "moe_wna16": MoeWNA16Config, "torchao": TorchAOConfig, "auto-round": AutoRoundConfig, - "rtn": RTNConfig + "rtn": RTNConfig, + "inc": INCConfig, } # Update the `method_to_config` with customized quantization methods. method_to_config.update(_CUSTOMIZED_METHOD_TO_QUANT_CONFIG) @@ -157,4 +160,4 @@ __all__ = [ "QuantizationMethods", "get_quantization_config", "QUANTIZATION_METHODS", -] \ No newline at end of file +] diff --git a/vllm/model_executor/layers/quantization/inc.py b/vllm/model_executor/layers/quantization/inc.py new file mode 100644 index 000000000000..8aa1f1a14bfc --- /dev/null +++ b/vllm/model_executor/layers/quantization/inc.py @@ -0,0 +1,61 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# +# Intel Gaudi supports quantization of various modules and functions, +# including, but not limited to `Linear`, `KVCache`, `Matmul` and `Softmax`. +# During model loading, +# INC will patch layers with quantization/dequantization operators. +# Meanwhile, INC will convert original weight to target datatype +# and loading to target device. +# static scaling should be provided through Quant_CONFIG: +# `QUANT_CONFIG` is an environment variable, +# that points to the measurement or quantization JSON config file. +# The measurement configuration file is used during the calibration procedure, +# to collect measurements for a given model. +# The quantization configuration is used during inference. +# For more information, please refer to: +# https://docs.habana.ai/en/v1.21.1/PyTorch/vLLM_Inference/vLLM_FP8_Inference.html + +from typing import Any, Optional + +import torch + +from vllm.model_executor.layers.fused_moe.layer import ( + FusedMoE, UnquantizedFusedMoEMethod) +from vllm.model_executor.layers.linear import (LinearBase, + UnquantizedLinearMethod) +from vllm.model_executor.layers.quantization import QuantizationMethods +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig, QuantizeMethodBase) + + +class INCConfig(QuantizationConfig): + """Config class for FP8 using Intel Neural Compressor.""" + + @classmethod + def get_name(cls) -> QuantizationMethods: + return "inc" + + @classmethod + def get_supported_act_dtypes(cls) -> list[torch.dtype]: + return [torch.bfloat16] + + @classmethod + def from_config(cls, config: dict[str, Any]) -> "INCConfig": + raise AssertionError + + def get_quant_method(self, layer: torch.nn.Module, + prefix: str) -> Optional["QuantizeMethodBase"]: + if isinstance(layer, LinearBase): + return UnquantizedLinearMethod() + elif isinstance(layer, FusedMoE): + return UnquantizedFusedMoEMethod(layer.moe_config) + return None + + @classmethod + def get_min_capability(cls) -> int: + raise AssertionError + + @staticmethod + def get_config_filenames() -> list[str]: + return [] diff --git a/vllm/model_executor/model_loader/base_loader.py b/vllm/model_executor/model_loader/base_loader.py index 5018c7d9a360..4cf6c7988960 100644 --- a/vllm/model_executor/model_loader/base_loader.py +++ b/vllm/model_executor/model_loader/base_loader.py @@ -6,9 +6,12 @@ import torch import torch.nn as nn from vllm.config import LoadConfig, ModelConfig, VllmConfig +from vllm.logger import init_logger from vllm.model_executor.model_loader.utils import ( initialize_model, process_weights_after_loading, set_default_torch_dtype) +logger = init_logger(__name__) + class BaseModelLoader(ABC): """Base class for model loaders.""" @@ -32,11 +35,16 @@ class BaseModelLoader(ABC): model_config: ModelConfig) -> nn.Module: """Load a model with the given configurations.""" device_config = vllm_config.device_config - target_device = torch.device(device_config.device) + load_config = vllm_config.load_config + load_device = device_config.device if load_config.device is None else \ + load_config.device + target_device = torch.device(load_device) with set_default_torch_dtype(model_config.dtype): with target_device: model = initialize_model(vllm_config=vllm_config, model_config=model_config) + + logger.debug("Loading weights on %s ...", load_device) # Quantization does not happen in `load_weights` but after it self.load_weights(model, model_config) process_weights_after_loading(model, model_config, target_device) diff --git a/vllm/model_executor/model_loader/weight_utils.py b/vllm/model_executor/model_loader/weight_utils.py index 178b37d7d70b..64a2089921ee 100644 --- a/vllm/model_executor/model_loader/weight_utils.py +++ b/vllm/model_executor/model_loader/weight_utils.py @@ -152,8 +152,8 @@ def get_quant_config(model_config: ModelConfig, quant_cls = get_quantization_config(model_config.quantization) # GGUF doesn't have config file - if model_config.quantization == "gguf": - return quant_cls.from_config({}) + if model_config.quantization in ("gguf", "inc"): + return quant_cls() # Read the quantization config from the HF model config, if available. hf_quant_config = getattr(model_config.hf_config, "quantization_config", diff --git a/vllm/utils/__init__.py b/vllm/utils/__init__.py index c18f1d12ba97..bbcc2a523dcb 100644 --- a/vllm/utils/__init__.py +++ b/vllm/utils/__init__.py @@ -179,6 +179,7 @@ STR_DTYPE_TO_TORCH_DTYPE = { "fp8_e4m3": torch.uint8, "fp8_e5m2": torch.uint8, "int8": torch.int8, + "fp8_inc": torch.float8_e4m3fn, } TORCH_DTYPE_TO_NUMPY_DTYPE = {