fp8 kv cache support fix for torch.compile (#22758)

Signed-off-by: Aleksandr Malyshev <maleksan@amd.com>
Signed-off-by: Gregory Shtrasberg <Gregory.Shtrasberg@amd.com>
Co-authored-by: Aleksandr Malyshev <maleksan@amd.com>
Co-authored-by: Gregory Shtrasberg <Gregory.Shtrasberg@amd.com>
Co-authored-by: Gregory Shtrasberg <156009573+gshtras@users.noreply.github.com>
This commit is contained in:
Aleksandr Malyshev 2025-09-16 14:27:11 -07:00 committed by GitHub
parent 02d4b85454
commit 3053a22b33
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 4 additions and 2 deletions

View File

@ -125,7 +125,9 @@ class BaseKVCacheMethod(QuantizeMethodBase):
# These are used in the final Attention.forward()
layer._q_scale.copy_(q_scale)
layer._q_scale_float = q_scale
layer._q_scale_float = q_scale.item() if isinstance(
q_scale, torch.Tensor) else q_scale
layer._prob_scale.copy_(prob_scale)
if layer.kv_cache_dtype == "fp8" and (q_scale == 1.0
or prob_scale == 1.0):

View File

@ -361,7 +361,7 @@ class TritonAttentionImpl(AttentionImpl):
key_cache = key_cache.view(self.fp8_dtype)
value_cache = value_cache.view(self.fp8_dtype)
num_tokens, num_heads, head_size = query.shape
assert layer._q_scale == 1.0, \
assert layer._q_scale_float == 1.0, \
"A non 1.0 q_scale is not currently supported."
if current_platform.is_cuda():
# Skip Q quantization on ROCm and XPU, enable this on cuda