[TPU] support fp8 kv cache quantization (#19292)

Signed-off-by: Chengji Yao <chengjiyao@google.com>
This commit is contained in:
Chengji Yao 2025-07-19 20:01:00 -07:00 committed by GitHub
parent 2b504eb770
commit 3a1d8940ae
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 94 additions and 27 deletions

View File

@ -15,15 +15,18 @@ import pytest
from vllm.platforms import current_platform from vllm.platforms import current_platform
MODEL_NAMES = [ MODEL_NAMES = [
"Qwen/Qwen2-1.5B-Instruct", "Qwen/Qwen3-1.7B",
"google/gemma-3-1b-it", "google/gemma-3-1b-it",
] ]
FP8_KV_MODEL_NAMES = [
"Qwen/Qwen3-1.7B",
]
NUM_CONCURRENT = 500 NUM_CONCURRENT = 500
TASK = "gsm8k" TASK = "gsm8k"
FILTER = "exact_match,strict-match" FILTER = "exact_match,strict-match"
RTOL = 0.03 RTOL = 0.03
EXPECTED_VALUES = { EXPECTED_VALUES = {
"Qwen/Qwen2-1.5B-Instruct": 0.58, "Qwen/Qwen3-1.7B": 0.68,
"google/gemma-3-1b-it": 0.25, "google/gemma-3-1b-it": 0.25,
} }
@ -70,10 +73,9 @@ def test_lm_eval_accuracy_v1_engine(model, monkeypatch: pytest.MonkeyPatch):
if current_platform.is_tpu(): if current_platform.is_tpu():
# Limit compilation time for TPU V1 # Limit compilation time for TPU V1
if model == "google/gemma-3-1b-it": # xet doesn't work well for both Qwen/Qwen3-1.7B and
# TPU + google/gemma-3-1b-it + xet doesn't work well. # google/gemma-3-1b-it
m.setenv("HF_HUB_DISABLE_XET", "1") m.setenv("HF_HUB_DISABLE_XET", "1")
more_args = "max_model_len=2048,max_num_seqs=64" more_args = "max_model_len=2048,max_num_seqs=64"
# Add TP test (if provided) # Add TP test (if provided)
@ -83,9 +85,27 @@ def test_lm_eval_accuracy_v1_engine(model, monkeypatch: pytest.MonkeyPatch):
run_test(model, more_args) run_test(model, more_args)
def test_lm_eval_accuracy_v0_engine(monkeypatch: pytest.MonkeyPatch): @pytest.mark.skipif(not current_platform.is_cuda()
"""Run with the V0 Engine.""" and not current_platform.is_tpu(),
reason="V1 is currently only supported on CUDA and TPU")
@pytest.mark.parametrize("model", FP8_KV_MODEL_NAMES)
def test_lm_eval_accuracy_v1_engine_fp8_kv_cache(
model, monkeypatch: pytest.MonkeyPatch):
"""Run with the V1 Engine."""
with monkeypatch.context() as m: with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "0") m.setenv("VLLM_USE_V1", "1")
run_test("Qwen/Qwen2-1.5B-Instruct")
more_args = None
if current_platform.is_tpu():
# Limit compilation time for TPU V1
# xet doesn't work well for Qwen/Qwen3-1.7B
m.setenv("HF_HUB_DISABLE_XET", "1")
more_args = "max_model_len=2048,max_num_seqs=128,kv_cache_dtype=fp8"
# Add TP test (if provided)
if TPU_TP_TEST_STR:
more_args += ",{}".format(TPU_TP_TEST_STR)
run_test(model, more_args)

View File

@ -95,4 +95,6 @@ def test_ragged_paged_attention():
sm_scale=scale, sm_scale=scale,
sliding_window=sliding_window, sliding_window=sliding_window,
soft_cap=logits_soft_cap, soft_cap=logits_soft_cap,
k_scale=1.0,
v_scale=1.0,
) )

View File

