mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-16 04:35:01 +08:00
Force TRTLLM attention for gpt-oss on SM100 (#22678)
Signed-off-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
parent
b1361c7273
commit
c6b928798e
@ -8,7 +8,6 @@ import torch.distributed as dist
|
|||||||
from torch import nn
|
from torch import nn
|
||||||
from transformers import GptOssConfig
|
from transformers import GptOssConfig
|
||||||
|
|
||||||
from vllm import envs
|
|
||||||
from vllm.attention import Attention, AttentionType
|
from vllm.attention import Attention, AttentionType
|
||||||
from vllm.compilation.decorators import support_torch_compile
|
from vllm.compilation.decorators import support_torch_compile
|
||||||
from vllm.config import CacheConfig, VllmConfig
|
from vllm.config import CacheConfig, VllmConfig
|
||||||
@ -70,11 +69,9 @@ class OAIAttention(nn.Module):
|
|||||||
|
|
||||||
tp_size = get_tensor_model_parallel_world_size()
|
tp_size = get_tensor_model_parallel_world_size()
|
||||||
|
|
||||||
attention_sink_dtype = (torch.float32 if envs.VLLM_USE_TRTLLM_ATTENTION
|
|
||||||
else torch.bfloat16)
|
|
||||||
self.sinks = torch.nn.Parameter(
|
self.sinks = torch.nn.Parameter(
|
||||||
torch.empty(config.num_attention_heads // tp_size,
|
torch.empty(config.num_attention_heads // tp_size,
|
||||||
dtype=attention_sink_dtype,
|
dtype=torch.bfloat16,
|
||||||
requires_grad=False))
|
requires_grad=False))
|
||||||
|
|
||||||
self.norm = RMSNorm(config.hidden_size, eps=1e-5)
|
self.norm = RMSNorm(config.hidden_size, eps=1e-5)
|
||||||
|
|||||||
@ -154,6 +154,7 @@ def use_trtllm_attention(
|
|||||||
num_qo_heads: Optional[int],
|
num_qo_heads: Optional[int],
|
||||||
num_kv_heads: Optional[int],
|
num_kv_heads: Optional[int],
|
||||||
attn_head_size: Optional[int],
|
attn_head_size: Optional[int],
|
||||||
|
has_sinks: bool = False,
|
||||||
) -> bool:
|
) -> bool:
|
||||||
# Requires SM100 and NVIDIA artifactory to be accessible to download cubins
|
# Requires SM100 and NVIDIA artifactory to be accessible to download cubins
|
||||||
if not (current_platform.is_device_capability(100)
|
if not (current_platform.is_device_capability(100)
|
||||||
@ -165,6 +166,13 @@ def use_trtllm_attention(
|
|||||||
or num_qo_heads % num_kv_heads != 0):
|
or num_qo_heads % num_kv_heads != 0):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
# If sinks are being used, we must use TRTLLM attention as it's
|
||||||
|
# the only backend that supports them
|
||||||
|
if has_sinks:
|
||||||
|
logger.info_once(
|
||||||
|
"Using TRTLLM attention (required for attention sinks).")
|
||||||
|
return True
|
||||||
|
|
||||||
env_value = envs.VLLM_USE_TRTLLM_ATTENTION
|
env_value = envs.VLLM_USE_TRTLLM_ATTENTION
|
||||||
if env_value is not None:
|
if env_value is not None:
|
||||||
logger.info_once("VLLM_USE_TRTLLM_ATTENTION is set to %s", env_value)
|
logger.info_once("VLLM_USE_TRTLLM_ATTENTION is set to %s", env_value)
|
||||||
|
|||||||
@ -523,14 +523,17 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
|||||||
num_kv_heads = self.kv_cache_spec.num_kv_heads
|
num_kv_heads = self.kv_cache_spec.num_kv_heads
|
||||||
head_dim = self.kv_cache_spec.head_size
|
head_dim = self.kv_cache_spec.head_size
|
||||||
|
|
||||||
|
# Check if any layer uses sinks (requires TRTLLM attention)
|
||||||
|
has_sinks = self.global_hyperparameters.has_sinks
|
||||||
|
|
||||||
# currently prefill trtllm attention does not support fp8 kv cache
|
# currently prefill trtllm attention does not support fp8 kv cache
|
||||||
prefill_use_trtllm = not cache_dtype.startswith("fp8") \
|
prefill_use_trtllm = not cache_dtype.startswith("fp8") \
|
||||||
and use_trtllm_attention(
|
and use_trtllm_attention(
|
||||||
num_prefill_tokens, max_seq_len, cache_dtype,
|
num_prefill_tokens, max_seq_len, cache_dtype,
|
||||||
num_qo_heads, num_kv_heads, head_dim)
|
num_qo_heads, num_kv_heads, head_dim, has_sinks)
|
||||||
decode_use_trtllm = use_trtllm_attention(
|
decode_use_trtllm = use_trtllm_attention(
|
||||||
num_decode_tokens, max_seq_len, cache_dtype,
|
num_decode_tokens, max_seq_len, cache_dtype,
|
||||||
num_qo_heads, num_kv_heads, head_dim)
|
num_qo_heads, num_kv_heads, head_dim, has_sinks)
|
||||||
|
|
||||||
attn_metadata = FlashInferMetadata(
|
attn_metadata = FlashInferMetadata(
|
||||||
num_actual_tokens=num_actual_tokens,
|
num_actual_tokens=num_actual_tokens,
|
||||||
@ -642,9 +645,9 @@ class FlashInferImpl(AttentionImpl):
|
|||||||
f"heads in the layer. Expected {num_heads}, but got "
|
f"heads in the layer. Expected {num_heads}, but got "
|
||||||
f"{sinks.shape[0]}."
|
f"{sinks.shape[0]}."
|
||||||
)
|
)
|
||||||
|
# Cast sinks to float32 if needed (FlashInfer requirement)
|
||||||
if sinks.dtype != torch.float32:
|
if sinks.dtype != torch.float32:
|
||||||
raise ValueError("Sinks must be of type float32, but got "
|
sinks = sinks.to(torch.float32)
|
||||||
f"{sinks.dtype}.")
|
|
||||||
self.sinks = sinks
|
self.sinks = sinks
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
|
|||||||
@ -285,6 +285,7 @@ class PerLayerParameters:
|
|||||||
window_left: int
|
window_left: int
|
||||||
logits_soft_cap: Optional[float]
|
logits_soft_cap: Optional[float]
|
||||||
sm_scale: float
|
sm_scale: float
|
||||||
|
has_sinks: bool = False
|
||||||
|
|
||||||
|
|
||||||
def get_per_layer_parameters(
|
def get_per_layer_parameters(
|
||||||
@ -307,9 +308,11 @@ def get_per_layer_parameters(
|
|||||||
window_left = window_size[0] if window_size is not None else -1
|
window_left = window_size[0] if window_size is not None else -1
|
||||||
logits_soft_cap = getattr(impl, "logits_soft_cap", None)
|
logits_soft_cap = getattr(impl, "logits_soft_cap", None)
|
||||||
sm_scale = impl.scale
|
sm_scale = impl.scale
|
||||||
|
has_sinks = getattr(impl, "sinks", None) is not None
|
||||||
|
|
||||||
per_layer_params[key] = PerLayerParameters(window_left,
|
per_layer_params[key] = PerLayerParameters(window_left,
|
||||||
logits_soft_cap, sm_scale)
|
logits_soft_cap, sm_scale,
|
||||||
|
has_sinks)
|
||||||
|
|
||||||
return per_layer_params
|
return per_layer_params
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user