Currently we need to call rotary embedding kernel for each LoRA, which makes it hard to serve multiple long context length LoRA. Add batched rotary embedding kernel and pipe it through.
It replaces the rotary embedding layer to the one that is aware of multiple cos-sin-cache per scaling factors.
Follow up of https://github.com/vllm-project/vllm/pull/3095/files
Follow on to #4332 to enable FP8 checkpoint loading for Mixtral and supersedes #4436.
This PR enables the following checkpoint loading features for Mixtral:
Supports loading fp8 checkpoints for Mixtral, such as this "nm-testing/Mixtral-8x7B-Instruct-v0.1-FP8" test model
Supports static or dynamic activation quantization with static weight quantization (all per tensor)
Supports different scales for each expert weight
Supports Fp8 in QKV layer
Notes:
The Expert Gate/Router always runs at half / full precision for now.
If there are different weight scales between QKV layer (for separate QKV weights), they are re-quantized using layer.weight_scale.max() so we can have a single gemm for performance.
This PR is the first step towards fixing https://github.com/vllm-project/vllm/pull/3208
It implements dynamic per-tensor scaling (see https://github.com/vllm-project/vllm/pull/4118), so users do not need to compute activation scales on a calibration dataset and they also don't need to convert their model checkpoints. It is enough to specify the `quantization="fp8"` argument. You can try out the PR like this:
```python
from vllm import LLM, SamplingParams
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
llm = LLM(model="mistralai/Mixtral-8x7B-Instruct-v0.1", tensor_parallel_size=2, quantization="fp8")
outputs = llm.generate(prompts, sampling_params)
# Print the outputs.
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
```
**Performance**: For this PR, the focus is on making the code clean (while still trying to get reasonable performance), there is a bunch of optimizations that we will submit as a follow up PR that significantly improve the performance (similar to the numbers in https://github.com/vllm-project/vllm/pull/3954). With this PR, the results are as follows:
<img width="725" alt="Screenshot 2024-04-21 at 1 31 50 PM" src="https://github.com/vllm-project/vllm/assets/113316/d8fe1118-07a0-4d4e-8530-37a77d465a03">
**Accuracy**: The accuracy with this PR on MMLU on `mistralai/Mixtral-8x7B-v0.1` is as follows:
```
| Groups |Version|Filter|n-shot|Metric|Value | |Stderr|
|------------------|-------|------|-----:|------|-----:|---|-----:|
|mmlu |N/A |none | 0|acc |0.7018|± |0.0036|
| - humanities |N/A |none | 5|acc |0.6472|± |0.0065|
| - other |N/A |none | 5|acc |0.7673|± |0.0072|
| - social_sciences|N/A |none | 5|acc |0.8099|± |0.0070|
| - stem |N/A |none | 5|acc |0.6131|± |0.0083|
```
this compares favorably with the fp16 results which are
```
| Groups |Version|Filter|n-shot|Metric|Value | |Stderr|
|------------------|-------|------|-----:|------|-----:|---|-----:|
|mmlu |N/A |none | 0|acc |0.7020|± |0.1313|
| - humanities |N/A |none | 5|acc |0.6425|± |0.1349|
| - other |N/A |none | 5|acc |0.7744|± |0.1038|
| - social_sciences|N/A |none | 5|acc |0.8131|± |0.0695|
| - stem |N/A |none | 5|acc |0.6108|± |0.1383|
```
Happy hacking!