diff --git a/tests/entrypoints/llm/test_accuracy.py b/tests/entrypoints/llm/test_accuracy.py index 30a666d4c39c..6c5706d16340 100644 --- a/tests/entrypoints/llm/test_accuracy.py +++ b/tests/entrypoints/llm/test_accuracy.py @@ -15,15 +15,18 @@ import pytest from vllm.platforms import current_platform MODEL_NAMES = [ - "Qwen/Qwen2-1.5B-Instruct", + "Qwen/Qwen3-1.7B", "google/gemma-3-1b-it", ] +FP8_KV_MODEL_NAMES = [ + "Qwen/Qwen3-1.7B", +] NUM_CONCURRENT = 500 TASK = "gsm8k" FILTER = "exact_match,strict-match" RTOL = 0.03 EXPECTED_VALUES = { - "Qwen/Qwen2-1.5B-Instruct": 0.58, + "Qwen/Qwen3-1.7B": 0.68, "google/gemma-3-1b-it": 0.25, } @@ -70,10 +73,9 @@ def test_lm_eval_accuracy_v1_engine(model, monkeypatch: pytest.MonkeyPatch): if current_platform.is_tpu(): # Limit compilation time for TPU V1 - if model == "google/gemma-3-1b-it": - # TPU + google/gemma-3-1b-it + xet doesn't work well. - m.setenv("HF_HUB_DISABLE_XET", "1") - + # xet doesn't work well for both Qwen/Qwen3-1.7B and + # google/gemma-3-1b-it + m.setenv("HF_HUB_DISABLE_XET", "1") more_args = "max_model_len=2048,max_num_seqs=64" # Add TP test (if provided) @@ -83,9 +85,27 @@ def test_lm_eval_accuracy_v1_engine(model, monkeypatch: pytest.MonkeyPatch): run_test(model, more_args) -def test_lm_eval_accuracy_v0_engine(monkeypatch: pytest.MonkeyPatch): - """Run with the V0 Engine.""" +@pytest.mark.skipif(not current_platform.is_cuda() + and not current_platform.is_tpu(), + reason="V1 is currently only supported on CUDA and TPU") +@pytest.mark.parametrize("model", FP8_KV_MODEL_NAMES) +def test_lm_eval_accuracy_v1_engine_fp8_kv_cache( + model, monkeypatch: pytest.MonkeyPatch): + """Run with the V1 Engine.""" with monkeypatch.context() as m: - m.setenv("VLLM_USE_V1", "0") - run_test("Qwen/Qwen2-1.5B-Instruct") + m.setenv("VLLM_USE_V1", "1") + + more_args = None + if current_platform.is_tpu(): + # Limit compilation time for TPU V1 + + # xet doesn't work well for Qwen/Qwen3-1.7B + m.setenv("HF_HUB_DISABLE_XET", "1") + more_args = "max_model_len=2048,max_num_seqs=128,kv_cache_dtype=fp8" + + # Add TP test (if provided) + if TPU_TP_TEST_STR: + more_args += ",{}".format(TPU_TP_TEST_STR) + + run_test(model, more_args) diff --git a/tests/v1/tpu/test_pallas.py b/tests/v1/tpu/test_pallas.py index df89133170b8..bfba3af57f71 100644 --- a/tests/v1/tpu/test_pallas.py +++ b/tests/v1/tpu/test_pallas.py @@ -95,4 +95,6 @@ def test_ragged_paged_attention(): sm_scale=scale, sliding_window=sliding_window, soft_cap=logits_soft_cap, + k_scale=1.0, + v_scale=1.0, ) diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 1ca4917de26b..019ff033eda2 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -1358,10 +1358,10 @@ class EngineArgs: and not envs.is_set("VLLM_ATTENTION_BACKEND") ) or envs.VLLM_ATTENTION_BACKEND == "FLASH_ATTN_VLLM_V1" supported = False - if current_platform.is_rocm() or ( - current_platform.is_cuda() - and current_platform.is_device_capability(100) - ): # handle hpu also for OOT platform + if (current_platform.is_rocm() + or (current_platform.is_cuda() + and current_platform.is_device_capability(100)) + or current_platform.is_tpu()): supported = True elif fp8_attention and will_use_fa: from vllm.attention.utils.fa_utils import ( diff --git a/vllm/platforms/tpu.py b/vllm/platforms/tpu.py index 5ec3be908e7d..febc6ae4662b 100644 --- a/vllm/platforms/tpu.py +++ b/vllm/platforms/tpu.py @@ -35,7 +35,9 @@ class TpuPlatform(Platform): device_control_env_var: str = "TPU_VISIBLE_CHIPS" simple_compile_backend: str = "openxla" - supported_quantization: list[str] = ["tpu_int8", "compressed-tensors"] + supported_quantization: list[str] = [ + "fp8", "tpu_int8", "compressed-tensors" + ] additional_env_vars: list[str] = [ "TPU_CHIPS_PER_HOST_BOUNDS", "TPU_HOST_BOUNDS" diff --git a/vllm/v1/attention/backends/pallas.py b/vllm/v1/attention/backends/pallas.py index ac7980c79e4d..9307cd937d5d 100644 --- a/vllm/v1/attention/backends/pallas.py +++ b/vllm/v1/attention/backends/pallas.py @@ -24,6 +24,19 @@ logger = init_logger(__name__) # TPU requires the head size to be a multiple of 128. TPU_HEAD_SIZE_ALIGNMENT = 128 +# Note: TPU can fp8 as storage dtype but doesn't support converting from uint8 +# from to fp32 directly. That's why it has a dtype mapping different from GPU +TPU_STR_DTYPE_TO_TORCH_DTYPE = { + "half": torch.half, + "bfloat16": torch.bfloat16, + "float": torch.float, + "fp8": torch.float8_e4m3fn, + "fp8_e4m3": torch.float8_e4m3fn, + "fp8_e5m2": torch.float8_e5m2, + "int8": torch.int8, + "uint8": torch.uint8, +} + class PallasAttentionBackend(AttentionBackend): @@ -152,8 +165,6 @@ class PallasAttentionBackendImpl(AttentionImpl): self.num_queries_per_kv = self.num_heads // self.num_kv_heads if alibi_slopes is not None: raise NotImplementedError("Alibi slopes is not supported.") - if kv_cache_dtype != "auto": - raise NotImplementedError("FP8 KV cache dtype is not supported.") if attn_type != AttentionType.DECODER: raise NotImplementedError("Encoder self-attention and " @@ -161,6 +172,11 @@ class PallasAttentionBackendImpl(AttentionImpl): "are not implemented for " "PallasAttentionBackendImpl") + self.kv_cache_quantized_dtype = None + if kv_cache_dtype != "auto": + self.kv_cache_quantized_dtype = TPU_STR_DTYPE_TO_TORCH_DTYPE.get( + kv_cache_dtype.lower().strip()) + def forward( self, layer: AttentionLayer, @@ -194,7 +210,6 @@ class PallasAttentionBackendImpl(AttentionImpl): output = torch.ones_like(query) return output - assert layer._k_scale_float == 1.0 and layer._v_scale_float == 1.0 num_tokens, hidden_size = query.shape query = query.view(num_tokens, self.num_heads, self.head_size) key = key.view(-1, self.num_kv_heads, self.head_size) @@ -215,10 +230,21 @@ class PallasAttentionBackendImpl(AttentionImpl): # Skip this if sharing KV cache with an earlier attention layer. slot_mapping = attn_metadata.slot_mapping write_to_kv_cache( - key, value, kv_cache, slot_mapping, + key, + value, + kv_cache, + slot_mapping, attn_metadata.num_slices_per_kv_cache_update_block, - attn_metadata.num_kv_update_slices) + attn_metadata.num_kv_update_slices, + self.kv_cache_quantized_dtype, + layer._k_scale_float, + layer._v_scale_float, + ) + if self.kv_cache_quantized_dtype is not None and ( + layer._k_scale_float == 0.0 or layer._v_scale_float == 0.0): + raise ValueError( + "k_scale_float and v_scale_float must be non-zero") output = torch.ops.xla.ragged_paged_attention( query, kv_cache, @@ -236,6 +262,8 @@ class PallasAttentionBackendImpl(AttentionImpl): sm_scale=self.scale, sliding_window=self.sliding_window, soft_cap=self.logits_soft_cap, + k_scale=layer._k_scale_float, + v_scale=layer._v_scale_float, ) if self.head_size % TPU_HEAD_SIZE_ALIGNMENT != 0: @@ -251,18 +279,32 @@ def write_to_kv_cache( slot_mapping: torch.Tensor, num_slices_per_kv_cache_update_block: int, num_kv_update_slices: torch.Tensor, + kv_cache_quantized_dtype: Optional[torch.dtype] = None, + k_scale: float = 1.0, + v_scale: float = 1.0, ) -> None: """ Write the key and values to the KV cache. Args: - key: shape = [num_tokens, num_kv_heads * head_size] - value: shape = [num_tokens, num_kv_heads * head_size] + key: shape = [num_tokens, num_kv_heads, head_size] + value: shape = [num_tokens, num_kv_heads, head_size] kv_cache = [num_blocks, block_size, num_kv_heads * 2, head_size] num_slices_per_kv_cache_update_block: int """ _, page_size, num_combined_kv_heads, head_size = kv_cache.shape head_size = cdiv(head_size, TPU_HEAD_SIZE_ALIGNMENT) * TPU_HEAD_SIZE_ALIGNMENT + + if kv_cache_quantized_dtype is not None: + dtype_info = torch.finfo(kv_cache_quantized_dtype) + key = key.to(torch.float32) / k_scale + # NOTE: clamp is added here to avoid out of range of quantized dtype + key = torch.clamp(key, dtype_info.min, dtype_info.max) + key = key.to(kv_cache_quantized_dtype) + value = value.to(torch.float32) / v_scale + value = torch.clamp(value, dtype_info.min, dtype_info.max) + value = value.to(kv_cache_quantized_dtype) + kv = torch.cat([key, value], axis=-1).reshape(-1, num_combined_kv_heads, head_size) diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index 1b55e5d61aa9..7ed1cf41011b 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -32,9 +32,10 @@ from vllm.multimodal.inputs import (BatchedTensorInputs, MultiModalKwargs, from vllm.multimodal.utils import group_mm_inputs_by_modality from vllm.pooling_params import PoolingTask from vllm.sequence import IntermediateTensors -from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, LayerBlockType, cdiv, - is_pin_memory_available, prev_power_of_2) -from vllm.v1.attention.backends.pallas import (PallasAttentionBackend, +from vllm.utils import (LayerBlockType, cdiv, is_pin_memory_available, + prev_power_of_2) +from vllm.v1.attention.backends.pallas import (TPU_STR_DTYPE_TO_TORCH_DTYPE, + PallasAttentionBackend, PallasMetadata, get_page_size_bytes) from vllm.v1.core.encoder_cache_manager import compute_encoder_budget @@ -142,11 +143,11 @@ class TPUModelRunner(LoRAModelRunnerMixin): if cache_config.cache_dtype == "auto": model_dtype = self.dtype if isinstance(model_dtype, str): - self.kv_cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[model_dtype] + self.kv_cache_dtype = TPU_STR_DTYPE_TO_TORCH_DTYPE[model_dtype] else: self.kv_cache_dtype = model_dtype else: - self.kv_cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[ + self.kv_cache_dtype = TPU_STR_DTYPE_TO_TORCH_DTYPE[ cache_config.cache_dtype] self._hidden_states_dtype = self.dtype