mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-24 03:54:28 +08:00
[AMD] [Quantization] Add override flag for attention dtype instead of using kv_cache_dtype trigger (#17331)
Signed-off-by: Randall Smith <Randall.Smith@amd.com>
This commit is contained in:
parent
29fa5cac1c
commit
c7ea0b56cd
@ -17,6 +17,7 @@ from vllm.attention.backends.utils import (CommonAttentionState,
|
||||
CommonMetadataBuilder)
|
||||
from vllm.attention.ops.paged_attn import (PagedAttention,
|
||||
PagedAttentionMetadata)
|
||||
from vllm.config import get_current_vllm_config
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.platforms.rocm import use_rocm_custom_paged_attention
|
||||
@ -584,6 +585,10 @@ class ROCmFlashAttentionImpl(AttentionImpl):
|
||||
logger.debug("Using naive (SDPA) attention in ROCmBackend")
|
||||
|
||||
self.aiter_kv_scales_initialized = False
|
||||
self.force_fp8_attention = (
|
||||
get_current_vllm_config() is not None
|
||||
and get_current_vllm_config().model_config.override_attention_dtype
|
||||
== "fp8")
|
||||
|
||||
def repeat_kv(self, x: torch.Tensor, n_rep: int) -> torch.Tensor:
|
||||
"""torch.repeat_interleave(x, dim=1, repeats=n_rep)"""
|
||||
@ -770,9 +775,12 @@ class ROCmFlashAttentionImpl(AttentionImpl):
|
||||
query.dtype,
|
||||
seq_lens,
|
||||
make_attn_mask=causal_mask) # type: ignore
|
||||
|
||||
use_fp8_scales = (layer._q_scale and layer._k_scale
|
||||
and layer._v_scale and layer._prob_scale
|
||||
and self.kv_cache_dtype == "fp8")
|
||||
and (self.kv_cache_dtype == "fp8"
|
||||
or self.force_fp8_attention))
|
||||
|
||||
full_scales = (
|
||||
layer._q_scale.item(), layer._k_scale.item(),
|
||||
layer._v_scale.item(),
|
||||
|
||||
@ -417,6 +417,8 @@ class ModelConfig:
|
||||
available.\n
|
||||
- "vllm" will use the vLLM model implementation.\n
|
||||
- "transformers" will use the Transformers model implementation."""
|
||||
override_attention_dtype: Optional[str] = None
|
||||
"""Override dtype for attention"""
|
||||
|
||||
def compute_hash(self) -> str:
|
||||
"""
|
||||
@ -517,6 +519,12 @@ class ModelConfig:
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
if (self.override_attention_dtype is not None
|
||||
and not current_platform.is_rocm()):
|
||||
warnings.warn(
|
||||
"override-attention-dtype is set but not using ROCm platform",
|
||||
stacklevel=2)
|
||||
|
||||
if (self.enable_sleep_mode
|
||||
and not current_platform.is_sleep_mode_available()):
|
||||
raise ValueError(
|
||||
|
||||
@ -429,6 +429,7 @@ class EngineArgs:
|
||||
override_generation_config: dict[str, Any] = \
|
||||
get_field(ModelConfig, "override_generation_config")
|
||||
model_impl: str = ModelConfig.model_impl
|
||||
override_attention_dtype: str = ModelConfig.override_attention_dtype
|
||||
|
||||
calculate_kv_scales: bool = CacheConfig.calculate_kv_scales
|
||||
|
||||
@ -549,6 +550,8 @@ class EngineArgs:
|
||||
model_group.add_argument("--model-impl",
|
||||
choices=[f.value for f in ModelImpl],
|
||||
**model_kwargs["model_impl"])
|
||||
model_group.add_argument("--override-attention-dtype",
|
||||
**model_kwargs["override_attention_dtype"])
|
||||
|
||||
# Model loading arguments
|
||||
load_kwargs = get_kwargs(LoadConfig)
|
||||
@ -946,6 +949,7 @@ class EngineArgs:
|
||||
override_generation_config=self.override_generation_config,
|
||||
enable_sleep_mode=self.enable_sleep_mode,
|
||||
model_impl=self.model_impl,
|
||||
override_attention_dtype=self.override_attention_dtype,
|
||||
)
|
||||
|
||||
def create_load_config(self) -> LoadConfig:
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user