[AMD] [Quantization] Add override flag for attention dtype instead of using kv_cache_dtype trigger (#17331)

Signed-off-by: Randall Smith <Randall.Smith@amd.com>
This commit is contained in:
rasmith 2025-06-11 14:53:28 -05:00 committed by GitHub
parent 29fa5cac1c
commit c7ea0b56cd
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 21 additions and 1 deletions

View File

@ -17,6 +17,7 @@ from vllm.attention.backends.utils import (CommonAttentionState,
CommonMetadataBuilder)
from vllm.attention.ops.paged_attn import (PagedAttention,
PagedAttentionMetadata)
from vllm.config import get_current_vllm_config
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.platforms.rocm import use_rocm_custom_paged_attention
@ -584,6 +585,10 @@ class ROCmFlashAttentionImpl(AttentionImpl):
logger.debug("Using naive (SDPA) attention in ROCmBackend")
self.aiter_kv_scales_initialized = False
self.force_fp8_attention = (
get_current_vllm_config() is not None
and get_current_vllm_config().model_config.override_attention_dtype
== "fp8")
def repeat_kv(self, x: torch.Tensor, n_rep: int) -> torch.Tensor:
"""torch.repeat_interleave(x, dim=1, repeats=n_rep)"""
@ -770,9 +775,12 @@ class ROCmFlashAttentionImpl(AttentionImpl):
query.dtype,
seq_lens,
make_attn_mask=causal_mask) # type: ignore
use_fp8_scales = (layer._q_scale and layer._k_scale
and layer._v_scale and layer._prob_scale
and self.kv_cache_dtype == "fp8")
and (self.kv_cache_dtype == "fp8"
or self.force_fp8_attention))
full_scales = (
layer._q_scale.item(), layer._k_scale.item(),
layer._v_scale.item(),

View File

@ -417,6 +417,8 @@ class ModelConfig:
available.\n
- "vllm" will use the vLLM model implementation.\n
- "transformers" will use the Transformers model implementation."""
override_attention_dtype: Optional[str] = None
"""Override dtype for attention"""
def compute_hash(self) -> str:
"""
@ -517,6 +519,12 @@ class ModelConfig:
from vllm.platforms import current_platform
if (self.override_attention_dtype is not None
and not current_platform.is_rocm()):
warnings.warn(
"override-attention-dtype is set but not using ROCm platform",
stacklevel=2)
if (self.enable_sleep_mode
and not current_platform.is_sleep_mode_available()):
raise ValueError(

View File

@ -429,6 +429,7 @@ class EngineArgs:
override_generation_config: dict[str, Any] = \
get_field(ModelConfig, "override_generation_config")
model_impl: str = ModelConfig.model_impl
override_attention_dtype: str = ModelConfig.override_attention_dtype
calculate_kv_scales: bool = CacheConfig.calculate_kv_scales
@ -549,6 +550,8 @@ class EngineArgs:
model_group.add_argument("--model-impl",
choices=[f.value for f in ModelImpl],
**model_kwargs["model_impl"])
model_group.add_argument("--override-attention-dtype",
**model_kwargs["override_attention_dtype"])
# Model loading arguments
load_kwargs = get_kwargs(LoadConfig)
@ -946,6 +949,7 @@ class EngineArgs:
override_generation_config=self.override_generation_config,
enable_sleep_mode=self.enable_sleep_mode,
model_impl=self.model_impl,
override_attention_dtype=self.override_attention_dtype,
)
def create_load_config(self) -> LoadConfig: