Force TRTLLM attention for gpt-oss on SM100 (#22678)

Signed-off-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
Michael Goin 2025-08-13 00:22:16 -04:00 committed by GitHub
parent b1361c7273
commit c6b928798e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 20 additions and 9 deletions

View File

@ -8,7 +8,6 @@ import torch.distributed as dist
from torch import nn from torch import nn
from transformers import GptOssConfig from transformers import GptOssConfig
from vllm import envs
from vllm.attention import Attention, AttentionType from vllm.attention import Attention, AttentionType
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig from vllm.config import CacheConfig, VllmConfig
@ -70,11 +69,9 @@ class OAIAttention(nn.Module):
tp_size = get_tensor_model_parallel_world_size() tp_size = get_tensor_model_parallel_world_size()
attention_sink_dtype = (torch.float32 if envs.VLLM_USE_TRTLLM_ATTENTION
else torch.bfloat16)
self.sinks = torch.nn.Parameter( self.sinks = torch.nn.Parameter(
torch.empty(config.num_attention_heads // tp_size, torch.empty(config.num_attention_heads // tp_size,
dtype=attention_sink_dtype, dtype=torch.bfloat16,
requires_grad=False)) requires_grad=False))
self.norm = RMSNorm(config.hidden_size, eps=1e-5) self.norm = RMSNorm(config.hidden_size, eps=1e-5)

View File

@ -154,6 +154,7 @@ def use_trtllm_attention(
num_qo_heads: Optional[int], num_qo_heads: Optional[int],
num_kv_heads: Optional[int], num_kv_heads: Optional[int],
attn_head_size: Optional[int], attn_head_size: Optional[int],
has_sinks: bool = False,
) -> bool: ) -> bool:
# Requires SM100 and NVIDIA artifactory to be accessible to download cubins # Requires SM100 and NVIDIA artifactory to be accessible to download cubins
if not (current_platform.is_device_capability(100) if not (current_platform.is_device_capability(100)
@ -165,6 +166,13 @@ def use_trtllm_attention(
or num_qo_heads % num_kv_heads != 0): or num_qo_heads % num_kv_heads != 0):
return False return False
# If sinks are being used, we must use TRTLLM attention as it's
# the only backend that supports them
if has_sinks:
logger.info_once(
"Using TRTLLM attention (required for attention sinks).")
return True
env_value = envs.VLLM_USE_TRTLLM_ATTENTION env_value = envs.VLLM_USE_TRTLLM_ATTENTION
if env_value is not None: if env_value is not None:
logger.info_once("VLLM_USE_TRTLLM_ATTENTION is set to %s", env_value) logger.info_once("VLLM_USE_TRTLLM_ATTENTION is set to %s", env_value)

View File

@ -523,14 +523,17 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
num_kv_heads = self.kv_cache_spec.num_kv_heads num_kv_heads = self.kv_cache_spec.num_kv_heads
head_dim = self.kv_cache_spec.head_size head_dim = self.kv_cache_spec.head_size
# Check if any layer uses sinks (requires TRTLLM attention)
has_sinks = self.global_hyperparameters.has_sinks
# currently prefill trtllm attention does not support fp8 kv cache # currently prefill trtllm attention does not support fp8 kv cache
prefill_use_trtllm = not cache_dtype.startswith("fp8") \ prefill_use_trtllm = not cache_dtype.startswith("fp8") \
and use_trtllm_attention( and use_trtllm_attention(
num_prefill_tokens, max_seq_len, cache_dtype, num_prefill_tokens, max_seq_len, cache_dtype,
num_qo_heads, num_kv_heads, head_dim) num_qo_heads, num_kv_heads, head_dim, has_sinks)
decode_use_trtllm = use_trtllm_attention( decode_use_trtllm = use_trtllm_attention(
num_decode_tokens, max_seq_len, cache_dtype, num_decode_tokens, max_seq_len, cache_dtype,
num_qo_heads, num_kv_heads, head_dim) num_qo_heads, num_kv_heads, head_dim, has_sinks)
attn_metadata = FlashInferMetadata( attn_metadata = FlashInferMetadata(
num_actual_tokens=num_actual_tokens, num_actual_tokens=num_actual_tokens,
@ -642,9 +645,9 @@ class FlashInferImpl(AttentionImpl):
f"heads in the layer. Expected {num_heads}, but got " f"heads in the layer. Expected {num_heads}, but got "
f"{sinks.shape[0]}." f"{sinks.shape[0]}."
) )
# Cast sinks to float32 if needed (FlashInfer requirement)
if sinks.dtype != torch.float32: if sinks.dtype != torch.float32:
raise ValueError("Sinks must be of type float32, but got " sinks = sinks.to(torch.float32)
f"{sinks.dtype}.")
self.sinks = sinks self.sinks = sinks
def forward( def forward(

View File

@ -285,6 +285,7 @@ class PerLayerParameters:
window_left: int window_left: int
logits_soft_cap: Optional[float] logits_soft_cap: Optional[float]
sm_scale: float sm_scale: float
has_sinks: bool = False
def get_per_layer_parameters( def get_per_layer_parameters(
@ -307,9 +308,11 @@ def get_per_layer_parameters(
window_left = window_size[0] if window_size is not None else -1 window_left = window_size[0] if window_size is not None else -1
logits_soft_cap = getattr(impl, "logits_soft_cap", None) logits_soft_cap = getattr(impl, "logits_soft_cap", None)
sm_scale = impl.scale sm_scale = impl.scale
has_sinks = getattr(impl, "sinks", None) is not None
per_layer_params[key] = PerLayerParameters(window_left, per_layer_params[key] = PerLayerParameters(window_left,
logits_soft_cap, sm_scale) logits_soft_cap, sm_scale,
has_sinks)
return per_layer_params return per_layer_params