mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-12 15:36:14 +08:00
[Bugfix] bugfix and add model test for flashinfer fp8 kv cache. (#8013)
This commit is contained in:
parent
1248e8506a
commit
622f8abff8
96
tests/models/test_fp8kv_flashinfer.py
Normal file
96
tests/models/test_fp8kv_flashinfer.py
Normal file
@ -0,0 +1,96 @@
|
|||||||
|
# flake8: noqa
|
||||||
|
"""Tests fp8 models against ground truth generation
|
||||||
|
This verifies the flashinfer backend with fp8
|
||||||
|
quantization and fp8 KV Cache without scaling
|
||||||
|
factors Note: these tests will only pass on H100 GPU.
|
||||||
|
"""
|
||||||
|
import os
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from transformers import AutoTokenizer
|
||||||
|
|
||||||
|
from tests.quantization.utils import is_quant_method_supported
|
||||||
|
from vllm import LLM, SamplingParams
|
||||||
|
|
||||||
|
os.environ["TOKENIZERS_PARALLELISM"] = "true"
|
||||||
|
|
||||||
|
MAX_MODEL_LEN = 1024
|
||||||
|
|
||||||
|
MODELS = [
|
||||||
|
"nm-testing/Meta-Llama-3-8B-Instruct-FP8",
|
||||||
|
]
|
||||||
|
|
||||||
|
EXPECTED_STRS_MAP = {
|
||||||
|
"nm-testing/Meta-Llama-3-8B-Instruct-FP8": {
|
||||||
|
"auto": [
|
||||||
|
'LLaMA is a high-throughput and memory-efficient inference and serving engine for Large Language Models (',
|
||||||
|
'Here are the major milestones in the development of artificial intelligence (AI) from 1950 to ',
|
||||||
|
'Artificial intelligence (AI) and human intelligence (HI) differ significantly in how they process information.',
|
||||||
|
'A neural network is a complex system modeled after the human brain, consisting of interconnected nodes or "ne',
|
||||||
|
'In the sterile, metallic halls of the robotics lab, a peculiar phenomenon occurred. Zeta-5',
|
||||||
|
'The COVID-19 pandemic has had a profound impact on global economic structures and future business models. The',
|
||||||
|
'The Mona Lisa, painted by Leonardo da Vinci in the early 16th century, is one of',
|
||||||
|
'Here are the translations:\n\n**Japanese:** (Haya aki no tori, mushi o',
|
||||||
|
],
|
||||||
|
"fp8": [
|
||||||
|
'LLM (Large Language Model) is a type of artificial intelligence (AI) model that is trained',
|
||||||
|
'Here are the major milestones in the development of artificial intelligence (AI) from 1950 to ',
|
||||||
|
'Artificial intelligence (AI) and human intelligence (HI) differ significantly in how they process information.',
|
||||||
|
'A neural network is a complex system modeled after the human brain, composed of interconnected nodes or "ne',
|
||||||
|
'Zeta-5, a highly advanced robot designed for menial labor, whirred and beep',
|
||||||
|
'The COVID-19 pandemic has had a profound impact on global economic structures and future business models. Here',
|
||||||
|
'The Mona Lisa, painted by Leonardo da Vinci in the early 16th century, is one of',
|
||||||
|
'Here are the translations:\n\n**Japanese:** (Haya aki no tori, guri o',
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# This test compares against golden strings for exact match since
|
||||||
|
# there is no baseline implementation to compare against
|
||||||
|
# and is unstable w.r.t specifics of the fp8 implementation or
|
||||||
|
# the hardware being run on.
|
||||||
|
# No assert to prevent it from breaking the build
|
||||||
|
@pytest.mark.skipif(not is_quant_method_supported("fp8"),
|
||||||
|
reason="fp8 is not supported on this GPU type.")
|
||||||
|
@pytest.mark.parametrize("model_name", MODELS)
|
||||||
|
@pytest.mark.parametrize("kv_cache_dtype", ["auto", "fp8"])
|
||||||
|
@pytest.mark.parametrize("backend", ["XFORMERS", "FLASHINFER"])
|
||||||
|
def test_models(example_prompts, model_name, kv_cache_dtype, backend) -> None:
|
||||||
|
# Note that the golden strings may not work for FLASHINFER Backend.
|
||||||
|
# The intention is to test the path
|
||||||
|
os.environ["VLLM_ATTENTION_BACKEND"] = backend
|
||||||
|
model = LLM(model=model_name,
|
||||||
|
max_model_len=MAX_MODEL_LEN,
|
||||||
|
trust_remote_code=True,
|
||||||
|
quantization="fp8",
|
||||||
|
kv_cache_dtype=kv_cache_dtype)
|
||||||
|
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||||
|
formatted_prompts = [
|
||||||
|
tokenizer.apply_chat_template([{
|
||||||
|
"role": "user",
|
||||||
|
"content": prompt
|
||||||
|
}],
|
||||||
|
tokenize=False,
|
||||||
|
add_generation_prompt=True)
|
||||||
|
for prompt in example_prompts
|
||||||
|
]
|
||||||
|
|
||||||
|
params = SamplingParams(max_tokens=20, temperature=0)
|
||||||
|
generations: List[str] = []
|
||||||
|
# Note: these need to be run 1 at a time due to numerical precision,
|
||||||
|
# since the expected strs were generated this way.
|
||||||
|
for prompt in formatted_prompts:
|
||||||
|
outputs = model.generate(prompt, params)
|
||||||
|
generations.append(outputs[0].outputs[0].text)
|
||||||
|
del model
|
||||||
|
|
||||||
|
print(f"Testing: {model_name} with kv_cache_dtype: {kv_cache_dtype}")
|
||||||
|
expected_strs = EXPECTED_STRS_MAP[model_name][kv_cache_dtype]
|
||||||
|
for i in range(len(example_prompts)):
|
||||||
|
generated_str = generations[i]
|
||||||
|
expected_str = expected_strs[i]
|
||||||
|
print(f"generated_str\n: {generated_str}")
|
||||||
|
print(f"expected_str\n: {expected_str}")
|
||||||
@ -186,9 +186,13 @@ class FlashInferState(AttentionState):
|
|||||||
self._graph_decode_workspace_buffer, _indptr_buffer,
|
self._graph_decode_workspace_buffer, _indptr_buffer,
|
||||||
self._graph_indices_buffer, _last_page_len_buffer, "NHD",
|
self._graph_indices_buffer, _last_page_len_buffer, "NHD",
|
||||||
use_tensor_cores)
|
use_tensor_cores)
|
||||||
|
if self.runner.kv_cache_dtype.startswith("fp8"):
|
||||||
kv_cache_dtype = FlashInferBackend.get_fp8_dtype_for_flashinfer(
|
kv_cache_dtype = FlashInferBackend.get_fp8_dtype_for_flashinfer(
|
||||||
self.runner.kv_cache_dtype)
|
self.runner.kv_cache_dtype)
|
||||||
|
else:
|
||||||
|
kv_cache_dtype = get_kv_cache_torch_dtype(
|
||||||
|
self.runner.kv_cache_dtype, self.runner.model_config.dtype)
|
||||||
|
|
||||||
paged_kv_indptr_tensor_host = torch.arange(0,
|
paged_kv_indptr_tensor_host = torch.arange(0,
|
||||||
batch_size + 1,
|
batch_size + 1,
|
||||||
dtype=torch.int32)
|
dtype=torch.int32)
|
||||||
@ -349,7 +353,7 @@ class FlashInferMetadata(AttentionMetadata):
|
|||||||
self.page_size,
|
self.page_size,
|
||||||
# Disable flashinfer's pos encoding and use vllm's rope.
|
# Disable flashinfer's pos encoding and use vllm's rope.
|
||||||
pos_encoding_mode="NONE",
|
pos_encoding_mode="NONE",
|
||||||
)
|
data_type=self.data_type)
|
||||||
|
|
||||||
def asdict_zerocopy(self,
|
def asdict_zerocopy(self,
|
||||||
skip_fields: Optional[Set[str]] = None
|
skip_fields: Optional[Set[str]] = None
|
||||||
@ -586,6 +590,10 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
|||||||
paged_kv_indptr_tensor = None
|
paged_kv_indptr_tensor = None
|
||||||
paged_kv_last_page_len_tensor = None
|
paged_kv_last_page_len_tensor = None
|
||||||
|
|
||||||
|
if self.runner.kv_cache_dtype.startswith("fp8"):
|
||||||
|
kv_cache_dtype = FlashInferBackend.get_fp8_dtype_for_flashinfer(
|
||||||
|
self.runner.kv_cache_dtype)
|
||||||
|
else:
|
||||||
kv_cache_dtype = get_kv_cache_torch_dtype(
|
kv_cache_dtype = get_kv_cache_torch_dtype(
|
||||||
self.runner.kv_cache_dtype, self.runner.model_config.dtype)
|
self.runner.kv_cache_dtype, self.runner.model_config.dtype)
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user