mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-20 21:45:42 +08:00
[Bugfix][Flashinfer] fix VLLM_USE_TRTLLM_ATTENTION issue for models with diff hyperparameters (#25924)
Signed-off-by: elvischenv <219235043+elvischenv@users.noreply.github.com>
This commit is contained in:
parent
76879cc160
commit
b82f4307c9
@ -37,7 +37,6 @@ from vllm.utils import cdiv, is_pin_memory_available
|
|||||||
from vllm.utils.flashinfer import (
|
from vllm.utils.flashinfer import (
|
||||||
can_use_trtllm_attention,
|
can_use_trtllm_attention,
|
||||||
flashinfer_disable_q_quantization,
|
flashinfer_disable_q_quantization,
|
||||||
supports_trtllm_attention,
|
|
||||||
use_trtllm_attention,
|
use_trtllm_attention,
|
||||||
)
|
)
|
||||||
from vllm.v1.attention.backends.utils import (
|
from vllm.v1.attention.backends.utils import (
|
||||||
@ -323,15 +322,13 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
|||||||
# VLLM_FLASHINFER_DISABLE_Q_QUANTIZATION is set to 1. Otherwise, try to
|
# VLLM_FLASHINFER_DISABLE_Q_QUANTIZATION is set to 1. Otherwise, try to
|
||||||
# use fp8 q if kv cache is fp8, and will fall back to model dtype
|
# use fp8 q if kv cache is fp8, and will fall back to model dtype
|
||||||
# if TRTLLM attention kernel is not used when building attn metadata
|
# if TRTLLM attention kernel is not used when building attn metadata
|
||||||
if supports_trtllm_attention() and not flashinfer_disable_q_quantization():
|
can_use_trtllm = can_use_trtllm_attention(self.num_qo_heads, self.num_kv_heads)
|
||||||
|
if can_use_trtllm and not flashinfer_disable_q_quantization():
|
||||||
self.q_data_type = self.kv_cache_dtype
|
self.q_data_type = self.kv_cache_dtype
|
||||||
else:
|
else:
|
||||||
self.q_data_type = self.model_config.dtype
|
self.q_data_type = self.model_config.dtype
|
||||||
|
|
||||||
supports_spec_as_decode = can_use_trtllm_attention(
|
self._init_reorder_batch_threshold(1, supports_spec_as_decode=can_use_trtllm)
|
||||||
self.num_qo_heads, self.num_kv_heads
|
|
||||||
)
|
|
||||||
self._init_reorder_batch_threshold(1, supports_spec_as_decode)
|
|
||||||
|
|
||||||
self._cascade_wrapper = None # Wrapper for cascade attention
|
self._cascade_wrapper = None # Wrapper for cascade attention
|
||||||
|
|
||||||
@ -344,7 +341,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
|||||||
self.window_left = self.global_hyperparameters.window_left
|
self.window_left = self.global_hyperparameters.window_left
|
||||||
self.logits_soft_cap = self.global_hyperparameters.logits_soft_cap
|
self.logits_soft_cap = self.global_hyperparameters.logits_soft_cap
|
||||||
self.has_sinks = self.global_hyperparameters.has_sinks
|
self.has_sinks = self.global_hyperparameters.has_sinks
|
||||||
if self.has_sinks and not supports_trtllm_attention():
|
if self.has_sinks and not can_use_trtllm:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
"FlashInfer backend currently does not support attention "
|
"FlashInfer backend currently does not support attention "
|
||||||
"sinks, please use trtllm on blackwell or flash attention on "
|
"sinks, please use trtllm on blackwell or flash attention on "
|
||||||
@ -548,16 +545,30 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
|||||||
has_sinks=self.has_sinks,
|
has_sinks=self.has_sinks,
|
||||||
has_spec=uses_spec_reorder,
|
has_spec=uses_spec_reorder,
|
||||||
)
|
)
|
||||||
if self.has_sinks and not (prefill_use_trtllm and decode_use_trtllm):
|
|
||||||
|
if not (prefill_use_trtllm and decode_use_trtllm):
|
||||||
|
if self.has_sinks:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
"FlashInfer backend currently does not support attention "
|
"FlashInfer backend currently does not support attention "
|
||||||
"sinks, please use trtllm on blackwell or flash attention on "
|
"sinks, please use trtllm on blackwell or flash attention "
|
||||||
"earlier GPUs."
|
"on earlier GPUs."
|
||||||
)
|
)
|
||||||
|
|
||||||
# If TRTLLM attention is not used, the q quantization is not supported.
|
if not self.global_hyperparameters.has_same_window_lefts:
|
||||||
# Fall back to use model dtype.
|
raise ValueError(
|
||||||
if not (prefill_use_trtllm and decode_use_trtllm):
|
"Window left is not the same for all layers. "
|
||||||
|
"One potential fix is to set disable_sliding_window=True"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert self.global_hyperparameters.has_same_all_params, (
|
||||||
|
"FlashInfer backend currently only supports models in which "
|
||||||
|
"all layers share the same values for the following "
|
||||||
|
"hyperparameters: `window_left`, `logits_soft_cap`, "
|
||||||
|
"`sm_scale`."
|
||||||
|
)
|
||||||
|
|
||||||
|
# The q quantization is not supported for non-trtllm attention,
|
||||||
|
# fall back to model dtype.
|
||||||
self.q_data_type = self.model_config.dtype
|
self.q_data_type = self.model_config.dtype
|
||||||
|
|
||||||
attn_metadata = FlashInferMetadata(
|
attn_metadata = FlashInferMetadata(
|
||||||
@ -772,9 +783,7 @@ class FlashInferImpl(AttentionImpl):
|
|||||||
)
|
)
|
||||||
self.sinks = sinks
|
self.sinks = sinks
|
||||||
|
|
||||||
self.support_trtllm_attn = (
|
self.support_trtllm_attn = can_use_trtllm_attention(num_heads, num_kv_heads)
|
||||||
supports_trtllm_attention() and num_heads % num_kv_heads == 0
|
|
||||||
)
|
|
||||||
self.bmm1_scale: float | None = None
|
self.bmm1_scale: float | None = None
|
||||||
self.bmm2_scale: float | None = None
|
self.bmm2_scale: float | None = None
|
||||||
self.o_sf_scale: float | None = None
|
self.o_sf_scale: float | None = None
|
||||||
|
|||||||
@ -4,7 +4,7 @@ import abc
|
|||||||
import enum
|
import enum
|
||||||
import functools
|
import functools
|
||||||
from abc import abstractmethod
|
from abc import abstractmethod
|
||||||
from dataclasses import dataclass, fields, make_dataclass
|
from dataclasses import dataclass, field, fields, make_dataclass
|
||||||
from typing import (
|
from typing import (
|
||||||
TYPE_CHECKING,
|
TYPE_CHECKING,
|
||||||
Any,
|
Any,
|
||||||
@ -395,6 +395,9 @@ class PerLayerParameters:
|
|||||||
logits_soft_cap: Optional[float]
|
logits_soft_cap: Optional[float]
|
||||||
sm_scale: float
|
sm_scale: float
|
||||||
has_sinks: bool = False
|
has_sinks: bool = False
|
||||||
|
# has same params for all layers
|
||||||
|
has_same_window_lefts: Optional[bool] = field(default=None, compare=False)
|
||||||
|
has_same_all_params: Optional[bool] = field(default=None, compare=False)
|
||||||
|
|
||||||
|
|
||||||
def get_per_layer_parameters(
|
def get_per_layer_parameters(
|
||||||
@ -446,19 +449,11 @@ def infer_global_hyperparameters(
|
|||||||
param_sets = list(per_layer_params.values())
|
param_sets = list(per_layer_params.values())
|
||||||
global_params = param_sets[0]
|
global_params = param_sets[0]
|
||||||
|
|
||||||
# trtllm attention doesn't need global hyper params so disable the check
|
global_params.has_same_window_lefts = all(
|
||||||
if not envs.VLLM_USE_TRTLLM_ATTENTION:
|
params.window_left == global_params.window_left for params in param_sets
|
||||||
for params in param_sets:
|
|
||||||
if params.window_left != global_params.window_left:
|
|
||||||
raise ValueError(
|
|
||||||
"Window left is not the same for all layers. "
|
|
||||||
"One potential fix is to set disable_sliding_window=True"
|
|
||||||
)
|
)
|
||||||
assert params == global_params, (
|
global_params.has_same_all_params = all(
|
||||||
"FlashInfer backend currently only supports models in which all"
|
params == global_params for params in param_sets
|
||||||
"layers share the same values "
|
|
||||||
"for the following hyperparameters:"
|
|
||||||
"`window_left`, `logits_soft_cap`, `sm_scale`."
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return global_params
|
return global_params
|
||||||
@ -925,8 +920,8 @@ def create_fast_prefill_custom_backend(
|
|||||||
):
|
):
|
||||||
def __init__(self, metadata, common_attn_metadata):
|
def __init__(self, metadata, common_attn_metadata):
|
||||||
# Shallow copy all fields in metadata cls
|
# Shallow copy all fields in metadata cls
|
||||||
for field in fields(metadata.__class__):
|
for _field in fields(metadata.__class__):
|
||||||
setattr(self, field.name, getattr(metadata, field.name))
|
setattr(self, _field.name, getattr(metadata, _field.name))
|
||||||
|
|
||||||
# Set additional fields that will be used in model code
|
# Set additional fields that will be used in model code
|
||||||
assert (
|
assert (
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user