mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 15:36:29 +08:00
[Doc] Add docs for llmcompressor INT8 and FP8 checkpoints (#7444)
This commit is contained in:
parent
93478b63d2
commit
b3f4e17935
@ -107,6 +107,7 @@ Documentation
|
||||
quantization/supported_hardware
|
||||
quantization/auto_awq
|
||||
quantization/bnb
|
||||
quantization/int8
|
||||
quantization/fp8
|
||||
quantization/fp8_e5m2_kvcache
|
||||
quantization/fp8_e4m3_kvcache
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
.. _fp8:
|
||||
|
||||
FP8
|
||||
FP8 W8A8
|
||||
==================
|
||||
|
||||
vLLM supports FP8 (8-bit floating point) weight and activation quantization using hardware acceleration on GPUs such as Nvidia H100 and AMD MI300x.
|
||||
@ -15,6 +15,11 @@ The FP8 types typically supported in hardware have two distinct representations,
|
||||
- **E4M3**: Consists of 1 sign bit, 4 exponent bits, and 3 bits of mantissa. It can store values up to +/-448 and ``nan``.
|
||||
- **E5M2**: Consists of 1 sign bit, 5 exponent bits, and 2 bits of mantissa. It can store values up to +/-57344, +/- ``inf``, and ``nan``. The tradeoff for the increased dynamic range is lower precision of the stored values.
|
||||
|
||||
.. note::
|
||||
|
||||
FP8 computation is supported on NVIDIA GPUs with compute capability > 8.9 (Ada Lovelace, Hopper).
|
||||
FP8 models will run on compute capability > 8.0 (Ampere) as weight-only W8A16, utilizing FP8 Marlin.
|
||||
|
||||
Quick Start with Online Dynamic Quantization
|
||||
--------------------------------------------
|
||||
|
||||
@ -33,10 +38,122 @@ In this mode, all Linear modules (except for the final ``lm_head``) have their w
|
||||
|
||||
Currently, we load the model at original precision before quantizing down to 8-bits, so you need enough memory to load the whole model.
|
||||
|
||||
Offline Quantization
|
||||
Installation
|
||||
------------
|
||||
|
||||
To produce performant FP8 quantized models with vLLM, you'll need to install the `llm-compressor <https://github.com/vllm-project/llm-compressor/>`_ library:
|
||||
|
||||
.. code-block:: console
|
||||
|
||||
$ pip install llmcompressor==0.1.0
|
||||
|
||||
Quantization Process
|
||||
--------------------
|
||||
|
||||
For offline quantization to FP8, please install the `AutoFP8 library <https://github.com/neuralmagic/autofp8>`_.
|
||||
The quantization process involves three main steps:
|
||||
|
||||
1. Loading the model
|
||||
2. Applying quantization
|
||||
3. Evaluating accuracy in vLLM
|
||||
|
||||
1. Loading the Model
|
||||
^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
Use ``SparseAutoModelForCausalLM``, which wraps ``AutoModelForCausalLM``, for saving and loading quantized models:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from llmcompressor.transformers import SparseAutoModelForCausalLM
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
MODEL_ID = "meta-llama/Meta-Llama-3-8B-Instruct"
|
||||
|
||||
model = SparseAutoModelForCausalLM.from_pretrained(
|
||||
MODEL_ID, device_map="auto", torch_dtype="auto")
|
||||
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
|
||||
|
||||
2. Applying Quantization
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
For FP8 quantization, we can recover accuracy with simple RTN quantization. We recommend targeting all ``Linear`` layers using the ``FP8_DYNAMIC`` scheme, which uses:
|
||||
|
||||
- Static, per-channel quantization on the weights
|
||||
- Dynamic, per-token quantization on the activations
|
||||
|
||||
Since simple RTN does not require data for weight quantization and the activations are quantized dynamically, we do not need any calibration data for this quantization flow.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from llmcompressor.transformers import oneshot
|
||||
from llmcompressor.modifiers.quantization import QuantizationModifier
|
||||
|
||||
# Configure the simple PTQ quantization
|
||||
recipe = QuantizationModifier(
|
||||
targets="Linear", scheme="FP8_DYNAMIC", ignore=["lm_head"])
|
||||
|
||||
# Apply the quantization algorithm.
|
||||
oneshot(model=model, recipe=recipe)
|
||||
|
||||
# Save the model.
|
||||
SAVE_DIR = MODEL_ID.split("/")[1] + "-FP8-Dynamic"
|
||||
model.save_pretrained(SAVE_DIR)
|
||||
tokenizer.save_pretrained(SAVE_DIR)
|
||||
|
||||
3. Evaluating Accuracy
|
||||
^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
Install ``vllm`` and ``lm-evaluation-harness``:
|
||||
|
||||
.. code-block:: console
|
||||
|
||||
$ pip install vllm lm_eval==0.4.3
|
||||
|
||||
Load and run the model in ``vllm``:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from vllm import LLM
|
||||
model = LLM("./Meta-Llama-3-8B-Instruct-FP8-Dynamic")
|
||||
model.generate("Hello my name is")
|
||||
|
||||
Evaluate accuracy with ``lm_eval`` (for example on 250 samples of ``gsm8k``):
|
||||
|
||||
.. note::
|
||||
|
||||
Quantized models can be sensitive to the presence of the ``bos`` token. ``lm_eval`` does not add a ``bos`` token by default, so make sure to include the ``add_bos_token=True`` argument when running your evaluations.
|
||||
|
||||
.. code-block:: console
|
||||
|
||||
$ MODEL=$PWD/Meta-Llama-3-8B-Instruct-FP8-Dynamic
|
||||
$ lm_eval \
|
||||
--model vllm \
|
||||
--model_args pretrained=$MODEL,add_bos_token=True \
|
||||
--tasks gsm8k --num_fewshot 5 --batch_size auto --limit 250
|
||||
|
||||
Here's an example of the resulting scores:
|
||||
|
||||
.. code-block:: text
|
||||
|
||||
|Tasks|Version| Filter |n-shot| Metric | |Value| |Stderr|
|
||||
|-----|------:|----------------|-----:|-----------|---|----:|---|-----:|
|
||||
|gsm8k| 3|flexible-extract| 5|exact_match|↑ |0.768|± |0.0268|
|
||||
| | |strict-match | 5|exact_match|↑ |0.768|± |0.0268|
|
||||
|
||||
Troubleshooting and Support
|
||||
---------------------------
|
||||
|
||||
If you encounter any issues or have feature requests, please open an issue on the ``vllm-project/llm-compressor`` GitHub repository.
|
||||
|
||||
|
||||
Deprecated Flow
|
||||
------------------
|
||||
|
||||
.. note::
|
||||
|
||||
The following information is preserved for reference and search purposes.
|
||||
The quantization method described below is deprecated in favor of the ``llmcompressor`` method described above.
|
||||
|
||||
For static per-tensor offline quantization to FP8, please install the `AutoFP8 library <https://github.com/neuralmagic/autofp8>`_.
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
@ -45,94 +162,10 @@ For offline quantization to FP8, please install the `AutoFP8 library <https://gi
|
||||
|
||||
This package introduces the ``AutoFP8ForCausalLM`` and ``BaseQuantizeConfig`` objects for managing how your model will be compressed.
|
||||
|
||||
Offline Quantization with Dynamic Activation Scaling Factors
|
||||
------------------------------------------------------------
|
||||
|
||||
You can use AutoFP8 to produce checkpoints with their weights quantized to FP8 ahead of time and let vLLM handle calculating dynamic scales for the activations at runtime for maximum accuracy. You can enable this with the ``activation_scheme="dynamic"`` argument.
|
||||
|
||||
.. warning::
|
||||
|
||||
Please note that although this mode doesn't give you better performance, it reduces memory footprint compared to online quantization.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from auto_fp8 import AutoFP8ForCausalLM, BaseQuantizeConfig
|
||||
|
||||
pretrained_model_dir = "meta-llama/Meta-Llama-3-8B-Instruct"
|
||||
quantized_model_dir = "Meta-Llama-3-8B-Instruct-FP8-Dynamic"
|
||||
|
||||
# Define quantization config with static activation scales
|
||||
quantize_config = BaseQuantizeConfig(quant_method="fp8", activation_scheme="dynamic")
|
||||
# For dynamic activation scales, there is no need for calbration examples
|
||||
examples = []
|
||||
|
||||
# Load the model, quantize, and save checkpoint
|
||||
model = AutoFP8ForCausalLM.from_pretrained(pretrained_model_dir, quantize_config)
|
||||
model.quantize(examples)
|
||||
model.save_quantized(quantized_model_dir)
|
||||
|
||||
In the output of the above script, you should be able to see the quantized Linear modules (FP8DynamicLinear) replaced in the model definition.
|
||||
Note that the ``lm_head`` Linear module at the end is currently skipped by default.
|
||||
|
||||
.. code-block:: text
|
||||
|
||||
LlamaForCausalLM(
|
||||
(model): LlamaModel(
|
||||
(embed_tokens): Embedding(128256, 4096)
|
||||
(layers): ModuleList(
|
||||
(0-31): 32 x LlamaDecoderLayer(
|
||||
(self_attn): LlamaSdpaAttention(
|
||||
(q_proj): FP8DynamicLinear()
|
||||
(k_proj): FP8DynamicLinear()
|
||||
(v_proj): FP8DynamicLinear()
|
||||
(o_proj): FP8DynamicLinear()
|
||||
(rotary_emb): LlamaRotaryEmbedding()
|
||||
)
|
||||
(mlp): LlamaMLP(
|
||||
(gate_proj): FP8DynamicLinear()
|
||||
(up_proj): FP8DynamicLinear()
|
||||
(down_proj): FP8DynamicLinear()
|
||||
(act_fn): SiLU()
|
||||
)
|
||||
(input_layernorm): LlamaRMSNorm()
|
||||
(post_attention_layernorm): LlamaRMSNorm()
|
||||
)
|
||||
)
|
||||
(norm): LlamaRMSNorm()
|
||||
)
|
||||
(lm_head): Linear(in_features=4096, out_features=128256, bias=False)
|
||||
)
|
||||
Saving the model to Meta-Llama-3-8B-Instruct-FP8-Dynamic
|
||||
|
||||
Your model checkpoint with quantized weights should be available at ``Meta-Llama-3-8B-Instruct-FP8/``.
|
||||
We can see that the weights are smaller than the original BF16 precision.
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
ls -lh Meta-Llama-3-8B-Instruct-FP8-Dynamic/
|
||||
total 8.5G
|
||||
-rw-rw-r-- 1 user user 869 Jun 7 14:43 config.json
|
||||
-rw-rw-r-- 1 user user 194 Jun 7 14:43 generation_config.json
|
||||
-rw-rw-r-- 1 user user 4.7G Jun 7 14:43 model-00001-of-00002.safetensors
|
||||
-rw-rw-r-- 1 user user 3.9G Jun 7 14:43 model-00002-of-00002.safetensors
|
||||
-rw-rw-r-- 1 user user 43K Jun 7 14:43 model.safetensors.index.json
|
||||
-rw-rw-r-- 1 user user 296 Jun 7 14:43 special_tokens_map.json
|
||||
-rw-rw-r-- 1 user user 50K Jun 7 14:43 tokenizer_config.json
|
||||
-rw-rw-r-- 1 user user 8.7M Jun 7 14:43 tokenizer.json
|
||||
|
||||
Finally, you can load the quantized model checkpoint directly in vLLM.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from vllm import LLM
|
||||
model = LLM(model="Meta-Llama-3-8B-Instruct-FP8-Dynamic/")
|
||||
# INFO 06-10 21:15:41 model_runner.py:159] Loading model weights took 8.4596 GB
|
||||
result = model.generate("Hello, my name is")
|
||||
|
||||
Offline Quantization with Static Activation Scaling Factors
|
||||
-----------------------------------------------------------
|
||||
|
||||
For the best inference performance, you can use AutoFP8 with calibration data to produce per-tensor static scales for both the weights and activations by enabling the ``activation_scheme="static"`` argument.
|
||||
You can use AutoFP8 with calibration data to produce per-tensor static scales for both the weights and activations by enabling the ``activation_scheme="static"`` argument.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
@ -169,41 +202,3 @@ Finally, you can load the quantized model checkpoint directly in vLLM.
|
||||
# INFO 06-10 21:15:41 model_runner.py:159] Loading model weights took 8.4596 GB
|
||||
result = model.generate("Hello, my name is")
|
||||
|
||||
FP8 checkpoint structure explanation
|
||||
-----------------------------------------------------------
|
||||
|
||||
Here we detail the structure for the FP8 checkpoints.
|
||||
|
||||
The following is necessary to be present in the model's ``config.json``:
|
||||
|
||||
.. code-block:: text
|
||||
|
||||
"quantization_config": {
|
||||
"quant_method": "fp8",
|
||||
"activation_scheme": "static" or "dynamic"
|
||||
}
|
||||
|
||||
|
||||
Each quantized layer in the state_dict will have these tensors:
|
||||
|
||||
* If the config has ``"activation_scheme": "static"``:
|
||||
|
||||
.. code-block:: text
|
||||
|
||||
model.layers.0.mlp.down_proj.weight < F8_E4M3
|
||||
model.layers.0.mlp.down_proj.input_scale < F32
|
||||
model.layers.0.mlp.down_proj.weight_scale < F32
|
||||
|
||||
* If the config has ``"activation_scheme": "dynamic"``:
|
||||
|
||||
.. code-block:: text
|
||||
|
||||
model.layers.0.mlp.down_proj.weight < F8_E4M3
|
||||
model.layers.0.mlp.down_proj.weight_scale < F32
|
||||
|
||||
|
||||
Additionally, there can be `FP8 kv-cache scaling factors <https://github.com/vllm-project/vllm/pull/4893>`_ contained within quantized checkpoints specified through the ``.kv_scale`` parameter present on the Attention Module, such as:
|
||||
|
||||
.. code-block:: text
|
||||
|
||||
model.layers.0.self_attn.kv_scale < F32
|
||||
|
||||
145
docs/source/quantization/int8.rst
Normal file
145
docs/source/quantization/int8.rst
Normal file
@ -0,0 +1,145 @@
|
||||
.. _int8:
|
||||
|
||||
INT8 W8A8
|
||||
==================
|
||||
|
||||
vLLM supports quantizing weights and activations to INT8 for memory savings and inference acceleration.
|
||||
This quantization method is particularly useful for reducing model size while maintaining good performance.
|
||||
|
||||
Please visit the HF collection of `quantized INT8 checkpoints of popular LLMs ready to use with vLLM <https://huggingface.co/collections/neuralmagic/int8-llms-for-vllm-668ec32c049dca0369816415>`_.
|
||||
|
||||
.. note::
|
||||
|
||||
INT8 computation is supported on NVIDIA GPUs with compute capability > 7.5 (Turing, Ampere, Ada Lovelace, Hopper).
|
||||
|
||||
Prerequisites
|
||||
-------------
|
||||
|
||||
To use INT8 quantization with vLLM, you'll need to install the `llm-compressor <https://github.com/vllm-project/llm-compressor/>`_ library:
|
||||
|
||||
.. code-block:: console
|
||||
|
||||
$ pip install llmcompressor==0.1.0
|
||||
|
||||
Quantization Process
|
||||
--------------------
|
||||
|
||||
The quantization process involves four main steps:
|
||||
|
||||
1. Loading the model
|
||||
2. Preparing calibration data
|
||||
3. Applying quantization
|
||||
4. Evaluating accuracy in vLLM
|
||||
|
||||
1. Loading the Model
|
||||
^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
Use ``SparseAutoModelForCausalLM``, which wraps ``AutoModelForCausalLM``, for saving and loading quantized models:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from llmcompressor.transformers import SparseAutoModelForCausalLM
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
MODEL_ID = "meta-llama/Meta-Llama-3-8B-Instruct"
|
||||
model = SparseAutoModelForCausalLM.from_pretrained(
|
||||
MODEL_ID, device_map="auto", torch_dtype="auto",
|
||||
)
|
||||
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
|
||||
|
||||
2. Preparing Calibration Data
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
When quantizing activations to INT8, you need sample data to estimate the activation scales.
|
||||
It's best to use calibration data that closely matches your deployment data.
|
||||
For a general-purpose instruction-tuned model, you can use a dataset like ``ultrachat``:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from datasets import load_dataset
|
||||
|
||||
NUM_CALIBRATION_SAMPLES = 512
|
||||
MAX_SEQUENCE_LENGTH = 2048
|
||||
|
||||
# Load and preprocess the dataset
|
||||
ds = load_dataset("HuggingFaceH4/ultrachat_200k", split="train_sft")
|
||||
ds = ds.shuffle(seed=42).select(range(NUM_CALIBRATION_SAMPLES))
|
||||
|
||||
def preprocess(example):
|
||||
return {"text": tokenizer.apply_chat_template(example["messages"], tokenize=False)}
|
||||
ds = ds.map(preprocess)
|
||||
|
||||
def tokenize(sample):
|
||||
return tokenizer(sample["text"], padding=False, max_length=MAX_SEQUENCE_LENGTH, truncation=True, add_special_tokens=False)
|
||||
ds = ds.map(tokenize, remove_columns=ds.column_names)
|
||||
|
||||
3. Applying Quantization
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
Now, apply the quantization algorithms:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from llmcompressor.transformers import oneshot
|
||||
from llmcompressor.modifiers.quantization import GPTQModifier
|
||||
from llmcompressor.modifiers.smoothquant import SmoothQuantModifier
|
||||
|
||||
# Configure the quantization algorithms
|
||||
recipe = [
|
||||
SmoothQuantModifier(smoothing_strength=0.8),
|
||||
GPTQModifier(targets="Linear", scheme="W8A8", ignore=["lm_head"]),
|
||||
]
|
||||
|
||||
# Apply quantization
|
||||
oneshot(
|
||||
model=model,
|
||||
dataset=ds,
|
||||
recipe=recipe,
|
||||
max_seq_length=MAX_SEQUENCE_LENGTH,
|
||||
num_calibration_samples=NUM_CALIBRATION_SAMPLES,
|
||||
)
|
||||
|
||||
# Save the compressed model
|
||||
SAVE_DIR = MODEL_ID.split("/")[1] + "-W8A8-Dynamic-Per-Token"
|
||||
model.save_pretrained(SAVE_DIR, save_compressed=True)
|
||||
tokenizer.save_pretrained(SAVE_DIR)
|
||||
|
||||
This process creates a W8A8 model with weights and activations quantized to 8-bit integers.
|
||||
|
||||
4. Evaluating Accuracy
|
||||
^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
After quantization, you can load and run the model in vLLM:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from vllm import LLM
|
||||
model = LLM("./Meta-Llama-3-8B-Instruct-W8A8-Dynamic-Per-Token")
|
||||
|
||||
To evaluate accuracy, you can use ``lm_eval``:
|
||||
|
||||
.. code-block:: console
|
||||
|
||||
$ lm_eval --model vllm \
|
||||
--model_args pretrained="./Meta-Llama-3-8B-Instruct-W8A8-Dynamic-Per-Token",add_bos_token=true \
|
||||
--tasks gsm8k \
|
||||
--num_fewshot 5 \
|
||||
--limit 250 \
|
||||
--batch_size 'auto'
|
||||
|
||||
.. note::
|
||||
|
||||
Quantized models can be sensitive to the presence of the ``bos`` token. Make sure to include the ``add_bos_token=True`` argument when running evaluations.
|
||||
|
||||
Best Practices
|
||||
--------------
|
||||
|
||||
- Start with 512 samples for calibration data (increase if accuracy drops)
|
||||
- Use a sequence length of 2048 as a starting point
|
||||
- Employ the chat template or instruction template that the model was trained with
|
||||
- If you've fine-tuned a model, consider using a sample of your training data for calibration
|
||||
|
||||
Troubleshooting and Support
|
||||
---------------------------
|
||||
|
||||
If you encounter any issues or have feature requests, please open an issue on the ``vllm-project/llm-compressor`` GitHub repository.
|
||||
Loading…
x
Reference in New Issue
Block a user