[MM][Core] Decouple ViT backend from LM backend (#27061)

Signed-off-by: Roger Wang <hey@rogerw.io>
This commit is contained in:
Roger Wang 2025-10-21 00:30:10 -07:00 committed by GitHub
parent 72f431e709
commit c3a2c6ac5f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
16 changed files with 230 additions and 17 deletions

View File

@ -0,0 +1,25 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
from vllm.attention.backends.registry import _Backend
from vllm.config.multimodal import MultiModalConfig
def test_mm_encoder_attn_backend_str_conversion():
config = MultiModalConfig(mm_encoder_attn_backend="FLASH_ATTN")
assert config.mm_encoder_attn_backend == _Backend.FLASH_ATTN
def test_mm_encoder_attn_backend_invalid():
with pytest.raises(ValueError):
MultiModalConfig(mm_encoder_attn_backend="not_a_backend")
def test_mm_encoder_attn_backend_hash_updates():
base_hash = MultiModalConfig().compute_hash()
overridden_hash = MultiModalConfig(
mm_encoder_attn_backend=_Backend.FLASH_ATTN
).compute_hash()
assert base_hash != overridden_hash

View File

@ -16,6 +16,7 @@ from vllm.attention.backends.registry import _Backend, backend_name_to_enum
from vllm.attention.selector import get_attn_backend
from vllm.attention.utils.kv_sharing_utils import validate_kv_sharing_target
from vllm.config import CacheConfig, get_current_vllm_config
from vllm.config.multimodal import MultiModalConfig
from vllm.config.vllm import VllmConfig
from vllm.distributed.kv_transfer import (
get_kv_transfer_group,
@ -443,6 +444,7 @@ class MultiHeadAttention(nn.Module):
# This has no effect, it is only here to make it easier to swap
# between Attention and MultiHeadAttention
prefix: str = "",
multimodal_config: MultiModalConfig | None = None,
) -> None:
super().__init__()
self.num_heads = num_heads
@ -462,7 +464,14 @@ class MultiHeadAttention(nn.Module):
dtype = torch.get_default_dtype()
# Determine the attention backend
backend = get_vit_attn_backend(head_size=head_size, dtype=dtype)
attn_backend_override = None
if multimodal_config is not None:
attn_backend_override = multimodal_config.mm_encoder_attn_backend
backend = get_vit_attn_backend(
head_size=head_size,
dtype=dtype,
attn_backend_override=attn_backend_override,
)
# Some auto-selected backends can be upgraded
# to upstream flash attention if available.

View File

@ -50,6 +50,7 @@ if TYPE_CHECKING:
import vllm.model_executor.layers.quantization as me_quant
import vllm.model_executor.models as me_models
from vllm.attention.backends.registry import _Backend
from vllm.config.load import LoadConfig
from vllm.config.parallel import ParallelConfig
from vllm.model_executor.layers.quantization import QuantizationMethods
@ -57,6 +58,7 @@ if TYPE_CHECKING:
else:
PretrainedConfig = Any
_Backend = Any
me_quant = LazyLoader(
"model_executor", globals(), "vllm.model_executor.layers.quantization"
)
@ -307,6 +309,7 @@ class ModelConfig:
mm_processor_cache_type: InitVar[MMCacheType | None] = None
mm_shm_cache_max_object_size_mb: InitVar[int | None] = None
mm_encoder_tp_mode: InitVar[MMEncoderTPMode | None] = None
mm_encoder_attn_backend: InitVar[_Backend | str | None] = None
interleave_mm_strings: InitVar[bool | None] = None
skip_mm_profiling: InitVar[bool | None] = None
video_pruning_rate: InitVar[float | None] = None
@ -424,6 +427,7 @@ class ModelConfig:
mm_processor_cache_type: MMCacheType | None,
mm_shm_cache_max_object_size_mb: int | None,
mm_encoder_tp_mode: MMEncoderTPMode | None,
mm_encoder_attn_backend: _Backend | str | None,
interleave_mm_strings: bool | None,
skip_mm_profiling: bool | None,
video_pruning_rate: float | None,
@ -733,6 +737,7 @@ class ModelConfig:
mm_processor_cache_type=mm_processor_cache_type,
mm_shm_cache_max_object_size_mb=mm_shm_cache_max_object_size_mb,
mm_encoder_tp_mode=mm_encoder_tp_mode,
mm_encoder_attn_backend=mm_encoder_attn_backend,
interleave_mm_strings=interleave_mm_strings,
skip_mm_profiling=skip_mm_profiling,
video_pruning_rate=video_pruning_rate,

View File

@ -3,13 +3,18 @@
import hashlib
from collections.abc import Mapping
from typing import Any, Literal, TypeAlias
from typing import TYPE_CHECKING, Any, Literal, TypeAlias
from pydantic import ConfigDict, Field, field_validator, model_validator
from pydantic.dataclasses import dataclass
from vllm.config.utils import config
if TYPE_CHECKING:
from vllm.attention.backends.registry import _Backend
else:
_Backend = Any
@dataclass
class BaseDummyOptions:
@ -112,6 +117,10 @@ class MultiModalConfig:
DP (which is controlled by `--data-parallel-size`).
This is only supported on a per-model basis and falls back to
`"weights"` if the encoder does not support DP."""
mm_encoder_attn_backend: _Backend | None = None
"""Optional override for the multi-modal encoder attention backend when
using vision transformers. Accepts any value from
`vllm.attention.backends.registry._Backend` (e.g. `FLASH_ATTN`)."""
interleave_mm_strings: bool = False
"""Enable fully interleaved support for multimodal prompts, while using
--chat-template-content-format=string."""
@ -148,6 +157,29 @@ class MultiModalConfig:
value[k] = BaseDummyOptions(**v)
return value
@field_validator("mm_encoder_attn_backend", mode="before")
@classmethod
def _validate_mm_encoder_attn_backend(cls, value: object) -> _Backend | None:
from vllm.attention.backends.registry import (
_Backend as BackendEnum,
)
from vllm.attention.backends.registry import (
backend_name_to_enum,
)
if value is None or isinstance(value, BackendEnum):
return value
if isinstance(value, str):
candidate = backend_name_to_enum(value.upper())
if candidate is not None:
return candidate
valid_backends = ", ".join(sorted(BackendEnum.__members__.keys()))
raise ValueError(
f"Invalid mm encoder attention backend. Expected one of: {valid_backends}."
)
@model_validator(mode="after")
def _validate_multimodal_config(self):
if self.mm_processor_cache_type != "shm" and (
@ -172,9 +204,11 @@ class MultiModalConfig:
excluding anything before input ids/embeddings and after
the final hidden states.
"""
# no factors to consider.
# this config will not affect the computation graph.
factors: list[Any] = []
factors: list[Any] = [
self.mm_encoder_attn_backend.name
if self.mm_encoder_attn_backend is not None
else None
]
hash_str = hashlib.md5(str(factors).encode(), usedforsecurity=False).hexdigest()
return hash_str

View File

@ -32,6 +32,7 @@ from pydantic.fields import FieldInfo
from typing_extensions import TypeIs, deprecated
import vllm.envs as envs
from vllm.attention.backends.registry import _Backend
from vllm.config import (
CacheConfig,
CompilationConfig,
@ -451,6 +452,9 @@ class EngineArgs:
MultiModalConfig.mm_shm_cache_max_object_size_mb
)
mm_encoder_tp_mode: MMEncoderTPMode = MultiModalConfig.mm_encoder_tp_mode
mm_encoder_attn_backend: _Backend | str | None = (
MultiModalConfig.mm_encoder_attn_backend
)
io_processor_plugin: str | None = None
skip_mm_profiling: bool = MultiModalConfig.skip_mm_profiling
video_pruning_rate: float = MultiModalConfig.video_pruning_rate
@ -914,6 +918,10 @@ class EngineArgs:
multimodal_group.add_argument(
"--mm-encoder-tp-mode", **multimodal_kwargs["mm_encoder_tp_mode"]
)
multimodal_group.add_argument(
"--mm-encoder-attn-backend",
**multimodal_kwargs["mm_encoder_attn_backend"],
)
multimodal_group.add_argument(
"--interleave-mm-strings", **multimodal_kwargs["interleave_mm_strings"]
)
@ -1160,6 +1168,7 @@ class EngineArgs:
mm_processor_cache_type=self.mm_processor_cache_type,
mm_shm_cache_max_object_size_mb=self.mm_shm_cache_max_object_size_mb,
mm_encoder_tp_mode=self.mm_encoder_tp_mode,
mm_encoder_attn_backend=self.mm_encoder_attn_backend,
pooler_config=self.pooler_config,
override_pooler_config=self.override_pooler_config,
logits_processor_pattern=self.logits_processor_pattern,

View File

@ -256,6 +256,7 @@ class DotsVisionAttention(nn.Module):
quant_config: QuantizationConfig | None = None,
prefix: str = "",
use_data_parallel: bool = False,
attn_backend_override: _Backend | None = None,
) -> None:
super().__init__()
@ -288,7 +289,9 @@ class DotsVisionAttention(nn.Module):
)
# Select attention backend
self.attn_backend = get_vit_attn_backend(
self.hidden_size_per_attention_head, torch.get_default_dtype()
self.hidden_size_per_attention_head,
torch.get_default_dtype(),
attn_backend_override=attn_backend_override,
)
self.use_upstream_fa = False
@ -510,6 +513,7 @@ class DotsVisionBlock(nn.Module):
quant_config: QuantizationConfig | None = None,
prefix: str = "",
use_data_parallel: bool = False,
attn_backend_override: _Backend | None = None,
):
super().__init__()
@ -521,6 +525,7 @@ class DotsVisionBlock(nn.Module):
quant_config=quant_config,
prefix=f"{prefix}.attn",
use_data_parallel=use_data_parallel,
attn_backend_override=attn_backend_override,
)
self.norm1 = RMSNorm(config.embed_dim, eps=config.rms_norm_eps)
self.mlp = DotsSwiGLUFFN(
@ -561,6 +566,7 @@ class DotsVisionTransformer(nn.Module):
require_post_norm: bool | None = None,
prefix: str = "",
use_data_parallel: bool = False,
attn_backend_override: _Backend | None = None,
) -> None:
super().__init__()
self.config = config
@ -571,7 +577,9 @@ class DotsVisionTransformer(nn.Module):
head_dim = config.embed_dim // config.num_attention_heads
self.rotary_pos_emb = VisionRotaryEmbedding(head_dim // 2)
self.attn_backend = get_vit_attn_backend(
head_size=head_dim, dtype=torch.get_default_dtype()
head_size=head_dim,
dtype=torch.get_default_dtype(),
attn_backend_override=attn_backend_override,
)
if self.attn_backend != _Backend.FLASH_ATTN and check_upstream_fa_availability(
torch.get_default_dtype()
@ -591,6 +599,7 @@ class DotsVisionTransformer(nn.Module):
quant_config=quant_config,
prefix=f"{prefix}.blocks.{i}",
use_data_parallel=use_data_parallel,
attn_backend_override=attn_backend_override,
)
for i in range(num_layers)
]
@ -750,11 +759,17 @@ class DotsOCRForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA
self.config.vision_config = vision_config
else:
vision_config = self.config.vision_config
attn_backend_override = (
multimodal_config.mm_encoder_attn_backend
if multimodal_config is not None
else None
)
self.vision_tower = DotsVisionTransformer(
vision_config,
quant_config=self.quant_config,
prefix=maybe_prefix(prefix, "vision_tower"),
use_data_parallel=self.use_data_parallel,
attn_backend_override=attn_backend_override,
)
self.language_model: Qwen2ForCausalLM = init_vllm_registered_model(
vllm_config=vllm_config,

View File

@ -164,6 +164,7 @@ class Ernie4_5_VisionAttention(nn.Module):
projection_size: int,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
attn_backend_override: _Backend | None = None,
) -> None:
super().__init__()
# Per attention head and per partition values.
@ -196,6 +197,7 @@ class Ernie4_5_VisionAttention(nn.Module):
self.attn_backend = get_vit_attn_backend(
head_size=self.hidden_size_per_attention_head,
dtype=torch.get_default_dtype(),
attn_backend_override=attn_backend_override,
)
self.use_upstream_fa = False
@ -367,6 +369,7 @@ class Ernie4_5_VisionBlock(nn.Module):
norm_layer: Callable[[int], nn.Module] | None = None,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
attn_backend_override: _Backend | None = None,
) -> None:
super().__init__()
@ -382,6 +385,7 @@ class Ernie4_5_VisionBlock(nn.Module):
projection_size=dim,
quant_config=quant_config,
prefix=f"{prefix}.attn",
attn_backend_override=attn_backend_override,
)
self.mlp = Ernie4_5_VisionMLP(
@ -458,6 +462,7 @@ class Ernie4_5_VisionTransformer(nn.Module):
norm_eps: float = 1e-6,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
attn_backend_override: _Backend | None = None,
) -> None:
super().__init__()
patch_size = vision_config.patch_size
@ -493,6 +498,7 @@ class Ernie4_5_VisionTransformer(nn.Module):
norm_layer=norm_layer,
quant_config=quant_config,
prefix=f"{prefix}.blocks.{layer_idx}",
attn_backend_override=attn_backend_override,
)
for layer_idx in range(depth)
]
@ -504,7 +510,9 @@ class Ernie4_5_VisionTransformer(nn.Module):
self.ln = nn.LayerNorm(hidden_size, eps=1e-6)
self.attn_backend = get_vit_attn_backend(
head_size=head_dim, dtype=torch.get_default_dtype()
head_size=head_dim,
dtype=torch.get_default_dtype(),
attn_backend_override=attn_backend_override,
)
if self.attn_backend != _Backend.FLASH_ATTN and check_upstream_fa_availability(
torch.get_default_dtype()
@ -1327,11 +1335,17 @@ class Ernie4_5_VLMoeForConditionalGeneration(
self.config = config
self.multimodal_config = multimodal_config
attn_backend_override = (
multimodal_config.mm_encoder_attn_backend
if multimodal_config is not None
else None
)
self.vision_model = Ernie4_5_VisionTransformer(
config.vision_config,
norm_eps=getattr(config, "rms_norm_eps", 1e-6),
quant_config=quant_config,
prefix=maybe_prefix(prefix, "vision_model"),
attn_backend_override=attn_backend_override,
)
self.language_model = Ernie4_5_VLMoeForCausalLM(

View File

@ -247,6 +247,7 @@ class Glm4vVisionAttention(nn.Module):
quant_config: QuantizationConfig | None = None,
prefix: str = "",
use_data_parallel: bool = False,
attn_backend_override: _Backend | None = None,
) -> None:
super().__init__()
# Per attention head and per partition values.
@ -287,6 +288,7 @@ class Glm4vVisionAttention(nn.Module):
self.attn_backend = get_vit_attn_backend(
head_size=self.hidden_size_per_attention_head,
dtype=torch.get_default_dtype(),
attn_backend_override=attn_backend_override,
)
self.use_upstream_fa = False
@ -417,6 +419,7 @@ class Glm4vVisionBlock(nn.Module):
quant_config: QuantizationConfig | None = None,
prefix: str = "",
use_data_parallel: bool = False,
attn_backend_override: _Backend | None = None,
) -> None:
super().__init__()
if norm_layer is None:
@ -430,6 +433,7 @@ class Glm4vVisionBlock(nn.Module):
quant_config=quant_config,
prefix=f"{prefix}.attn",
use_data_parallel=use_data_parallel,
attn_backend_override=attn_backend_override,
)
self.mlp = Glm4vVisionMLP(
dim,
@ -696,6 +700,7 @@ class Glm4vVisionTransformer(nn.Module):
quant_config: QuantizationConfig | None = None,
prefix: str = "",
use_data_parallel: bool = False,
attn_backend_override: _Backend | None = None,
) -> None:
super().__init__()
@ -731,6 +736,7 @@ class Glm4vVisionTransformer(nn.Module):
quant_config=quant_config,
prefix=f"{prefix}.blocks.{layer_idx}",
use_data_parallel=self.use_data_parallel,
attn_backend_override=attn_backend_override,
)
for layer_idx in range(depth)
]
@ -759,7 +765,9 @@ class Glm4vVisionTransformer(nn.Module):
)
self.attn_backend = get_vit_attn_backend(
head_size=head_dim, dtype=torch.get_default_dtype()
head_size=head_dim,
dtype=torch.get_default_dtype(),
attn_backend_override=attn_backend_override,
)
if self.attn_backend != _Backend.FLASH_ATTN and check_upstream_fa_availability(
torch.get_default_dtype()
@ -1437,12 +1445,18 @@ class Glm4vForConditionalGeneration(
self.multimodal_config = multimodal_config
self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data"
attn_backend_override = (
multimodal_config.mm_encoder_attn_backend
if multimodal_config is not None
else None
)
self.visual = Glm4vVisionTransformer(
config.vision_config,
norm_eps=getattr(config, "rms_norm_eps", 1e-5),
quant_config=quant_config,
prefix=maybe_prefix(prefix, "visual"),
use_data_parallel=self.use_data_parallel,
attn_backend_override=attn_backend_override,
)
if config.model_type == "glm4v":

View File

@ -353,6 +353,7 @@ class KeyeSiglipAttention(nn.Module):
config: PretrainedConfig,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
attn_backend_override: _Backend | None = None,
):
super().__init__()
self.config = config
@ -392,7 +393,9 @@ class KeyeSiglipAttention(nn.Module):
# Detect attention implementation.
self.attn_backend = get_vit_attn_backend(
head_size=self.head_dim, dtype=torch.get_default_dtype()
head_size=self.head_dim,
dtype=torch.get_default_dtype(),
attn_backend_override=attn_backend_override,
)
self.use_upstream_fa = False
@ -521,6 +524,7 @@ class KeyeSiglipEncoderLayer(nn.Module):
config: PretrainedConfig,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
attn_backend_override: _Backend | None = None,
):
super().__init__()
self.embed_dim = config.hidden_size
@ -529,6 +533,7 @@ class KeyeSiglipEncoderLayer(nn.Module):
config,
quant_config=quant_config,
prefix=f"{prefix}.self_attn",
attn_backend_override=attn_backend_override,
)
self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
self.mlp = SiglipMLP(
@ -573,6 +578,7 @@ class KeyeSiglipEncoder(nn.Module):
config: PretrainedConfig,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
attn_backend_override: _Backend | None = None,
):
super().__init__()
self.config = config
@ -585,6 +591,7 @@ class KeyeSiglipEncoder(nn.Module):
config,
quant_config=quant_config,
prefix=f"{prefix}.layers.{layer_idx}",
attn_backend_override=attn_backend_override,
)
for layer_idx in range(config.num_hidden_layers)
]
@ -666,6 +673,7 @@ class KeyeSiglipVisionTransformer(nn.Module):
config: PretrainedConfig,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
attn_backend_override: _Backend | None = None,
):
super().__init__()
self.config = config
@ -676,6 +684,7 @@ class KeyeSiglipVisionTransformer(nn.Module):
config,
quant_config=quant_config,
prefix=f"{prefix}.encoder",
attn_backend_override=attn_backend_override,
)
self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
@ -747,6 +756,7 @@ class KeyeSiglipVisionModel(nn.Module):
config: PretrainedConfig,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
attn_backend_override: _Backend | None = None,
):
super().__init__()
@ -754,6 +764,7 @@ class KeyeSiglipVisionModel(nn.Module):
config,
quant_config=quant_config,
prefix=f"{prefix}.vision_model",
attn_backend_override=attn_backend_override,
)
self.quant_config = quant_config
@ -1296,10 +1307,16 @@ class BaseKeyeModule(nn.Module):
self.config = config
self.multimodal_config = multimodal_config
attn_backend_override = (
multimodal_config.mm_encoder_attn_backend
if multimodal_config is not None
else None
)
self.visual = KeyeSiglipVisionModel(
config.vision_config,
quant_config=quant_config,
prefix=maybe_prefix(prefix, "visual"),
attn_backend_override=attn_backend_override,
)
self.mlp_AR = self._build_projector(

View File

@ -10,6 +10,7 @@ import torch
import torch.nn as nn
from transformers import BaseImageProcessor, BatchFeature, PretrainedConfig
from vllm.attention.backends.registry import _Backend
from vllm.config import VllmConfig
from vllm.config.multimodal import BaseDummyOptions
from vllm.model_executor.layers.linear import ReplicatedLinear
@ -105,6 +106,7 @@ class VisualTokenizer(torch.nn.Module):
quant_config: QuantizationConfig | None = None,
prefix: str = "",
use_data_parallel: bool = False,
attn_backend_override: _Backend | None = None,
):
super().__init__()
self.config = config
@ -113,6 +115,7 @@ class VisualTokenizer(torch.nn.Module):
quant_config=quant_config,
prefix=f"{prefix}.vit",
use_data_parallel=use_data_parallel,
attn_backend_override=attn_backend_override,
)
# reserved tokens for INDICATOR_IDS
head_dim = visual_vocab_size - len(INDICATOR_IDS)
@ -132,6 +135,7 @@ class VisualTokenizer(torch.nn.Module):
quant_config: QuantizationConfig | None = None,
prefix: str = "",
use_data_parallel: bool = False,
attn_backend_override: _Backend | None = None,
):
model_type = config.model_type
if model_type == "siglip2_navit":
@ -140,6 +144,7 @@ class VisualTokenizer(torch.nn.Module):
quant_config=quant_config,
prefix=prefix,
use_data_parallel=use_data_parallel,
attn_backend_override=attn_backend_override,
)
raise ValueError(f"Unsupported visual tokenizer model_type: {model_type}")
@ -457,6 +462,7 @@ class Ovis2_5(nn.Module, SupportsMultiModal, SupportsPP):
super().__init__()
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
multimodal_config = vllm_config.model_config.multimodal_config
self.config: PretrainedConfig = config
self.llm = init_vllm_registered_model(
@ -464,11 +470,17 @@ class Ovis2_5(nn.Module, SupportsMultiModal, SupportsPP):
prefix=maybe_prefix(prefix, "llm"),
)
attn_backend_override = (
multimodal_config.mm_encoder_attn_backend
if multimodal_config is not None
else None
)
self.visual_tokenizer = VisualTokenizer(
config=config.vit_config,
visual_vocab_size=config.visual_vocab_size,
quant_config=quant_config,
prefix=f"{prefix}.visual_tokenizer",
attn_backend_override=attn_backend_override,
)
self.vte = VisualEmbedding(config.visual_vocab_size, config.hidden_size)

