[TPU] support fp8 kv cache quantization (#19292)

Signed-off-by: Chengji Yao <chengjiyao@google.com>
This commit is contained in:
Chengji Yao 2025-07-19 20:01:00 -07:00 committed by GitHub
parent 2b504eb770
commit 3a1d8940ae
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 94 additions and 27 deletions

View File

@ -15,15 +15,18 @@ import pytest
from vllm.platforms import current_platform
MODEL_NAMES = [
"Qwen/Qwen2-1.5B-Instruct",
"Qwen/Qwen3-1.7B",
"google/gemma-3-1b-it",
]
FP8_KV_MODEL_NAMES = [
"Qwen/Qwen3-1.7B",
]
NUM_CONCURRENT = 500
TASK = "gsm8k"
FILTER = "exact_match,strict-match"
RTOL = 0.03
EXPECTED_VALUES = {
"Qwen/Qwen2-1.5B-Instruct": 0.58,
"Qwen/Qwen3-1.7B": 0.68,
"google/gemma-3-1b-it": 0.25,
}
@ -70,10 +73,9 @@ def test_lm_eval_accuracy_v1_engine(model, monkeypatch: pytest.MonkeyPatch):
if current_platform.is_tpu():
# Limit compilation time for TPU V1
if model == "google/gemma-3-1b-it":
# TPU + google/gemma-3-1b-it + xet doesn't work well.
m.setenv("HF_HUB_DISABLE_XET", "1")
# xet doesn't work well for both Qwen/Qwen3-1.7B and
# google/gemma-3-1b-it
m.setenv("HF_HUB_DISABLE_XET", "1")
more_args = "max_model_len=2048,max_num_seqs=64"
# Add TP test (if provided)
@ -83,9 +85,27 @@ def test_lm_eval_accuracy_v1_engine(model, monkeypatch: pytest.MonkeyPatch):
run_test(model, more_args)
def test_lm_eval_accuracy_v0_engine(monkeypatch: pytest.MonkeyPatch):
"""Run with the V0 Engine."""
@pytest.mark.skipif(not current_platform.is_cuda()
and not current_platform.is_tpu(),
reason="V1 is currently only supported on CUDA and TPU")
@pytest.mark.parametrize("model", FP8_KV_MODEL_NAMES)
def test_lm_eval_accuracy_v1_engine_fp8_kv_cache(
model, monkeypatch: pytest.MonkeyPatch):
"""Run with the V1 Engine."""
with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "0")
run_test("Qwen/Qwen2-1.5B-Instruct")
m.setenv("VLLM_USE_V1", "1")
more_args = None
if current_platform.is_tpu():
# Limit compilation time for TPU V1
# xet doesn't work well for Qwen/Qwen3-1.7B
m.setenv("HF_HUB_DISABLE_XET", "1")
more_args = "max_model_len=2048,max_num_seqs=128,kv_cache_dtype=fp8"
# Add TP test (if provided)
if TPU_TP_TEST_STR:
more_args += ",{}".format(TPU_TP_TEST_STR)
run_test(model, more_args)

View File

@ -95,4 +95,6 @@ def test_ragged_paged_attention():
sm_scale=scale,
sliding_window=sliding_window,
soft_cap=logits_soft_cap,
k_scale=1.0,
v_scale=1.0,
)

View File

