Force TRTLLM attention for gpt-oss on SM100 (#22678)

Signed-off-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
Michael Goin 2025-08-13 00:22:16 -04:00 committed by GitHub
parent b1361c7273
commit c6b928798e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 20 additions and 9 deletions

View File

@ -8,7 +8,6 @@ import torch.distributed as dist
from torch import nn
from transformers import GptOssConfig
from vllm import envs
from vllm.attention import Attention, AttentionType
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig
@ -70,11 +69,9 @@ class OAIAttention(nn.Module):
tp_size = get_tensor_model_parallel_world_size()
attention_sink_dtype = (torch.float32 if envs.VLLM_USE_TRTLLM_ATTENTION
else torch.bfloat16)
self.sinks = torch.nn.Parameter(
torch.empty(config.num_attention_heads // tp_size,
dtype=attention_sink_dtype,
dtype=torch.bfloat16,
requires_grad=False))
self.norm = RMSNorm(config.hidden_size, eps=1e-5)

View File

@ -154,6 +154,7 @@ def use_trtllm_attention(
num_qo_heads: Optional[int],
num_kv_heads: Optional[int],
attn_head_size: Optional[int],
has_sinks: bool = False,
) -> bool:
# Requires SM100 and NVIDIA artifactory to be accessible to download cubins
if not (current_platform.is_device_capability(100)
@ -165,6 +166,13 @@ def use_trtllm_attention(
or num_qo_heads % num_kv_heads != 0):
return False
# If sinks are being used, we must use TRTLLM attention as it's
# the only backend that supports them
if has_sinks:
logger.info_once(
"Using TRTLLM attention (required for attention sinks).")
return True
env_value = envs.VLLM_USE_TRTLLM_ATTENTION
if env_value is not None:
logger.info_once("VLLM_USE_TRTLLM_ATTENTION is set to %s", env_value)

View File

@ -523,14 +523,17 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
num_kv_heads = self.kv_cache_spec.num_kv_heads
head_dim = self.kv_cache_spec.head_size
# Check if any layer uses sinks (requires TRTLLM attention)
has_sinks = self.global_hyperparameters.has_sinks
# currently prefill trtllm attention does not support fp8 kv cache
prefill_use_trtllm = not cache_dtype.startswith("fp8") \
and use_trtllm_attention(
num_prefill_tokens, max_seq_len, cache_dtype,
num_qo_heads, num_kv_heads, head_dim)
num_qo_heads, num_kv_heads, head_dim, has_sinks)
decode_use_trtllm = use_trtllm_attention(
num_decode_tokens, max_seq_len, cache_dtype,
num_qo_heads, num_kv_heads, head_dim)
num_qo_heads, num_kv_heads, head_dim, has_sinks)
attn_metadata = FlashInferMetadata(
num_actual_tokens=num_actual_tokens,
@ -642,9 +645,9 @@ class FlashInferImpl(AttentionImpl):
f"heads in the layer. Expected {num_heads}, but got "
f"{sinks.shape[0]}."
)
# Cast sinks to float32 if needed (FlashInfer requirement)
if sinks.dtype != torch.float32:
raise ValueError("Sinks must be of type float32, but got "
f"{sinks.dtype}.")
sinks = sinks.to(torch.float32)
self.sinks = sinks
def forward(

View File

@ -285,6 +285,7 @@ class PerLayerParameters:
window_left: int
logits_soft_cap: Optional[float]
sm_scale: float
has_sinks: bool = False
def get_per_layer_parameters(
@ -307,9 +308,11 @@ def get_per_layer_parameters(
window_left = window_size[0] if window_size is not None else -1
logits_soft_cap = getattr(impl, "logits_soft_cap", None)
sm_scale = impl.scale
has_sinks = getattr(impl, "sinks", None) is not None
per_layer_params[key] = PerLayerParameters(window_left,
logits_soft_cap, sm_scale)
logits_soft_cap, sm_scale,
has_sinks)
return per_layer_params