@ -1358,10 +1358,10 @@ class EngineArgs:
and not envs.is_set("VLLM_ATTENTION_BACKEND") and not envs.is_set("VLLM_ATTENTION_BACKEND")
) or envs.VLLM_ATTENTION_BACKEND == "FLASH_ATTN_VLLM_V1" ) or envs.VLLM_ATTENTION_BACKEND == "FLASH_ATTN_VLLM_V1"
supported = False supported = False
if current_platform.is_rocm() or ( if (current_platform.is_rocm()
current_platform.is_cuda() or (current_platform.is_cuda()
and current_platform.is_device_capability(100) and current_platform.is_device_capability(100))
): # handle hpu also for OOT platform or current_platform.is_tpu()):
supported = True supported = True
elif fp8_attention and will_use_fa: elif fp8_attention and will_use_fa:
from vllm.attention.utils.fa_utils import ( from vllm.attention.utils.fa_utils import (

View File

@ -35,7 +35,9 @@ class TpuPlatform(Platform):
device_control_env_var: str = "TPU_VISIBLE_CHIPS" device_control_env_var: str = "TPU_VISIBLE_CHIPS"
simple_compile_backend: str = "openxla" simple_compile_backend: str = "openxla"
supported_quantization: list[str] = ["tpu_int8", "compressed-tensors"] supported_quantization: list[str] = [
"fp8", "tpu_int8", "compressed-tensors"
]
additional_env_vars: list[str] = [ additional_env_vars: list[str] = [
"TPU_CHIPS_PER_HOST_BOUNDS", "TPU_HOST_BOUNDS" "TPU_CHIPS_PER_HOST_BOUNDS", "TPU_HOST_BOUNDS"

View File

@ -24,6 +24,19 @@ logger = init_logger(__name__)
# TPU requires the head size to be a multiple of 128. # TPU requires the head size to be a multiple of 128.
TPU_HEAD_SIZE_ALIGNMENT = 128 TPU_HEAD_SIZE_ALIGNMENT = 128
# Note: TPU can fp8 as storage dtype but doesn't support converting from uint8
# from to fp32 directly. That's why it has a dtype mapping different from GPU
TPU_STR_DTYPE_TO_TORCH_DTYPE = {
"half": torch.half,
"bfloat16": torch.bfloat16,
"float": torch.float,
"fp8": torch.float8_e4m3fn,
"fp8_e4m3": torch.float8_e4m3fn,
"fp8_e5m2": torch.float8_e5m2,
"int8": torch.int8,
"uint8": torch.uint8,
}
class PallasAttentionBackend(AttentionBackend): class PallasAttentionBackend(AttentionBackend):
@ -152,8 +165,6 @@ class PallasAttentionBackendImpl(AttentionImpl):
self.num_queries_per_kv = self.num_heads // self.num_kv_heads self.num_queries_per_kv = self.num_heads // self.num_kv_heads
if alibi_slopes is not None: if alibi_slopes is not None:
raise NotImplementedError("Alibi slopes is not supported.") raise NotImplementedError("Alibi slopes is not supported.")
if kv_cache_dtype != "auto":
raise NotImplementedError("FP8 KV cache dtype is not supported.")
if attn_type != AttentionType.DECODER: if attn_type != AttentionType.DECODER:
raise NotImplementedError("Encoder self-attention and " raise NotImplementedError("Encoder self-attention and "
@ -161,6 +172,11 @@ class PallasAttentionBackendImpl(AttentionImpl):
"are not implemented for " "are not implemented for "
"PallasAttentionBackendImpl") "PallasAttentionBackendImpl")
self.kv_cache_quantized_dtype = None
if kv_cache_dtype != "auto":
self.kv_cache_quantized_dtype = TPU_STR_DTYPE_TO_TORCH_DTYPE.get(
kv_cache_dtype.lower().strip())
def forward( def forward(
self, self,
layer: AttentionLayer, layer: AttentionLayer,
@ -194,7 +210,6 @@ class PallasAttentionBackendImpl(AttentionImpl):
output = torch.ones_like(query) output = torch.ones_like(query)
return output return output
assert layer._k_scale_float == 1.0 and layer._v_scale_float == 1.0
num_tokens, hidden_size = query.shape num_tokens, hidden_size = query.shape
query = query.view(num_tokens, self.num_heads, self.head_size) query = query.view(num_tokens, self.num_heads, self.head_size)
key = key.view(-1, self.num_kv_heads, self.head_size) key = key.view(-1, self.num_kv_heads, self.head_size)
@ -215,10 +230,21 @@ class PallasAttentionBackendImpl(AttentionImpl):
# Skip this if sharing KV cache with an earlier attention layer. # Skip this if sharing KV cache with an earlier attention layer.
slot_mapping = attn_metadata.slot_mapping slot_mapping = attn_metadata.slot_mapping
write_to_kv_cache( write_to_kv_cache(
key, value, kv_cache, slot_mapping, key,
value,
kv_cache,
slot_mapping,
attn_metadata.num_slices_per_kv_cache_update_block, attn_metadata.num_slices_per_kv_cache_update_block,
attn_metadata.num_kv_update_slices) attn_metadata.num_kv_update_slices,
self.kv_cache_quantized_dtype,
layer._k_scale_float,
layer._v_scale_float,
)
if self.kv_cache_quantized_dtype is not None and (
layer._k_scale_float == 0.0 or layer._v_scale_float == 0.0):
raise ValueError(
"k_scale_float and v_scale_float must be non-zero")
output = torch.ops.xla.ragged_paged_attention( output = torch.ops.xla.ragged_paged_attention(
query, query,
kv_cache, kv_cache,
@ -236,6 +262,8 @@ class PallasAttentionBackendImpl(AttentionImpl):
sm_scale=self.scale, sm_scale=self.scale,
sliding_window=self.sliding_window, sliding_window=self.sliding_window,
soft_cap=self.logits_soft_cap, soft_cap=self.logits_soft_cap,
k_scale=layer._k_scale_float,
v_scale=layer._v_scale_float,
) )
if self.head_size % TPU_HEAD_SIZE_ALIGNMENT != 0: if self.head_size % TPU_HEAD_SIZE_ALIGNMENT != 0:
@ -251,18 +279,32 @@ def write_to_kv_cache(
slot_mapping: torch.Tensor, slot_mapping: torch.Tensor,
num_slices_per_kv_cache_update_block: int, num_slices_per_kv_cache_update_block: int,
num_kv_update_slices: torch.Tensor, num_kv_update_slices: torch.Tensor,
kv_cache_quantized_dtype: Optional[torch.dtype] = None,
k_scale: float = 1.0,
v_scale: float = 1.0,
) -> None: ) -> None:
""" Write the key and values to the KV cache. """ Write the key and values to the KV cache.
Args: Args:
key: shape = [num_tokens, num_kv_heads * head_size] key: shape = [num_tokens, num_kv_heads, head_size]
value: shape = [num_tokens, num_kv_heads * head_size] value: shape = [num_tokens, num_kv_heads, head_size]
kv_cache = [num_blocks, block_size, num_kv_heads * 2, head_size] kv_cache = [num_blocks, block_size, num_kv_heads * 2, head_size]
num_slices_per_kv_cache_update_block: int num_slices_per_kv_cache_update_block: int
""" """
_, page_size, num_combined_kv_heads, head_size = kv_cache.shape _, page_size, num_combined_kv_heads, head_size = kv_cache.shape
head_size = cdiv(head_size, head_size = cdiv(head_size,
TPU_HEAD_SIZE_ALIGNMENT) * TPU_HEAD_SIZE_ALIGNMENT TPU_HEAD_SIZE_ALIGNMENT) * TPU_HEAD_SIZE_ALIGNMENT
if kv_cache_quantized_dtype is not None:
dtype_info = torch.finfo(kv_cache_quantized_dtype)
key = key.to(torch.float32) / k_scale
# NOTE: clamp is added here to avoid out of range of quantized dtype
key = torch.clamp(key, dtype_info.min, dtype_info.max)
key = key.to(kv_cache_quantized_dtype)
value = value.to(torch.float32) / v_scale
value = torch.clamp(value, dtype_info.min, dtype_info.max)
value = value.to(kv_cache_quantized_dtype)
kv = torch.cat([key, value], axis=-1).reshape(-1, num_combined_kv_heads, kv = torch.cat([key, value], axis=-1).reshape(-1, num_combined_kv_heads,
head_size) head_size)

View File

@ -32,9 +32,10 @@ from vllm.multimodal.inputs import (BatchedTensorInputs, MultiModalKwargs,
from vllm.multimodal.utils import group_mm_inputs_by_modality from vllm.multimodal.utils import group_mm_inputs_by_modality
from vllm.pooling_params import PoolingTask from vllm.pooling_params import PoolingTask
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, LayerBlockType, cdiv, from vllm.utils import (LayerBlockType, cdiv, is_pin_memory_available,
is_pin_memory_available, prev_power_of_2) prev_power_of_2)
from vllm.v1.attention.backends.pallas import (PallasAttentionBackend, from vllm.v1.attention.backends.pallas import (TPU_STR_DTYPE_TO_TORCH_DTYPE,
PallasAttentionBackend,
PallasMetadata, PallasMetadata,
get_page_size_bytes) get_page_size_bytes)
from vllm.v1.core.encoder_cache_manager import compute_encoder_budget from vllm.v1.core.encoder_cache_manager import compute_encoder_budget
@ -142,11 +143,11 @@ class TPUModelRunner(LoRAModelRunnerMixin):
if cache_config.cache_dtype == "auto": if cache_config.cache_dtype == "auto":
model_dtype = self.dtype model_dtype = self.dtype
if isinstance(model_dtype, str): if isinstance(model_dtype, str):
self.kv_cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[model_dtype] self.kv_cache_dtype = TPU_STR_DTYPE_TO_TORCH_DTYPE[model_dtype]
else: else:
self.kv_cache_dtype = model_dtype self.kv_cache_dtype = model_dtype
else: else:
self.kv_cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[ self.kv_cache_dtype = TPU_STR_DTYPE_TO_TORCH_DTYPE[
cache_config.cache_dtype] cache_config.cache_dtype]
self._hidden_states_dtype = self.dtype self._hidden_states_dtype = self.dtype