mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 22:55:51 +08:00
[Bugfix][Flashinfer] fix VLLM_USE_TRTLLM_ATTENTION issue for models with diff hyperparameters (#25924)
Signed-off-by: elvischenv <219235043+elvischenv@users.noreply.github.com>
This commit is contained in:
parent
76879cc160
commit
b82f4307c9
@ -37,7 +37,6 @@ from vllm.utils import cdiv, is_pin_memory_available
|
||||
from vllm.utils.flashinfer import (
|
||||
can_use_trtllm_attention,
|
||||
flashinfer_disable_q_quantization,
|
||||
supports_trtllm_attention,
|
||||
use_trtllm_attention,
|
||||
)
|
||||
from vllm.v1.attention.backends.utils import (
|
||||
@ -323,15 +322,13 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
||||
# VLLM_FLASHINFER_DISABLE_Q_QUANTIZATION is set to 1. Otherwise, try to
|
||||
# use fp8 q if kv cache is fp8, and will fall back to model dtype
|
||||
# if TRTLLM attention kernel is not used when building attn metadata
|
||||
if supports_trtllm_attention() and not flashinfer_disable_q_quantization():
|
||||
can_use_trtllm = can_use_trtllm_attention(self.num_qo_heads, self.num_kv_heads)
|
||||
if can_use_trtllm and not flashinfer_disable_q_quantization():
|
||||
self.q_data_type = self.kv_cache_dtype
|
||||
else:
|
||||
self.q_data_type = self.model_config.dtype
|
||||
|
||||
supports_spec_as_decode = can_use_trtllm_attention(
|
||||
self.num_qo_heads, self.num_kv_heads
|
||||
)
|
||||
self._init_reorder_batch_threshold(1, supports_spec_as_decode)
|
||||
self._init_reorder_batch_threshold(1, supports_spec_as_decode=can_use_trtllm)
|
||||
|
||||
self._cascade_wrapper = None # Wrapper for cascade attention
|
||||
|
||||
@ -344,7 +341,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
||||
self.window_left = self.global_hyperparameters.window_left
|
||||
self.logits_soft_cap = self.global_hyperparameters.logits_soft_cap
|
||||
self.has_sinks = self.global_hyperparameters.has_sinks
|
||||
if self.has_sinks and not supports_trtllm_attention():
|
||||
if self.has_sinks and not can_use_trtllm:
|
||||
raise NotImplementedError(
|
||||
"FlashInfer backend currently does not support attention "
|
||||
"sinks, please use trtllm on blackwell or flash attention on "
|
||||
@ -548,16 +545,30 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
||||
has_sinks=self.has_sinks,
|
||||
has_spec=uses_spec_reorder,
|
||||
)
|
||||
if self.has_sinks and not (prefill_use_trtllm and decode_use_trtllm):
|
||||
raise NotImplementedError(
|
||||
"FlashInfer backend currently does not support attention "
|
||||
"sinks, please use trtllm on blackwell or flash attention on "
|
||||
"earlier GPUs."
|
||||
|
||||
if not (prefill_use_trtllm and decode_use_trtllm):
|
||||
if self.has_sinks:
|
||||
raise NotImplementedError(
|
||||
"FlashInfer backend currently does not support attention "
|
||||
"sinks, please use trtllm on blackwell or flash attention "
|
||||
"on earlier GPUs."
|
||||
)
|
||||
|
||||
if not self.global_hyperparameters.has_same_window_lefts:
|
||||
raise ValueError(
|
||||
"Window left is not the same for all layers. "
|
||||
"One potential fix is to set disable_sliding_window=True"
|
||||
)
|
||||
|
||||
assert self.global_hyperparameters.has_same_all_params, (
|
||||
"FlashInfer backend currently only supports models in which "
|
||||
"all layers share the same values for the following "
|
||||
"hyperparameters: `window_left`, `logits_soft_cap`, "
|
||||
"`sm_scale`."
|
||||
)
|
||||
|
||||
# If TRTLLM attention is not used, the q quantization is not supported.
|
||||
# Fall back to use model dtype.
|
||||
if not (prefill_use_trtllm and decode_use_trtllm):
|
||||
# The q quantization is not supported for non-trtllm attention,
|
||||
# fall back to model dtype.
|
||||
self.q_data_type = self.model_config.dtype
|
||||
|
||||
attn_metadata = FlashInferMetadata(
|
||||
@ -772,9 +783,7 @@ class FlashInferImpl(AttentionImpl):
|
||||
)
|
||||
self.sinks = sinks
|
||||
|
||||
self.support_trtllm_attn = (
|
||||
supports_trtllm_attention() and num_heads % num_kv_heads == 0
|
||||
)
|
||||
self.support_trtllm_attn = can_use_trtllm_attention(num_heads, num_kv_heads)
|
||||
self.bmm1_scale: float | None = None
|
||||
self.bmm2_scale: float | None = None
|
||||
self.o_sf_scale: float | None = None
|
||||
|
||||
@ -4,7 +4,7 @@ import abc
|
||||
import enum
|
||||
import functools
|
||||
from abc import abstractmethod
|
||||
from dataclasses import dataclass, fields, make_dataclass
|
||||
from dataclasses import dataclass, field, fields, make_dataclass
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
@ -233,7 +233,7 @@ class AttentionCGSupport(enum.Enum):
|
||||
"""Cudagraph always supported; supports mixed-prefill-decode"""
|
||||
UNIFORM_BATCH = 2
|
||||
"""Cudagraph supported for batches the only contain query lengths that are
|
||||
the same, this can be used for spec-decode
|
||||
the same, this can be used for spec-decode
|
||||
i.e. "decodes" are 1 + num_speculative_tokens"""
|
||||
UNIFORM_SINGLE_TOKEN_DECODE = 1
|
||||
"""Cudagraph supported for batches the only contain query_len==1 decodes"""
|
||||
@ -395,6 +395,9 @@ class PerLayerParameters:
|
||||
logits_soft_cap: Optional[float]
|
||||
sm_scale: float
|
||||
has_sinks: bool = False
|
||||
# has same params for all layers
|
||||
has_same_window_lefts: Optional[bool] = field(default=None, compare=False)
|
||||
has_same_all_params: Optional[bool] = field(default=None, compare=False)
|
||||
|
||||
|
||||
def get_per_layer_parameters(
|
||||
@ -446,20 +449,12 @@ def infer_global_hyperparameters(
|
||||
param_sets = list(per_layer_params.values())
|
||||
global_params = param_sets[0]
|
||||
|
||||
# trtllm attention doesn't need global hyper params so disable the check
|
||||
if not envs.VLLM_USE_TRTLLM_ATTENTION:
|
||||
for params in param_sets:
|
||||
if params.window_left != global_params.window_left:
|
||||
raise ValueError(
|
||||
"Window left is not the same for all layers. "
|
||||
"One potential fix is to set disable_sliding_window=True"
|
||||
)
|
||||
assert params == global_params, (
|
||||
"FlashInfer backend currently only supports models in which all"
|
||||
"layers share the same values "
|
||||
"for the following hyperparameters:"
|
||||
"`window_left`, `logits_soft_cap`, `sm_scale`."
|
||||
)
|
||||
global_params.has_same_window_lefts = all(
|
||||
params.window_left == global_params.window_left for params in param_sets
|
||||
)
|
||||
global_params.has_same_all_params = all(
|
||||
params == global_params for params in param_sets
|
||||
)
|
||||
|
||||
return global_params
|
||||
|
||||
@ -925,8 +920,8 @@ def create_fast_prefill_custom_backend(
|
||||
):
|
||||
def __init__(self, metadata, common_attn_metadata):
|
||||
# Shallow copy all fields in metadata cls
|
||||
for field in fields(metadata.__class__):
|
||||
setattr(self, field.name, getattr(metadata, field.name))
|
||||
for _field in fields(metadata.__class__):
|
||||
setattr(self, _field.name, getattr(metadata, _field.name))
|
||||
|
||||
# Set additional fields that will be used in model code
|
||||
assert (
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user