mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-05 06:47:04 +08:00
Fix trtllm-gen attention env and add attention sink (#22378)
Signed-off-by: Siyuan Fu <siyuanf@nvidia.com> Signed-off-by: Lain <fusiyuan2000@hotmail.com> Signed-off-by: Yongye Zhu <zyy1102000@gmail.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Co-authored-by: Michael Goin <mgoin64@gmail.com> Co-authored-by: Yongye Zhu <zyy1102000@gmail.com>
This commit is contained in:
parent
5c7cc33f4d
commit
9a3835aaa9
13
vllm/envs.py
13
vllm/envs.py
@ -152,8 +152,7 @@ if TYPE_CHECKING:
|
||||
VLLM_LOOPBACK_IP: str = ""
|
||||
VLLM_ALLOW_CHUNKED_LOCAL_ATTN_WITH_HYBRID_KV_CACHE: bool = False
|
||||
VLLM_ENABLE_RESPONSES_API_STORE: bool = False
|
||||
VLLM_USE_TRTLLM_CONTEXT_ATTENTION: bool = False
|
||||
VLLM_USE_TRTLLM_DECODE_ATTENTION: bool = False
|
||||
VLLM_USE_TRTLLM_ATTENTION: Optional[str] = None
|
||||
VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8: bool = False
|
||||
VLLM_USE_FLASHINFER_MOE_MXFP4_BF16: bool = False
|
||||
|
||||
@ -1043,13 +1042,9 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
||||
"VLLM_USE_CUDNN_PREFILL":
|
||||
lambda: bool(int(os.getenv("VLLM_USE_CUDNN_PREFILL", "0"))),
|
||||
|
||||
# If set to 1, use the TRTLLM Context Attention backend in flashinfer.
|
||||
"VLLM_USE_TRTLLM_CONTEXT_ATTENTION":
|
||||
lambda: bool(int(os.getenv("VLLM_USE_TRTLLM_CONTEXT_ATTENTION", "0"))),
|
||||
|
||||
# If set to 1, use the TRTLLM Decode Attention backend in flashinfer.
|
||||
"VLLM_USE_TRTLLM_DECODE_ATTENTION":
|
||||
lambda: bool(int(os.getenv("VLLM_USE_TRTLLM_DECODE_ATTENTION", "0"))),
|
||||
# If set to 1, use the TRTLLM attention backend in flashinfer.
|
||||
"VLLM_USE_TRTLLM_ATTENTION":
|
||||
lambda: os.getenv("VLLM_USE_TRTLLM_ATTENTION", None),
|
||||
|
||||
# Controls garbage collection during CUDA graph capture.
|
||||
# If set to 0 (default), enables GC freezing to speed up capture time.
|
||||
|
||||
@ -70,9 +70,8 @@ class OAIAttention(nn.Module):
|
||||
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
|
||||
attention_sink_dtype = (
|
||||
torch.float32 if envs.VLLM_USE_TRTLLM_CONTEXT_ATTENTION
|
||||
or envs.VLLM_USE_TRTLLM_DECODE_ATTENTION else torch.bfloat16)
|
||||
attention_sink_dtype = (torch.float32 if envs.VLLM_USE_TRTLLM_ATTENTION
|
||||
else torch.bfloat16)
|
||||
self.sinks = torch.nn.Parameter(
|
||||
torch.empty(config.num_attention_heads // tp_size,
|
||||
dtype=attention_sink_dtype,
|
||||
|
||||
@ -159,7 +159,7 @@ def use_trtllm_attention(
|
||||
|
||||
# Check if the dimensions are supported by TRTLLM decode attention
|
||||
if (attn_head_size is None or num_qo_heads is None or num_kv_heads is None
|
||||
or num_qo_heads % num_kv_heads != 0 or attn_head_size != 128):
|
||||
or num_qo_heads % num_kv_heads != 0):
|
||||
return False
|
||||
|
||||
env_value = envs.VLLM_USE_TRTLLM_ATTENTION
|
||||
@ -169,10 +169,10 @@ def use_trtllm_attention(
|
||||
# Making the conditional check for zero because
|
||||
# the path is automatically enabled if the batch size condition
|
||||
# is satisfied.
|
||||
no_use_trtllm = (env_value == "0")
|
||||
if not no_use_trtllm:
|
||||
use_trtllm = (env_value == "1")
|
||||
if use_trtllm:
|
||||
logger.info_once("Using TRTLLM attention.")
|
||||
return not no_use_trtllm
|
||||
return use_trtllm
|
||||
else:
|
||||
# Environment variable not set - use auto-detection
|
||||
use_trtllm = (num_tokens <= 256 and max_seq_len < 131072
|
||||
|
||||
@ -215,6 +215,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
||||
self._cascade_wrapper = None # Wrapper for cascade attention
|
||||
|
||||
# Global hyperparameters shared by all attention layers
|
||||
# TODO: discard this for trtllm-gen backend
|
||||
self.global_hyperparameters = infer_global_hyperparameters(
|
||||
get_per_layer_parameters(vllm_config, layer_names, FlashInferImpl))
|
||||
|
||||
@ -523,16 +524,12 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
||||
head_dim = self.kv_cache_spec.head_size
|
||||
|
||||
# currently prefill trtllm attention does not support fp8 kv cache
|
||||
# trtllm may not support sliding window
|
||||
prefill_use_trtllm = (self.global_hyperparameters.window_left == -1
|
||||
and not cache_dtype.startswith("fp8")
|
||||
and use_trtllm_attention(
|
||||
prefill_use_trtllm = use_trtllm_attention(
|
||||
num_prefill_tokens, max_seq_len, cache_dtype,
|
||||
num_qo_heads, num_kv_heads, head_dim))
|
||||
decode_use_trtllm = (self.global_hyperparameters.window_left == -1
|
||||
and use_trtllm_attention(
|
||||
num_qo_heads, num_kv_heads, head_dim)
|
||||
decode_use_trtllm = use_trtllm_attention(
|
||||
num_decode_tokens, max_seq_len, cache_dtype,
|
||||
num_qo_heads, num_kv_heads, head_dim))
|
||||
num_qo_heads, num_kv_heads, head_dim)
|
||||
|
||||
attn_metadata = FlashInferMetadata(
|
||||
num_actual_tokens=num_actual_tokens,
|
||||
@ -793,6 +790,8 @@ class FlashInferImpl(AttentionImpl):
|
||||
batch_size=attn_metadata.num_prefills,
|
||||
cum_seq_lens_q=attn_metadata.qo_indptr_gpu,
|
||||
cum_seq_lens_kv=attn_metadata.paged_kv_indptr_gpu,
|
||||
window_left=window_left,
|
||||
sinks=self.sinks,
|
||||
out=output[num_decode_tokens:],
|
||||
)
|
||||
|
||||
@ -839,6 +838,8 @@ class FlashInferImpl(AttentionImpl):
|
||||
max_seq_len=attn_metadata.max_seq_len,
|
||||
bmm1_scale=layer._k_scale_float * self.scale,
|
||||
bmm2_scale=layer._v_scale_float,
|
||||
window_left=window_left,
|
||||
sinks=self.sinks,
|
||||
out=output[:num_decode_tokens],
|
||||
)
|
||||
return output_padded
|
||||
|
||||
@ -254,8 +254,7 @@ def get_kv_cache_layout():
|
||||
# Override with format specified by the user.
|
||||
cache_layout = envs.VLLM_KV_CACHE_LAYOUT
|
||||
if cache_layout is None:
|
||||
if (envs.VLLM_USE_TRTLLM_CONTEXT_ATTENTION
|
||||
or envs.VLLM_USE_TRTLLM_DECODE_ATTENTION):
|
||||
if envs.VLLM_USE_TRTLLM_ATTENTION:
|
||||
cache_layout = "HND"
|
||||
else:
|
||||
cache_layout = get_kv_connector_cache_layout()
|
||||
@ -333,8 +332,7 @@ def infer_global_hyperparameters(
|
||||
global_params = param_sets[0]
|
||||
|
||||
# trtllm attention doesn't need global hyper params so disable the check
|
||||
if (not envs.VLLM_USE_TRTLLM_CONTEXT_ATTENTION
|
||||
and not envs.VLLM_USE_TRTLLM_DECODE_ATTENTION):
|
||||
if not envs.VLLM_USE_TRTLLM_ATTENTION:
|
||||
for params in param_sets:
|
||||
if params.window_left != global_params.window_left:
|
||||
raise ValueError(
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user