diff --git a/vllm/model_executor/models/gpt_oss.py b/vllm/model_executor/models/gpt_oss.py index 6a65bbbe2e0d..7c7712dbe106 100644 --- a/vllm/model_executor/models/gpt_oss.py +++ b/vllm/model_executor/models/gpt_oss.py @@ -8,7 +8,6 @@ import torch.distributed as dist from torch import nn from transformers import GptOssConfig -from vllm import envs from vllm.attention import Attention, AttentionType from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, VllmConfig @@ -70,11 +69,9 @@ class OAIAttention(nn.Module): tp_size = get_tensor_model_parallel_world_size() - attention_sink_dtype = (torch.float32 if envs.VLLM_USE_TRTLLM_ATTENTION - else torch.bfloat16) self.sinks = torch.nn.Parameter( torch.empty(config.num_attention_heads // tp_size, - dtype=attention_sink_dtype, + dtype=torch.bfloat16, requires_grad=False)) self.norm = RMSNorm(config.hidden_size, eps=1e-5) diff --git a/vllm/utils/flashinfer.py b/vllm/utils/flashinfer.py index 5998d4c3127f..6b23ed426806 100644 --- a/vllm/utils/flashinfer.py +++ b/vllm/utils/flashinfer.py @@ -154,6 +154,7 @@ def use_trtllm_attention( num_qo_heads: Optional[int], num_kv_heads: Optional[int], attn_head_size: Optional[int], + has_sinks: bool = False, ) -> bool: # Requires SM100 and NVIDIA artifactory to be accessible to download cubins if not (current_platform.is_device_capability(100) @@ -165,6 +166,13 @@ def use_trtllm_attention( or num_qo_heads % num_kv_heads != 0): return False + # If sinks are being used, we must use TRTLLM attention as it's + # the only backend that supports them + if has_sinks: + logger.info_once( + "Using TRTLLM attention (required for attention sinks).") + return True + env_value = envs.VLLM_USE_TRTLLM_ATTENTION if env_value is not None: logger.info_once("VLLM_USE_TRTLLM_ATTENTION is set to %s", env_value) diff --git a/vllm/v1/attention/backends/flashinfer.py b/vllm/v1/attention/backends/flashinfer.py index c85d8bce31f5..12e5542d691c 100755 --- a/vllm/v1/attention/backends/flashinfer.py +++ b/vllm/v1/attention/backends/flashinfer.py @@ -523,14 +523,17 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): num_kv_heads = self.kv_cache_spec.num_kv_heads head_dim = self.kv_cache_spec.head_size + # Check if any layer uses sinks (requires TRTLLM attention) + has_sinks = self.global_hyperparameters.has_sinks + # currently prefill trtllm attention does not support fp8 kv cache prefill_use_trtllm = not cache_dtype.startswith("fp8") \ and use_trtllm_attention( num_prefill_tokens, max_seq_len, cache_dtype, - num_qo_heads, num_kv_heads, head_dim) + num_qo_heads, num_kv_heads, head_dim, has_sinks) decode_use_trtllm = use_trtllm_attention( num_decode_tokens, max_seq_len, cache_dtype, - num_qo_heads, num_kv_heads, head_dim) + num_qo_heads, num_kv_heads, head_dim, has_sinks) attn_metadata = FlashInferMetadata( num_actual_tokens=num_actual_tokens, @@ -642,9 +645,9 @@ class FlashInferImpl(AttentionImpl): f"heads in the layer. Expected {num_heads}, but got " f"{sinks.shape[0]}." ) + # Cast sinks to float32 if needed (FlashInfer requirement) if sinks.dtype != torch.float32: - raise ValueError("Sinks must be of type float32, but got " - f"{sinks.dtype}.") + sinks = sinks.to(torch.float32) self.sinks = sinks def forward( diff --git a/vllm/v1/attention/backends/utils.py b/vllm/v1/attention/backends/utils.py index e23dd8bc5bbb..91eb84245ac0 100644 --- a/vllm/v1/attention/backends/utils.py +++ b/vllm/v1/attention/backends/utils.py @@ -285,6 +285,7 @@ class PerLayerParameters: window_left: int logits_soft_cap: Optional[float] sm_scale: float + has_sinks: bool = False def get_per_layer_parameters( @@ -307,9 +308,11 @@ def get_per_layer_parameters( window_left = window_size[0] if window_size is not None else -1 logits_soft_cap = getattr(impl, "logits_soft_cap", None) sm_scale = impl.scale + has_sinks = getattr(impl, "sinks", None) is not None per_layer_params[key] = PerLayerParameters(window_left, - logits_soft_cap, sm_scale) + logits_soft_cap, sm_scale, + has_sinks) return per_layer_params