@ -1358,10 +1358,10 @@ class EngineArgs:
and not envs.is_set("VLLM_ATTENTION_BACKEND")
) or envs.VLLM_ATTENTION_BACKEND == "FLASH_ATTN_VLLM_V1"
supported = False
if current_platform.is_rocm() or (
current_platform.is_cuda()
and current_platform.is_device_capability(100)
): # handle hpu also for OOT platform
if (current_platform.is_rocm()
or (current_platform.is_cuda()
and current_platform.is_device_capability(100))
or current_platform.is_tpu()):
supported = True
elif fp8_attention and will_use_fa:
from vllm.attention.utils.fa_utils import (

View File

@ -35,7 +35,9 @@ class TpuPlatform(Platform):
device_control_env_var: str = "TPU_VISIBLE_CHIPS"
simple_compile_backend: str = "openxla"
supported_quantization: list[str] = ["tpu_int8", "compressed-tensors"]
supported_quantization: list[str] = [
"fp8", "tpu_int8", "compressed-tensors"
]
additional_env_vars: list[str] = [
"TPU_CHIPS_PER_HOST_BOUNDS", "TPU_HOST_BOUNDS"

View File

@ -24,6 +24,19 @@ logger = init_logger(__name__)
# TPU requires the head size to be a multiple of 128.
TPU_HEAD_SIZE_ALIGNMENT = 128
# Note: TPU can fp8 as storage dtype but doesn't support converting from uint8
# from to fp32 directly. That's why it has a dtype mapping different from GPU
TPU_STR_DTYPE_TO_TORCH_DTYPE = {
"half": torch.half,
"bfloat16": torch.bfloat16,
"float": torch.float,
"fp8": torch.float8_e4m3fn,
"fp8_e4m3": torch.float8_e4m3fn,
"fp8_e5m2": torch.float8_e5m2,
"int8": torch.int8,
"uint8": torch.uint8,
}
class PallasAttentionBackend(AttentionBackend):
@ -152,8 +165,6 @@ class PallasAttentionBackendImpl(AttentionImpl):
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
if alibi_slopes is not None:
raise NotImplementedError("Alibi slopes is not supported.")
if kv_cache_dtype != "auto":
raise NotImplementedError("FP8 KV cache dtype is not supported.")
if attn_type != AttentionType.DECODER:
raise NotImplementedError("Encoder self-attention and "
@ -161,6 +172,11 @@ class PallasAttentionBackendImpl(AttentionImpl):
"are not implemented for "
"PallasAttentionBackendImpl")
self.kv_cache_quantized_dtype = None
if kv_cache_dtype != "auto":
self.kv_cache_quantized_dtype = TPU_STR_DTYPE_TO_TORCH_DTYPE.get(
kv_cache_dtype.lower().strip())
def forward(
self,
layer: AttentionLayer,
@ -194,7 +210,6 @@ class PallasAttentionBackendImpl(AttentionImpl):
output = torch.ones_like(query)
return output
assert layer._k_scale_float == 1.0 and layer._v_scale_float == 1.0
num_tokens, hidden_size = query.shape
query = query.view(num_tokens, self.num_heads, self.head_size)
key = key.view(-1, self.num_kv_heads, self.head_size)
@ -215,10 +230,21 @@ class PallasAttentionBackendImpl(AttentionImpl):
# Skip this if sharing KV cache with an earlier attention layer.
slot_mapping = attn_metadata.slot_mapping
write_to_kv_cache(
key, value, kv_cache, slot_mapping,
key,
value,
kv_cache,
slot_mapping,
attn_metadata.num_slices_per_kv_cache_update_block,
attn_metadata.num_kv_update_slices)
attn_metadata.num_kv_update_slices,
self.kv_cache_quantized_dtype,
layer._k_scale_float,
layer._v_scale_float,
)
if self.kv_cache_quantized_dtype is not None and (
layer._k_scale_float == 0.0 or layer._v_scale_float == 0.0):
raise ValueError(
"k_scale_float and v_scale_float must be non-zero")
output = torch.ops.xla.ragged_paged_attention(
query,
kv_cache,
@ -236,6 +262,8 @@ class PallasAttentionBackendImpl(AttentionImpl):
sm_scale=self.scale,
sliding_window=self.sliding_window,
soft_cap=self.logits_soft_cap,
k_scale=layer._k_scale_float,
v_scale=layer._v_scale_float,
)
if self.head_size % TPU_HEAD_SIZE_ALIGNMENT != 0:
@ -251,18 +279,32 @@ def write_to_kv_cache(
slot_mapping: torch.Tensor,
num_slices_per_kv_cache_update_block: int,
num_kv_update_slices: torch.Tensor,
kv_cache_quantized_dtype: Optional[torch.dtype] = None,
k_scale: float = 1.0,
v_scale: float = 1.0,
) -> None:
""" Write the key and values to the KV cache.
Args:
key: shape = [num_tokens, num_kv_heads * head_size]
value: shape = [num_tokens, num_kv_heads * head_size]
key: shape = [num_tokens, num_kv_heads, head_size]
value: shape = [num_tokens, num_kv_heads, head_size]
kv_cache = [num_blocks, block_size, num_kv_heads * 2, head_size]
num_slices_per_kv_cache_update_block: int
"""
_, page_size, num_combined_kv_heads, head_size = kv_cache.shape
head_size = cdiv(head_size,
TPU_HEAD_SIZE_ALIGNMENT) * TPU_HEAD_SIZE_ALIGNMENT
if kv_cache_quantized_dtype is not None:
dtype_info = torch.finfo(kv_cache_quantized_dtype)
key = key.to(torch.float32) / k_scale
# NOTE: clamp is added here to avoid out of range of quantized dtype
key = torch.clamp(key, dtype_info.min, dtype_info.max)
key = key.to(kv_cache_quantized_dtype)
value = value.to(torch.float32) / v_scale
value = torch.clamp(value, dtype_info.min, dtype_info.max)
value = value.to(kv_cache_quantized_dtype)
kv = torch.cat([key, value], axis=-1).reshape(-1, num_combined_kv_heads,
head_size)

View File

@ -32,9 +32,10 @@ from vllm.multimodal.inputs import (BatchedTensorInputs, MultiModalKwargs,
from vllm.multimodal.utils import group_mm_inputs_by_modality
from vllm.pooling_params import PoolingTask
from vllm.sequence import IntermediateTensors
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, LayerBlockType, cdiv,
is_pin_memory_available, prev_power_of_2)
from vllm.v1.attention.backends.pallas import (PallasAttentionBackend,
from vllm.utils import (LayerBlockType, cdiv, is_pin_memory_available,
prev_power_of_2)
from vllm.v1.attention.backends.pallas import (TPU_STR_DTYPE_TO_TORCH_DTYPE,
PallasAttentionBackend,
PallasMetadata,
get_page_size_bytes)
from vllm.v1.core.encoder_cache_manager import compute_encoder_budget
@ -142,11 +143,11 @@ class TPUModelRunner(LoRAModelRunnerMixin):
if cache_config.cache_dtype == "auto":
model_dtype = self.dtype
if isinstance(model_dtype, str):
self.kv_cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[model_dtype]
self.kv_cache_dtype = TPU_STR_DTYPE_TO_TORCH_DTYPE[model_dtype]
else:
self.kv_cache_dtype = model_dtype
else:
self.kv_cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[
self.kv_cache_dtype = TPU_STR_DTYPE_TO_TORCH_DTYPE[
cache_config.cache_dtype]
self._hidden_states_dtype = self.dtype