View File

@ -637,6 +637,7 @@ class Qwen2_5_VisionTransformer(nn.Module):
quant_config: QuantizationConfig | None = None,
prefix: str = "",
use_data_parallel: bool = False,
attn_backend_override: _Backend | None = None,
) -> None:
super().__init__()
@ -669,7 +670,9 @@ class Qwen2_5_VisionTransformer(nn.Module):
use_upstream_fa = False
self.attn_backend = get_vit_attn_backend(
head_size=head_dim, dtype=torch.get_default_dtype()
head_size=head_dim,
dtype=torch.get_default_dtype(),
attn_backend_override=attn_backend_override,
)
if (
self.attn_backend != _Backend.FLASH_ATTN
@ -1226,12 +1229,18 @@ class Qwen2_5_VLForConditionalGeneration(
if multimodal_config.get_limit_per_prompt(
"image"
) or multimodal_config.get_limit_per_prompt("video"):
attn_backend_override = (
multimodal_config.mm_encoder_attn_backend
if multimodal_config is not None
else None
)
self.visual = Qwen2_5_VisionTransformer(
config.vision_config,
norm_eps=getattr(config, "rms_norm_eps", 1e-6),
quant_config=self.quant_config,
prefix=maybe_prefix(prefix, "visual"),
use_data_parallel=self.use_data_parallel,
attn_backend_override=attn_backend_override,
)
else:
self.visual = None

View File

@ -320,6 +320,7 @@ class Qwen2VisionAttention(nn.Module):
quant_config: QuantizationConfig | None = None,
prefix: str = "",
use_data_parallel: bool = False,
attn_backend_override: _Backend | None = None,
) -> None:
super().__init__()
# Per attention head and per partition values.
@ -355,6 +356,7 @@ class Qwen2VisionAttention(nn.Module):
self.attn_backend = get_vit_attn_backend(
head_size=self.hidden_size_per_attention_head,
dtype=torch.get_default_dtype(),
attn_backend_override=attn_backend_override,
)
self.use_upstream_fa = False
@ -497,6 +499,7 @@ class Qwen2VisionBlock(nn.Module):
quant_config: QuantizationConfig | None = None,
prefix: str = "",
use_data_parallel: bool = False,
attn_backend_override: _Backend | None = None,
) -> None:
super().__init__()
if norm_layer is None:
@ -512,6 +515,7 @@ class Qwen2VisionBlock(nn.Module):
quant_config=quant_config,
prefix=f"{prefix}.attn",
use_data_parallel=use_data_parallel,
attn_backend_override=attn_backend_override,
)
self.mlp = Qwen2VisionMLP(
dim,
@ -662,6 +666,7 @@ class Qwen2VisionTransformer(nn.Module):
quant_config: QuantizationConfig | None = None,
prefix: str = "",
use_data_parallel: bool = False,
attn_backend_override: _Backend | None = None,
) -> None:
super().__init__()
@ -703,6 +708,7 @@ class Qwen2VisionTransformer(nn.Module):
quant_config=quant_config,
prefix=f"{prefix}.blocks.{layer_idx}",
use_data_parallel=use_data_parallel,
attn_backend_override=attn_backend_override,
)
for layer_idx in range(depth)
]
@ -716,7 +722,9 @@ class Qwen2VisionTransformer(nn.Module):
use_data_parallel=use_data_parallel,
)
self.attn_backend = get_vit_attn_backend(
head_size=head_dim, dtype=torch.get_default_dtype()
head_size=head_dim,
dtype=torch.get_default_dtype(),
attn_backend_override=attn_backend_override,
)
if self.attn_backend != _Backend.FLASH_ATTN and check_upstream_fa_availability(
torch.get_default_dtype()
@ -1356,12 +1364,18 @@ class Qwen2VLForConditionalGeneration(
if multimodal_config.get_limit_per_prompt(
"image"
) or multimodal_config.get_limit_per_prompt("video"):
attn_backend_override = (
multimodal_config.mm_encoder_attn_backend
if multimodal_config is not None
else None
)
self.visual = Qwen2VisionTransformer(
config.vision_config,
norm_eps=getattr(config, "rms_norm_eps", 1e-6),
quant_config=quant_config,
prefix=maybe_prefix(prefix, "visual"),
use_data_parallel=self.use_data_parallel,
attn_backend_override=attn_backend_override,
)
else:
self.visual = None

View File

@ -296,6 +296,7 @@ class Qwen3Omni_VisionTransformer(nn.Module):
norm_eps: float = 1e-6,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
attn_backend_override: _Backend | None = None,
) -> None:
super().__init__()
self.hidden_size = vision_config.hidden_size
@ -367,7 +368,9 @@ class Qwen3Omni_VisionTransformer(nn.Module):
)
self.attn_backend = get_vit_attn_backend(
head_size=head_dim, dtype=torch.get_default_dtype()
head_size=head_dim,
dtype=torch.get_default_dtype(),
attn_backend_override=attn_backend_override,
)
if self.attn_backend != _Backend.FLASH_ATTN and check_upstream_fa_availability(
torch.get_default_dtype()
@ -1144,11 +1147,17 @@ class Qwen3OmniMoeThinkerForConditionalGeneration(
self.audio_tower = Qwen3OmniMoeAudioEncoder(thinker_config.audio_config)
attn_backend_override = (
multimodal_config.mm_encoder_attn_backend
if multimodal_config is not None
else None
)
self.visual = Qwen3Omni_VisionTransformer(
vision_config=thinker_config.vision_config,
norm_eps=getattr(thinker_config.text_config, "rms_norm_eps", 1e-6),
quant_config=quant_config,
prefix=maybe_prefix(prefix, "visual"),
attn_backend_override=attn_backend_override,
)
self.quant_config = quant_config

View File

@ -300,6 +300,7 @@ class Qwen3_VisionTransformer(nn.Module):
quant_config: QuantizationConfig | None = None,
prefix: str = "",
use_data_parallel: bool = False,
attn_backend_override: _Backend | None = None,
) -> None:
super().__init__()
self.hidden_size = vision_config.hidden_size
@ -359,7 +360,9 @@ class Qwen3_VisionTransformer(nn.Module):
)
self.attn_backend = get_vit_attn_backend(
head_size=head_dim, dtype=torch.get_default_dtype()
head_size=head_dim,
dtype=torch.get_default_dtype(),
attn_backend_override=attn_backend_override,
)
use_upstream_fa = False
if (
@ -379,7 +382,6 @@ class Qwen3_VisionTransformer(nn.Module):
raise RuntimeError(
f"Qwen3-VL does not support {self.attn_backend} backend now."
)
self.blocks = nn.ModuleList(
[
Qwen3_VisionBlock(
@ -1214,12 +1216,18 @@ class Qwen3VLForConditionalGeneration(
) and not multimodal_config.get_limit_per_prompt("video"):
self.visual = None
else:
attn_backend_override = (
multimodal_config.mm_encoder_attn_backend
if multimodal_config is not None
else None
)
self.visual = Qwen3_VisionTransformer(
config.vision_config,
norm_eps=getattr(config, "rms_norm_eps", 1e-6),
quant_config=quant_config,
prefix=maybe_prefix(prefix, "visual"),
use_data_parallel=self.use_data_parallel,
attn_backend_override=attn_backend_override,
)
self.language_model = Qwen3LLMForCausalLM(

View File

@ -208,6 +208,7 @@ class Siglip2Attention(nn.Module):
quant_config: QuantizationConfig | None = None,
prefix: str = "",
use_data_parallel: bool = False,
attn_backend_override: _Backend | None = None,
):
super().__init__()
self.config = config
@ -248,7 +249,9 @@ class Siglip2Attention(nn.Module):
# Detect attention implementation.
self.attn_backend = get_vit_attn_backend(
head_size=self.head_dim, dtype=torch.get_default_dtype()
head_size=self.head_dim,
dtype=torch.get_default_dtype(),
attn_backend_override=attn_backend_override,
)
self.use_upstream_fa = False
@ -372,6 +375,7 @@ class Siglip2EncoderLayer(nn.Module):
quant_config: QuantizationConfig | None = None,
prefix: str = "",
use_data_parallel: bool = False,
attn_backend_override: _Backend | None = None,
):
super().__init__()
self.embed_dim = config.hidden_size
@ -381,6 +385,7 @@ class Siglip2EncoderLayer(nn.Module):
quant_config=quant_config,
prefix=f"{prefix}.self_attn",
use_data_parallel=use_data_parallel,
attn_backend_override=attn_backend_override,
)
self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
self.mlp = Siglip2MLP(
@ -434,6 +439,7 @@ class Siglip2Encoder(nn.Module):
quant_config: QuantizationConfig | None = None,
prefix: str = "",
use_data_parallel: bool = False,
attn_backend_override: _Backend | None = None,
):
super().__init__()
self.config = config
@ -444,6 +450,7 @@ class Siglip2Encoder(nn.Module):
quant_config=quant_config,
prefix=f"{prefix}.layers.{idx}",
use_data_parallel=use_data_parallel,
attn_backend_override=attn_backend_override,
)
for idx in range(config.num_hidden_layers)
]
@ -618,6 +625,7 @@ class Siglip2VisionTransformer(nn.Module):
quant_config: QuantizationConfig | None = None,
prefix: str = "",
use_data_parallel: bool = False,
attn_backend_override: _Backend | None = None,
):
super().__init__()
self.config = config
@ -629,6 +637,7 @@ class Siglip2VisionTransformer(nn.Module):
quant_config=quant_config,
prefix=f"{prefix}.encoder",
use_data_parallel=use_data_parallel,
attn_backend_override=attn_backend_override,
)
self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
@ -657,6 +666,7 @@ class Siglip2NavitModel(torch.nn.Module):
quant_config: QuantizationConfig | None = None,
prefix: str = "",
use_data_parallel: bool = False,
attn_backend_override: _Backend | None = None,
):
super().__init__()
@ -665,6 +675,7 @@ class Siglip2NavitModel(torch.nn.Module):
quant_config=quant_config,
prefix=f"{prefix}.vision_model",
use_data_parallel=use_data_parallel,
attn_backend_override=attn_backend_override,
)
def forward(

View File

@ -78,10 +78,18 @@ def get_vision_encoder_info(hf_config: VisionLanguageConfig) -> VisionEncoderInf
raise NotImplementedError(msg)
def get_vit_attn_backend(head_size: int, dtype: torch.dtype) -> _Backend:
def get_vit_attn_backend(
head_size: int,
dtype: torch.dtype,
*,
attn_backend_override: _Backend | None = None,
) -> _Backend:
"""
Get the available attention backend for Vision Transformer.
"""
if attn_backend_override is not None:
return attn_backend_override
# Lazy import to avoid circular dependency
from vllm.attention.selector import get_env_variable_attn_backend