Add NVIDIA TensorRT Model Optimizer in vLLM documentation (#17561)

This commit is contained in:
Zhiyu 2025-05-02 11:36:46 -07:00 committed by GitHub
parent 3e887d2e0c
commit 182f40ea8b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 90 additions and 1 deletions

View File

@ -17,6 +17,7 @@ gptqmodel
int4
int8
fp8
modelopt
quark
quantized_kvcache
torchao

View File

@ -0,0 +1,78 @@
# NVIDIA TensorRT Model Optimizer
The [NVIDIA TensorRT Model Optimizer](https://github.com/NVIDIA/TensorRT-Model-Optimizer) is a library designed to optimize models for inference with NVIDIA GPUs. It includes tools for Post-Training Quantization (PTQ) and Quantization Aware Training (QAT) of Large Language Models (LLMs), Vision Language Models (VLMs), and diffusion models.
We recommend installing the library with:
```console
pip install nvidia-modelopt
```
## Quantizing HuggingFace Models with PTQ
You can quantize HuggingFace models using the example scripts provided in the TensorRT Model Optimizer repository. The primary script for LLM PTQ is typically found within the `examples/llm_ptq` directory.
Below is an example showing how to quantize a model using modelopt's PTQ API:
```python
import modelopt.torch.quantization as mtq
from transformers import AutoModelForCausalLM
# Load the model from HuggingFace
model = AutoModelForCausalLM.from_pretrained("<path_or_model_id>")
# Select the quantization config, for example, FP8
config = mtq.FP8_DEFAULT_CFG
# Define a forward loop function for calibration
def forward_loop(model):
for data in calib_set:
model(data)
# PTQ with in-place replacement of quantized modules
model = mtq.quantize(model, config, forward_loop)
```
After the model is quantized, you can export it to a quantized checkpoint using the export API:
```python
import torch
from modelopt.torch.export import export_hf_checkpoint
with torch.inference_mode():
export_hf_checkpoint(
model, # The quantized model.
export_dir, # The directory where the exported files will be stored.
)
```
The quantized checkpoint can then be deployed with vLLM. As an example, the following code shows how to deploy `nvidia/Llama-3.1-8B-Instruct-FP8`, which is the FP8 quantized checkpoint derived from `meta-llama/Llama-3.1-8B-Instruct`, using vLLM:
```python
from vllm import LLM, SamplingParams
def main():
model_id = "nvidia/Llama-3.1-8B-Instruct-FP8"
# Ensure you specify quantization='modelopt' when loading the modelopt checkpoint
llm = LLM(model=model_id, quantization="modelopt", trust_remote_code=True)
sampling_params = SamplingParams(temperature=0.8, top_p=0.9)
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
outputs = llm.generate(prompts, sampling_params)
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
if __name__ == "__main__":
main()
```

View File

@ -129,7 +129,17 @@ The table below shows the compatibility of various quantization implementations
* ❌
* ❌
* ❌
- * modelopt
* ✅︎
* ✅︎
* ✅︎
* ✅︎
* ✅︎︎
* ❌
* ❌
* ❌
* ❌
* ❌
:::
- Volta refers to SM 7.0, Turing to SM 7.5, Ampere to SM 8.0/8.6, Ada to SM 8.9, and Hopper to SM 9.0.