mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-19 21:34:27 +08:00
Update nm to rht in doc links + refine fp8 doc (#17678)
Signed-off-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
parent
90bd2ae172
commit
98834fefaa
@ -19,24 +19,6 @@ FP8 computation is supported on NVIDIA GPUs with compute capability > 8.9 (Ada L
|
|||||||
FP8 models will run on compute capability > 8.0 (Ampere) as weight-only W8A16, utilizing FP8 Marlin.
|
FP8 models will run on compute capability > 8.0 (Ampere) as weight-only W8A16, utilizing FP8 Marlin.
|
||||||
:::
|
:::
|
||||||
|
|
||||||
## Quick Start with Online Dynamic Quantization
|
|
||||||
|
|
||||||
Dynamic quantization of an original precision BF16/FP16 model to FP8 can be achieved with vLLM without any calibration data required. You can enable the feature by specifying `--quantization="fp8"` in the command line or setting `quantization="fp8"` in the LLM constructor.
|
|
||||||
|
|
||||||
In this mode, all Linear modules (except for the final `lm_head`) have their weights quantized down to FP8_E4M3 precision with a per-tensor scale. Activations have their minimum and maximum values calculated during each forward pass to provide a dynamic per-tensor scale for high accuracy. As a result, latency improvements are limited in this mode.
|
|
||||||
|
|
||||||
```python
|
|
||||||
from vllm import LLM
|
|
||||||
model = LLM("facebook/opt-125m", quantization="fp8")
|
|
||||||
# INFO 06-10 17:55:42 model_runner.py:157] Loading model weights took 0.1550 GB
|
|
||||||
result = model.generate("Hello, my name is")
|
|
||||||
print(result[0].outputs[0].text)
|
|
||||||
```
|
|
||||||
|
|
||||||
:::{warning}
|
|
||||||
Currently, we load the model at original precision before quantizing down to 8-bits, so you need enough memory to load the whole model.
|
|
||||||
:::
|
|
||||||
|
|
||||||
## Installation
|
## Installation
|
||||||
|
|
||||||
To produce performant FP8 quantized models with vLLM, you'll need to install the [llm-compressor](https://github.com/vllm-project/llm-compressor/) library:
|
To produce performant FP8 quantized models with vLLM, you'll need to install the [llm-compressor](https://github.com/vllm-project/llm-compressor/) library:
|
||||||
@ -45,12 +27,6 @@ To produce performant FP8 quantized models with vLLM, you'll need to install the
|
|||||||
pip install llmcompressor
|
pip install llmcompressor
|
||||||
```
|
```
|
||||||
|
|
||||||
Additionally, install `vllm` and `lm-evaluation-harness` for evaluation:
|
|
||||||
|
|
||||||
```console
|
|
||||||
pip install vllm lm-eval==0.4.4
|
|
||||||
```
|
|
||||||
|
|
||||||
## Quantization Process
|
## Quantization Process
|
||||||
|
|
||||||
The quantization process involves three main steps:
|
The quantization process involves three main steps:
|
||||||
@ -101,6 +77,12 @@ tokenizer.save_pretrained(SAVE_DIR)
|
|||||||
|
|
||||||
### 3. Evaluating Accuracy
|
### 3. Evaluating Accuracy
|
||||||
|
|
||||||
|
Install `vllm` and `lm-evaluation-harness` for evaluation:
|
||||||
|
|
||||||
|
```console
|
||||||
|
pip install vllm lm-eval==0.4.4
|
||||||
|
```
|
||||||
|
|
||||||
Load and run the model in `vllm`:
|
Load and run the model in `vllm`:
|
||||||
|
|
||||||
```python
|
```python
|
||||||
@ -137,58 +119,20 @@ Here's an example of the resulting scores:
|
|||||||
|
|
||||||
If you encounter any issues or have feature requests, please open an issue on the `vllm-project/llm-compressor` GitHub repository.
|
If you encounter any issues or have feature requests, please open an issue on the `vllm-project/llm-compressor` GitHub repository.
|
||||||
|
|
||||||
## Deprecated Flow
|
## Online Dynamic Quantization
|
||||||
|
|
||||||
:::{note}
|
Dynamic quantization of an original precision BF16/FP16 model to FP8 can be achieved with vLLM without any calibration data required. You can enable the feature by specifying `--quantization="fp8"` in the command line or setting `quantization="fp8"` in the LLM constructor.
|
||||||
The following information is preserved for reference and search purposes.
|
|
||||||
The quantization method described below is deprecated in favor of the `llmcompressor` method described above.
|
|
||||||
:::
|
|
||||||
|
|
||||||
For static per-tensor offline quantization to FP8, please install the [AutoFP8 library](https://github.com/neuralmagic/autofp8).
|
In this mode, all Linear modules (except for the final `lm_head`) have their weights quantized down to FP8_E4M3 precision with a per-tensor scale. Activations have their minimum and maximum values calculated during each forward pass to provide a dynamic per-tensor scale for high accuracy. As a result, latency improvements are limited in this mode.
|
||||||
|
|
||||||
```bash
|
|
||||||
git clone https://github.com/neuralmagic/AutoFP8.git
|
|
||||||
pip install -e AutoFP8
|
|
||||||
```
|
|
||||||
|
|
||||||
This package introduces the `AutoFP8ForCausalLM` and `BaseQuantizeConfig` objects for managing how your model will be compressed.
|
|
||||||
|
|
||||||
## Offline Quantization with Static Activation Scaling Factors
|
|
||||||
|
|
||||||
You can use AutoFP8 with calibration data to produce per-tensor static scales for both the weights and activations by enabling the `activation_scheme="static"` argument.
|
|
||||||
|
|
||||||
```python
|
|
||||||
from datasets import load_dataset
|
|
||||||
from transformers import AutoTokenizer
|
|
||||||
from auto_fp8 import AutoFP8ForCausalLM, BaseQuantizeConfig
|
|
||||||
|
|
||||||
pretrained_model_dir = "meta-llama/Meta-Llama-3-8B-Instruct"
|
|
||||||
quantized_model_dir = "Meta-Llama-3-8B-Instruct-FP8"
|
|
||||||
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(pretrained_model_dir, use_fast=True)
|
|
||||||
tokenizer.pad_token = tokenizer.eos_token
|
|
||||||
|
|
||||||
# Load and tokenize 512 dataset samples for calibration of activation scales
|
|
||||||
ds = load_dataset("mgoin/ultrachat_2k", split="train_sft").select(range(512))
|
|
||||||
examples = [tokenizer.apply_chat_template(batch["messages"], tokenize=False) for batch in ds]
|
|
||||||
examples = tokenizer(examples, padding=True, truncation=True, return_tensors="pt").to("cuda")
|
|
||||||
|
|
||||||
# Define quantization config with static activation scales
|
|
||||||
quantize_config = BaseQuantizeConfig(quant_method="fp8", activation_scheme="static")
|
|
||||||
|
|
||||||
# Load the model, quantize, and save checkpoint
|
|
||||||
model = AutoFP8ForCausalLM.from_pretrained(pretrained_model_dir, quantize_config)
|
|
||||||
model.quantize(examples)
|
|
||||||
model.save_quantized(quantized_model_dir)
|
|
||||||
```
|
|
||||||
|
|
||||||
Your model checkpoint with quantized weights and activations should be available at `Meta-Llama-3-8B-Instruct-FP8/`.
|
|
||||||
Finally, you can load the quantized model checkpoint directly in vLLM.
|
|
||||||
|
|
||||||
```python
|
```python
|
||||||
from vllm import LLM
|
from vllm import LLM
|
||||||
model = LLM(model="Meta-Llama-3-8B-Instruct-FP8/")
|
model = LLM("facebook/opt-125m", quantization="fp8")
|
||||||
# INFO 06-10 21:15:41 model_runner.py:159] Loading model weights took 8.4596 GB
|
# INFO 06-10 17:55:42 model_runner.py:157] Loading model weights took 0.1550 GB
|
||||||
result = model.generate("Hello, my name is")
|
result = model.generate("Hello, my name is")
|
||||||
print(result[0].outputs[0].text)
|
print(result[0].outputs[0].text)
|
||||||
```
|
```
|
||||||
|
|
||||||
|
:::{warning}
|
||||||
|
Currently, we load the model at original precision before quantizing down to 8-bits, so you need enough memory to load the whole model.
|
||||||
|
:::
|
||||||
|
|||||||
@ -95,7 +95,7 @@ You can convert the model checkpoint to a sharded checkpoint using <gh-file:exam
|
|||||||
|
|
||||||
Quantized models take less memory at the cost of lower precision.
|
Quantized models take less memory at the cost of lower precision.
|
||||||
|
|
||||||
Statically quantized models can be downloaded from HF Hub (some popular ones are available at [Neural Magic](https://huggingface.co/neuralmagic))
|
Statically quantized models can be downloaded from HF Hub (some popular ones are available at [Red Hat AI](https://huggingface.co/RedHatAI))
|
||||||
and used directly without extra configuration.
|
and used directly without extra configuration.
|
||||||
|
|
||||||
Dynamic quantization is also supported via the `quantization` option -- see [here](#quantization-index) for more details.
|
Dynamic quantization is also supported via the `quantization` option -- see [here](#quantization-index) for more details.
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user