diff --git a/vllm/attention/ops/chunked_prefill_paged_decode.py b/vllm/attention/ops/chunked_prefill_paged_decode.py index dc039a0259aa5..217db3bf965de 100644 --- a/vllm/attention/ops/chunked_prefill_paged_decode.py +++ b/vllm/attention/ops/chunked_prefill_paged_decode.py @@ -9,6 +9,7 @@ import torch from vllm import _custom_ops as ops +from vllm.platforms import current_platform from vllm.platforms.rocm import use_rocm_custom_paged_attention from vllm.triton_utils import tl, triton @@ -267,7 +268,7 @@ def chunked_prefill_paged_decode( assert value_cache.dtype == torch.uint8 if kv_cache_dtype in ("fp8", "fp8_e4m3"): - target_dtype = torch.float8_e4m3fn + target_dtype = current_platform.fp8_dtype() elif kv_cache_dtype == "fp8_e5m2": target_dtype = torch.float8_e5m2 else: diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index a3b34f4ba6544..26a5784e8458c 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -1205,7 +1205,9 @@ class EngineArgs: and not envs.is_set("VLLM_ATTENTION_BACKEND") ) or envs.VLLM_ATTENTION_BACKEND == "FLASH_ATTN_VLLM_V1" supported = False - if fp8_attention and will_use_fa: + if current_platform.is_rocm(): + supported = True + elif fp8_attention and will_use_fa: from vllm.attention.utils.fa_utils import ( flash_attn_supports_fp8) supported = flash_attn_supports_fp8() diff --git a/vllm/v1/attention/backends/triton_attn.py b/vllm/v1/attention/backends/triton_attn.py index bb700c8e2e7ad..c4922a716bc2a 100644 --- a/vllm/v1/attention/backends/triton_attn.py +++ b/vllm/v1/attention/backends/triton_attn.py @@ -9,6 +9,7 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, AttentionMetadata, AttentionType) from vllm.attention.ops.triton_unified_attention import unified_attention from vllm.logger import init_logger +from vllm.platforms import current_platform from vllm.v1.attention.backends.flash_attn import ( FlashAttentionMetadata, FlashAttentionMetadataBuilder) @@ -108,6 +109,8 @@ class TritonAttentionImpl(AttentionImpl): "are not implemented for " "TritonAttentionImpl") + self.fp8_dtype = current_platform.fp8_dtype() + def forward( self, layer: torch.nn.Module, @@ -161,15 +164,18 @@ class TritonAttentionImpl(AttentionImpl): ) if self.kv_cache_dtype.startswith("fp8"): - key_cache = key_cache.view(torch.float8_e4m3fn) - value_cache = value_cache.view(torch.float8_e4m3fn) + key_cache = key_cache.view(self.fp8_dtype) + value_cache = value_cache.view(self.fp8_dtype) num_tokens, num_heads, head_size = query.shape assert layer._q_scale == 1.0, \ "A non 1.0 q_scale is not currently supported." - query, _ = ops.scaled_fp8_quant( - query.reshape( - (num_tokens, num_heads * head_size)).contiguous(), - layer._q_scale) + if not current_platform.is_rocm(): + # Skip Q quantization on ROCm, since dequantizing back to + # f32 in the attention kernel is not supported. + query, _ = ops.scaled_fp8_quant( + query.reshape( + (num_tokens, num_heads * head_size)).contiguous(), + layer._q_scale) query = query.reshape((num_tokens, num_heads, head_size)) use_local_attn = \