diff --git a/vllm/attention/backends/rocm_flash_attn.py b/vllm/attention/backends/rocm_flash_attn.py index 4b460dc0b58cd..7887ebf65f44e 100644 --- a/vllm/attention/backends/rocm_flash_attn.py +++ b/vllm/attention/backends/rocm_flash_attn.py @@ -17,6 +17,7 @@ from vllm.attention.backends.utils import (CommonAttentionState, CommonMetadataBuilder) from vllm.attention.ops.paged_attn import (PagedAttention, PagedAttentionMetadata) +from vllm.config import get_current_vllm_config from vllm.logger import init_logger from vllm.platforms import current_platform from vllm.platforms.rocm import use_rocm_custom_paged_attention @@ -584,6 +585,10 @@ class ROCmFlashAttentionImpl(AttentionImpl): logger.debug("Using naive (SDPA) attention in ROCmBackend") self.aiter_kv_scales_initialized = False + self.force_fp8_attention = ( + get_current_vllm_config() is not None + and get_current_vllm_config().model_config.override_attention_dtype + == "fp8") def repeat_kv(self, x: torch.Tensor, n_rep: int) -> torch.Tensor: """torch.repeat_interleave(x, dim=1, repeats=n_rep)""" @@ -770,9 +775,12 @@ class ROCmFlashAttentionImpl(AttentionImpl): query.dtype, seq_lens, make_attn_mask=causal_mask) # type: ignore + use_fp8_scales = (layer._q_scale and layer._k_scale and layer._v_scale and layer._prob_scale - and self.kv_cache_dtype == "fp8") + and (self.kv_cache_dtype == "fp8" + or self.force_fp8_attention)) + full_scales = ( layer._q_scale.item(), layer._k_scale.item(), layer._v_scale.item(), diff --git a/vllm/config.py b/vllm/config.py index 32ef83a1866db..5da44988bc5f1 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -417,6 +417,8 @@ class ModelConfig: available.\n - "vllm" will use the vLLM model implementation.\n - "transformers" will use the Transformers model implementation.""" + override_attention_dtype: Optional[str] = None + """Override dtype for attention""" def compute_hash(self) -> str: """ @@ -517,6 +519,12 @@ class ModelConfig: from vllm.platforms import current_platform + if (self.override_attention_dtype is not None + and not current_platform.is_rocm()): + warnings.warn( + "override-attention-dtype is set but not using ROCm platform", + stacklevel=2) + if (self.enable_sleep_mode and not current_platform.is_sleep_mode_available()): raise ValueError( diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 38d567acfd8af..85b7bbfbd93d1 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -429,6 +429,7 @@ class EngineArgs: override_generation_config: dict[str, Any] = \ get_field(ModelConfig, "override_generation_config") model_impl: str = ModelConfig.model_impl + override_attention_dtype: str = ModelConfig.override_attention_dtype calculate_kv_scales: bool = CacheConfig.calculate_kv_scales @@ -549,6 +550,8 @@ class EngineArgs: model_group.add_argument("--model-impl", choices=[f.value for f in ModelImpl], **model_kwargs["model_impl"]) + model_group.add_argument("--override-attention-dtype", + **model_kwargs["override_attention_dtype"]) # Model loading arguments load_kwargs = get_kwargs(LoadConfig) @@ -946,6 +949,7 @@ class EngineArgs: override_generation_config=self.override_generation_config, enable_sleep_mode=self.enable_sleep_mode, model_impl=self.model_impl, + override_attention_dtype=self.override_attention_dtype, ) def create_load_config(self) -> LoadConfig: