mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-10 20:10:19 +08:00
[AMD] [Quantization] Add override flag for attention dtype instead of using kv_cache_dtype trigger (#17331)
Signed-off-by: Randall Smith <Randall.Smith@amd.com>
This commit is contained in:
parent
29fa5cac1c
commit
c7ea0b56cd
@ -17,6 +17,7 @@ from vllm.attention.backends.utils import (CommonAttentionState,
|
|||||||
CommonMetadataBuilder)
|
CommonMetadataBuilder)
|
||||||
from vllm.attention.ops.paged_attn import (PagedAttention,
|
from vllm.attention.ops.paged_attn import (PagedAttention,
|
||||||
PagedAttentionMetadata)
|
PagedAttentionMetadata)
|
||||||
|
from vllm.config import get_current_vllm_config
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.platforms.rocm import use_rocm_custom_paged_attention
|
from vllm.platforms.rocm import use_rocm_custom_paged_attention
|
||||||
@ -584,6 +585,10 @@ class ROCmFlashAttentionImpl(AttentionImpl):
|
|||||||
logger.debug("Using naive (SDPA) attention in ROCmBackend")
|
logger.debug("Using naive (SDPA) attention in ROCmBackend")
|
||||||
|
|
||||||
self.aiter_kv_scales_initialized = False
|
self.aiter_kv_scales_initialized = False
|
||||||
|
self.force_fp8_attention = (
|
||||||
|
get_current_vllm_config() is not None
|
||||||
|
and get_current_vllm_config().model_config.override_attention_dtype
|
||||||
|
== "fp8")
|
||||||
|
|
||||||
def repeat_kv(self, x: torch.Tensor, n_rep: int) -> torch.Tensor:
|
def repeat_kv(self, x: torch.Tensor, n_rep: int) -> torch.Tensor:
|
||||||
"""torch.repeat_interleave(x, dim=1, repeats=n_rep)"""
|
"""torch.repeat_interleave(x, dim=1, repeats=n_rep)"""
|
||||||
@ -770,9 +775,12 @@ class ROCmFlashAttentionImpl(AttentionImpl):
|
|||||||
query.dtype,
|
query.dtype,
|
||||||
seq_lens,
|
seq_lens,
|
||||||
make_attn_mask=causal_mask) # type: ignore
|
make_attn_mask=causal_mask) # type: ignore
|
||||||
|
|
||||||
use_fp8_scales = (layer._q_scale and layer._k_scale
|
use_fp8_scales = (layer._q_scale and layer._k_scale
|
||||||
and layer._v_scale and layer._prob_scale
|
and layer._v_scale and layer._prob_scale
|
||||||
and self.kv_cache_dtype == "fp8")
|
and (self.kv_cache_dtype == "fp8"
|
||||||
|
or self.force_fp8_attention))
|
||||||
|
|
||||||
full_scales = (
|
full_scales = (
|
||||||
layer._q_scale.item(), layer._k_scale.item(),
|
layer._q_scale.item(), layer._k_scale.item(),
|
||||||
layer._v_scale.item(),
|
layer._v_scale.item(),
|
||||||
|
|||||||
@ -417,6 +417,8 @@ class ModelConfig:
|
|||||||
available.\n
|
available.\n
|
||||||
- "vllm" will use the vLLM model implementation.\n
|
- "vllm" will use the vLLM model implementation.\n
|
||||||
- "transformers" will use the Transformers model implementation."""
|
- "transformers" will use the Transformers model implementation."""
|
||||||
|
override_attention_dtype: Optional[str] = None
|
||||||
|
"""Override dtype for attention"""
|
||||||
|
|
||||||
def compute_hash(self) -> str:
|
def compute_hash(self) -> str:
|
||||||
"""
|
"""
|
||||||
@ -517,6 +519,12 @@ class ModelConfig:
|
|||||||
|
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
|
if (self.override_attention_dtype is not None
|
||||||
|
and not current_platform.is_rocm()):
|
||||||
|
warnings.warn(
|
||||||
|
"override-attention-dtype is set but not using ROCm platform",
|
||||||
|
stacklevel=2)
|
||||||
|
|
||||||
if (self.enable_sleep_mode
|
if (self.enable_sleep_mode
|
||||||
and not current_platform.is_sleep_mode_available()):
|
and not current_platform.is_sleep_mode_available()):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
|
|||||||
@ -429,6 +429,7 @@ class EngineArgs:
|
|||||||
override_generation_config: dict[str, Any] = \
|
override_generation_config: dict[str, Any] = \
|
||||||
get_field(ModelConfig, "override_generation_config")
|
get_field(ModelConfig, "override_generation_config")
|
||||||
model_impl: str = ModelConfig.model_impl
|
model_impl: str = ModelConfig.model_impl
|
||||||
|
override_attention_dtype: str = ModelConfig.override_attention_dtype
|
||||||
|
|
||||||
calculate_kv_scales: bool = CacheConfig.calculate_kv_scales
|
calculate_kv_scales: bool = CacheConfig.calculate_kv_scales
|
||||||
|
|
||||||
@ -549,6 +550,8 @@ class EngineArgs:
|
|||||||
model_group.add_argument("--model-impl",
|
model_group.add_argument("--model-impl",
|
||||||
choices=[f.value for f in ModelImpl],
|
choices=[f.value for f in ModelImpl],
|
||||||
**model_kwargs["model_impl"])
|
**model_kwargs["model_impl"])
|
||||||
|
model_group.add_argument("--override-attention-dtype",
|
||||||
|
**model_kwargs["override_attention_dtype"])
|
||||||
|
|
||||||
# Model loading arguments
|
# Model loading arguments
|
||||||
load_kwargs = get_kwargs(LoadConfig)
|
load_kwargs = get_kwargs(LoadConfig)
|
||||||
@ -946,6 +949,7 @@ class EngineArgs:
|
|||||||
override_generation_config=self.override_generation_config,
|
override_generation_config=self.override_generation_config,
|
||||||
enable_sleep_mode=self.enable_sleep_mode,
|
enable_sleep_mode=self.enable_sleep_mode,
|
||||||
model_impl=self.model_impl,
|
model_impl=self.model_impl,
|
||||||
|
override_attention_dtype=self.override_attention_dtype,
|
||||||
)
|
)
|
||||||
|
|
||||||
def create_load_config(self) -> LoadConfig:
|
def create_load_config(self) -> LoadConfig:
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user