From 3053a22b330cd7170dce6f33f3a2043c64a99599 Mon Sep 17 00:00:00 2001 From: Aleksandr Malyshev <164964928+maleksan85@users.noreply.github.com> Date: Tue, 16 Sep 2025 14:27:11 -0700 Subject: [PATCH] fp8 kv cache support fix for torch.compile (#22758) Signed-off-by: Aleksandr Malyshev Signed-off-by: Gregory Shtrasberg Co-authored-by: Aleksandr Malyshev Co-authored-by: Gregory Shtrasberg Co-authored-by: Gregory Shtrasberg <156009573+gshtras@users.noreply.github.com> --- vllm/model_executor/layers/quantization/kv_cache.py | 4 +++- vllm/v1/attention/backends/triton_attn.py | 2 +- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/layers/quantization/kv_cache.py b/vllm/model_executor/layers/quantization/kv_cache.py index 4c6fcda893a03..275a1c43fdd2b 100644 --- a/vllm/model_executor/layers/quantization/kv_cache.py +++ b/vllm/model_executor/layers/quantization/kv_cache.py @@ -125,7 +125,9 @@ class BaseKVCacheMethod(QuantizeMethodBase): # These are used in the final Attention.forward() layer._q_scale.copy_(q_scale) - layer._q_scale_float = q_scale + layer._q_scale_float = q_scale.item() if isinstance( + q_scale, torch.Tensor) else q_scale + layer._prob_scale.copy_(prob_scale) if layer.kv_cache_dtype == "fp8" and (q_scale == 1.0 or prob_scale == 1.0): diff --git a/vllm/v1/attention/backends/triton_attn.py b/vllm/v1/attention/backends/triton_attn.py index c294a5a73cbdd..784912a122f68 100644 --- a/vllm/v1/attention/backends/triton_attn.py +++ b/vllm/v1/attention/backends/triton_attn.py @@ -361,7 +361,7 @@ class TritonAttentionImpl(AttentionImpl): key_cache = key_cache.view(self.fp8_dtype) value_cache = value_cache.view(self.fp8_dtype) num_tokens, num_heads, head_size = query.shape - assert layer._q_scale == 1.0, \ + assert layer._q_scale_float == 1.0, \ "A non 1.0 q_scale is not currently supported." if current_platform.is_cuda(): # Skip Q quantization on ROCm and XPU, enable this on cuda