mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-10 04:33:34 +08:00
[Bugfix] Allow Only SDPA Backend for ViT on B200 for Qwen3-VL (#25788)
Signed-off-by: simon-mo <simon.mo@hey.com>
This commit is contained in:
parent
32335c8b34
commit
c2fa2d4dc9
@ -274,6 +274,8 @@ class Qwen2_5_VisionAttention(nn.Module):
|
|||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
prefix: str = "",
|
prefix: str = "",
|
||||||
use_data_parallel: bool = False,
|
use_data_parallel: bool = False,
|
||||||
|
attn_backend: _Backend = _Backend.TORCH_SDPA,
|
||||||
|
use_upstream_fa: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
# Per attention head and per partition values.
|
# Per attention head and per partition values.
|
||||||
@ -300,25 +302,8 @@ class Qwen2_5_VisionAttention(nn.Module):
|
|||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
prefix=f"{prefix}.proj",
|
prefix=f"{prefix}.proj",
|
||||||
disable_tp=use_data_parallel)
|
disable_tp=use_data_parallel)
|
||||||
|
self.attn_backend = attn_backend
|
||||||
# Detect attention implementation.
|
self.use_upstream_fa = use_upstream_fa
|
||||||
self.attn_backend = get_vit_attn_backend(
|
|
||||||
head_size=self.hidden_size_per_attention_head,
|
|
||||||
dtype=torch.get_default_dtype())
|
|
||||||
self.use_upstream_fa = False
|
|
||||||
if self.attn_backend != _Backend.FLASH_ATTN and \
|
|
||||||
check_upstream_fa_availability(
|
|
||||||
torch.get_default_dtype()):
|
|
||||||
self.attn_backend = _Backend.FLASH_ATTN
|
|
||||||
self.use_upstream_fa = True
|
|
||||||
|
|
||||||
if self.attn_backend not in {
|
|
||||||
_Backend.FLASH_ATTN, _Backend.TORCH_SDPA, _Backend.XFORMERS,
|
|
||||||
_Backend.ROCM_AITER_FA
|
|
||||||
}:
|
|
||||||
raise RuntimeError(
|
|
||||||
f"Qwen2.5-VL does not support {self.attn_backend} backend now."
|
|
||||||
)
|
|
||||||
self.is_flash_attn_backend = self.attn_backend in {
|
self.is_flash_attn_backend = self.attn_backend in {
|
||||||
_Backend.FLASH_ATTN, _Backend.ROCM_AITER_FA
|
_Backend.FLASH_ATTN, _Backend.ROCM_AITER_FA
|
||||||
}
|
}
|
||||||
@ -443,6 +428,8 @@ class Qwen2_5_VisionBlock(nn.Module):
|
|||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
prefix: str = "",
|
prefix: str = "",
|
||||||
use_data_parallel: bool = False,
|
use_data_parallel: bool = False,
|
||||||
|
attn_backend: _Backend = _Backend.TORCH_SDPA,
|
||||||
|
use_upstream_fa: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
if norm_layer is None:
|
if norm_layer is None:
|
||||||
@ -455,7 +442,9 @@ class Qwen2_5_VisionBlock(nn.Module):
|
|||||||
projection_size=dim,
|
projection_size=dim,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
prefix=f"{prefix}.attn",
|
prefix=f"{prefix}.attn",
|
||||||
use_data_parallel=use_data_parallel)
|
use_data_parallel=use_data_parallel,
|
||||||
|
attn_backend=attn_backend,
|
||||||
|
use_upstream_fa=use_upstream_fa)
|
||||||
self.mlp = Qwen2_5_VisionMLP(dim,
|
self.mlp = Qwen2_5_VisionMLP(dim,
|
||||||
mlp_hidden_dim,
|
mlp_hidden_dim,
|
||||||
act_fn=act_fn,
|
act_fn=act_fn,
|
||||||
@ -627,17 +616,35 @@ class Qwen2_5_VisionTransformer(nn.Module):
|
|||||||
head_dim = self.hidden_size // self.num_heads
|
head_dim = self.hidden_size // self.num_heads
|
||||||
self.rotary_pos_emb = Qwen2_5_VisionRotaryEmbedding(head_dim // 2)
|
self.rotary_pos_emb = Qwen2_5_VisionRotaryEmbedding(head_dim // 2)
|
||||||
|
|
||||||
|
use_upstream_fa = False
|
||||||
|
self.attn_backend = get_vit_attn_backend(
|
||||||
|
head_size=head_dim, dtype=torch.get_default_dtype())
|
||||||
|
if self.attn_backend != _Backend.FLASH_ATTN and \
|
||||||
|
check_upstream_fa_availability(
|
||||||
|
torch.get_default_dtype()):
|
||||||
|
self.attn_backend = _Backend.FLASH_ATTN
|
||||||
|
use_upstream_fa = True
|
||||||
|
|
||||||
|
if self.attn_backend not in {
|
||||||
|
_Backend.FLASH_ATTN, _Backend.TORCH_SDPA, _Backend.XFORMERS,
|
||||||
|
_Backend.ROCM_AITER_FA
|
||||||
|
}:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Qwen2.5-VL does not support {self.attn_backend} backend now."
|
||||||
|
)
|
||||||
|
|
||||||
self.blocks = nn.ModuleList([
|
self.blocks = nn.ModuleList([
|
||||||
Qwen2_5_VisionBlock(dim=self.hidden_size,
|
Qwen2_5_VisionBlock(
|
||||||
num_heads=self.num_heads,
|
dim=self.hidden_size,
|
||||||
mlp_hidden_dim=vision_config.intermediate_size,
|
num_heads=self.num_heads,
|
||||||
act_fn=get_act_and_mul_fn(
|
mlp_hidden_dim=vision_config.intermediate_size,
|
||||||
vision_config.hidden_act),
|
act_fn=get_act_and_mul_fn(vision_config.hidden_act),
|
||||||
norm_layer=norm_layer,
|
norm_layer=norm_layer,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
prefix=f"{prefix}.blocks.{layer_idx}",
|
prefix=f"{prefix}.blocks.{layer_idx}",
|
||||||
use_data_parallel=use_data_parallel)
|
use_data_parallel=use_data_parallel,
|
||||||
for layer_idx in range(depth)
|
attn_backend=self.attn_backend,
|
||||||
|
use_upstream_fa=use_upstream_fa) for layer_idx in range(depth)
|
||||||
])
|
])
|
||||||
self.merger = Qwen2_5_VisionPatchMerger(
|
self.merger = Qwen2_5_VisionPatchMerger(
|
||||||
d_model=vision_config.out_hidden_size,
|
d_model=vision_config.out_hidden_size,
|
||||||
@ -648,12 +655,6 @@ class Qwen2_5_VisionTransformer(nn.Module):
|
|||||||
prefix=f"{prefix}.merger",
|
prefix=f"{prefix}.merger",
|
||||||
use_data_parallel=use_data_parallel,
|
use_data_parallel=use_data_parallel,
|
||||||
)
|
)
|
||||||
self.attn_backend = get_vit_attn_backend(
|
|
||||||
head_size=head_dim, dtype=torch.get_default_dtype())
|
|
||||||
if self.attn_backend != _Backend.FLASH_ATTN and \
|
|
||||||
check_upstream_fa_availability(
|
|
||||||
torch.get_default_dtype()):
|
|
||||||
self.attn_backend = _Backend.FLASH_ATTN
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def dtype(self) -> torch.dtype:
|
def dtype(self) -> torch.dtype:
|
||||||
|
|||||||
@ -63,7 +63,7 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
|||||||
PromptReplacement, PromptUpdate,
|
PromptReplacement, PromptUpdate,
|
||||||
PromptUpdateDetails)
|
PromptUpdateDetails)
|
||||||
from vllm.multimodal.profiling import BaseDummyInputsBuilder
|
from vllm.multimodal.profiling import BaseDummyInputsBuilder
|
||||||
from vllm.platforms import _Backend
|
from vllm.platforms import _Backend, current_platform
|
||||||
from vllm.sequence import IntermediateTensors
|
from vllm.sequence import IntermediateTensors
|
||||||
from vllm.transformers_utils.config import uses_mrope
|
from vllm.transformers_utils.config import uses_mrope
|
||||||
from vllm.utils import is_list_of
|
from vllm.utils import is_list_of
|
||||||
@ -158,6 +158,8 @@ class Qwen3_VisionBlock(nn.Module):
|
|||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
prefix: str = "",
|
prefix: str = "",
|
||||||
use_data_parallel: bool = False,
|
use_data_parallel: bool = False,
|
||||||
|
attn_backend: _Backend = _Backend.TORCH_SDPA,
|
||||||
|
use_upstream_fa: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
if norm_layer is None:
|
if norm_layer is None:
|
||||||
@ -170,7 +172,9 @@ class Qwen3_VisionBlock(nn.Module):
|
|||||||
projection_size=dim,
|
projection_size=dim,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
prefix=f"{prefix}.attn",
|
prefix=f"{prefix}.attn",
|
||||||
use_data_parallel=use_data_parallel)
|
use_data_parallel=use_data_parallel,
|
||||||
|
attn_backend=attn_backend,
|
||||||
|
use_upstream_fa=use_upstream_fa)
|
||||||
self.mlp = Qwen3_VisionMLP(dim,
|
self.mlp = Qwen3_VisionMLP(dim,
|
||||||
mlp_hidden_dim,
|
mlp_hidden_dim,
|
||||||
act_fn=act_fn,
|
act_fn=act_fn,
|
||||||
@ -287,19 +291,6 @@ class Qwen3_VisionTransformer(nn.Module):
|
|||||||
head_dim = self.hidden_size // self.num_heads
|
head_dim = self.hidden_size // self.num_heads
|
||||||
self.rotary_pos_emb = Qwen2_5_VisionRotaryEmbedding(head_dim // 2)
|
self.rotary_pos_emb = Qwen2_5_VisionRotaryEmbedding(head_dim // 2)
|
||||||
|
|
||||||
self.blocks = nn.ModuleList([
|
|
||||||
Qwen3_VisionBlock(
|
|
||||||
dim=self.hidden_size,
|
|
||||||
num_heads=self.num_heads,
|
|
||||||
mlp_hidden_dim=vision_config.intermediate_size,
|
|
||||||
act_fn=_ACTIVATION_REGISTRY[vision_config.hidden_act],
|
|
||||||
norm_layer=norm_layer,
|
|
||||||
quant_config=quant_config,
|
|
||||||
prefix=f"{prefix}.blocks.{layer_idx}",
|
|
||||||
use_data_parallel=use_data_parallel)
|
|
||||||
for layer_idx in range(vision_config.depth)
|
|
||||||
])
|
|
||||||
|
|
||||||
self.merger = Qwen3_VisionPatchMerger(
|
self.merger = Qwen3_VisionPatchMerger(
|
||||||
d_model=vision_config.out_hidden_size,
|
d_model=vision_config.out_hidden_size,
|
||||||
context_dim=self.hidden_size,
|
context_dim=self.hidden_size,
|
||||||
@ -325,10 +316,42 @@ class Qwen3_VisionTransformer(nn.Module):
|
|||||||
|
|
||||||
self.attn_backend = get_vit_attn_backend(
|
self.attn_backend = get_vit_attn_backend(
|
||||||
head_size=head_dim, dtype=torch.get_default_dtype())
|
head_size=head_dim, dtype=torch.get_default_dtype())
|
||||||
|
use_upstream_fa = False
|
||||||
if self.attn_backend != _Backend.FLASH_ATTN and \
|
if self.attn_backend != _Backend.FLASH_ATTN and \
|
||||||
check_upstream_fa_availability(
|
check_upstream_fa_availability(
|
||||||
torch.get_default_dtype()):
|
torch.get_default_dtype()):
|
||||||
self.attn_backend = _Backend.FLASH_ATTN
|
self.attn_backend = _Backend.FLASH_ATTN
|
||||||
|
use_upstream_fa = True
|
||||||
|
|
||||||
|
if self.attn_backend not in {
|
||||||
|
_Backend.FLASH_ATTN, _Backend.TORCH_SDPA, _Backend.XFORMERS,
|
||||||
|
_Backend.ROCM_AITER_FA
|
||||||
|
}:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Qwen3-VL does not support {self.attn_backend} backend now.")
|
||||||
|
if current_platform.is_device_capability(
|
||||||
|
100) and self.attn_backend != _Backend.TORCH_SDPA:
|
||||||
|
# TODO(Roger/Wentao): remove this after FA
|
||||||
|
# or XFORMERS's issue fixed on Blackwell
|
||||||
|
logger.info_once("Qwen3-VL vision attention does not support "
|
||||||
|
f"{self.attn_backend} backend on Blackwell now. "
|
||||||
|
"Vision attention backend is set to TORCH_SDPA.")
|
||||||
|
self.attn_backend = _Backend.TORCH_SDPA
|
||||||
|
|
||||||
|
self.blocks = nn.ModuleList([
|
||||||
|
Qwen3_VisionBlock(
|
||||||
|
dim=self.hidden_size,
|
||||||
|
num_heads=self.num_heads,
|
||||||
|
mlp_hidden_dim=vision_config.intermediate_size,
|
||||||
|
act_fn=_ACTIVATION_REGISTRY[vision_config.hidden_act],
|
||||||
|
norm_layer=norm_layer,
|
||||||
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.blocks.{layer_idx}",
|
||||||
|
use_data_parallel=use_data_parallel,
|
||||||
|
attn_backend=self.attn_backend,
|
||||||
|
use_upstream_fa=use_upstream_fa)
|
||||||
|
for layer_idx in range(vision_config.depth)
|
||||||
|
])
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def dtype(self) -> torch.dtype:
|
def dtype(self) -> torch.dtype:
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user