mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-23 18:44:30 +08:00
fp8 kv cache support fix for torch.compile (#22758)
Signed-off-by: Aleksandr Malyshev <maleksan@amd.com> Signed-off-by: Gregory Shtrasberg <Gregory.Shtrasberg@amd.com> Co-authored-by: Aleksandr Malyshev <maleksan@amd.com> Co-authored-by: Gregory Shtrasberg <Gregory.Shtrasberg@amd.com> Co-authored-by: Gregory Shtrasberg <156009573+gshtras@users.noreply.github.com>
This commit is contained in:
parent
02d4b85454
commit
3053a22b33
@ -125,7 +125,9 @@ class BaseKVCacheMethod(QuantizeMethodBase):
|
||||
|
||||
# These are used in the final Attention.forward()
|
||||
layer._q_scale.copy_(q_scale)
|
||||
layer._q_scale_float = q_scale
|
||||
layer._q_scale_float = q_scale.item() if isinstance(
|
||||
q_scale, torch.Tensor) else q_scale
|
||||
|
||||
layer._prob_scale.copy_(prob_scale)
|
||||
if layer.kv_cache_dtype == "fp8" and (q_scale == 1.0
|
||||
or prob_scale == 1.0):
|
||||
|
||||
@ -361,7 +361,7 @@ class TritonAttentionImpl(AttentionImpl):
|
||||
key_cache = key_cache.view(self.fp8_dtype)
|
||||
value_cache = value_cache.view(self.fp8_dtype)
|
||||
num_tokens, num_heads, head_size = query.shape
|
||||
assert layer._q_scale == 1.0, \
|
||||
assert layer._q_scale_float == 1.0, \
|
||||
"A non 1.0 q_scale is not currently supported."
|
||||
if current_platform.is_cuda():
|
||||
# Skip Q quantization on ROCm and XPU, enable this on cuda
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user