mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 20:15:35 +08:00
[MM][Core] Decouple ViT backend from LM backend (#27061)
Signed-off-by: Roger Wang <hey@rogerw.io>
This commit is contained in:
parent
72f431e709
commit
c3a2c6ac5f
25
tests/config/test_multimodal_config.py
Normal file
25
tests/config/test_multimodal_config.py
Normal file
@ -0,0 +1,25 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.attention.backends.registry import _Backend
|
||||
from vllm.config.multimodal import MultiModalConfig
|
||||
|
||||
|
||||
def test_mm_encoder_attn_backend_str_conversion():
|
||||
config = MultiModalConfig(mm_encoder_attn_backend="FLASH_ATTN")
|
||||
assert config.mm_encoder_attn_backend == _Backend.FLASH_ATTN
|
||||
|
||||
|
||||
def test_mm_encoder_attn_backend_invalid():
|
||||
with pytest.raises(ValueError):
|
||||
MultiModalConfig(mm_encoder_attn_backend="not_a_backend")
|
||||
|
||||
|
||||
def test_mm_encoder_attn_backend_hash_updates():
|
||||
base_hash = MultiModalConfig().compute_hash()
|
||||
overridden_hash = MultiModalConfig(
|
||||
mm_encoder_attn_backend=_Backend.FLASH_ATTN
|
||||
).compute_hash()
|
||||
assert base_hash != overridden_hash
|
||||
@ -16,6 +16,7 @@ from vllm.attention.backends.registry import _Backend, backend_name_to_enum
|
||||
from vllm.attention.selector import get_attn_backend
|
||||
from vllm.attention.utils.kv_sharing_utils import validate_kv_sharing_target
|
||||
from vllm.config import CacheConfig, get_current_vllm_config
|
||||
from vllm.config.multimodal import MultiModalConfig
|
||||
from vllm.config.vllm import VllmConfig
|
||||
from vllm.distributed.kv_transfer import (
|
||||
get_kv_transfer_group,
|
||||
@ -443,6 +444,7 @@ class MultiHeadAttention(nn.Module):
|
||||
# This has no effect, it is only here to make it easier to swap
|
||||
# between Attention and MultiHeadAttention
|
||||
prefix: str = "",
|
||||
multimodal_config: MultiModalConfig | None = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.num_heads = num_heads
|
||||
@ -462,7 +464,14 @@ class MultiHeadAttention(nn.Module):
|
||||
dtype = torch.get_default_dtype()
|
||||
|
||||
# Determine the attention backend
|
||||
backend = get_vit_attn_backend(head_size=head_size, dtype=dtype)
|
||||
attn_backend_override = None
|
||||
if multimodal_config is not None:
|
||||
attn_backend_override = multimodal_config.mm_encoder_attn_backend
|
||||
backend = get_vit_attn_backend(
|
||||
head_size=head_size,
|
||||
dtype=dtype,
|
||||
attn_backend_override=attn_backend_override,
|
||||
)
|
||||
|
||||
# Some auto-selected backends can be upgraded
|
||||
# to upstream flash attention if available.
|
||||
|
||||
@ -50,6 +50,7 @@ if TYPE_CHECKING:
|
||||
|
||||
import vllm.model_executor.layers.quantization as me_quant
|
||||
import vllm.model_executor.models as me_models
|
||||
from vllm.attention.backends.registry import _Backend
|
||||
from vllm.config.load import LoadConfig
|
||||
from vllm.config.parallel import ParallelConfig
|
||||
from vllm.model_executor.layers.quantization import QuantizationMethods
|
||||
@ -57,6 +58,7 @@ if TYPE_CHECKING:
|
||||
else:
|
||||
PretrainedConfig = Any
|
||||
|
||||
_Backend = Any
|
||||
me_quant = LazyLoader(
|
||||
"model_executor", globals(), "vllm.model_executor.layers.quantization"
|
||||
)
|
||||
@ -307,6 +309,7 @@ class ModelConfig:
|
||||
mm_processor_cache_type: InitVar[MMCacheType | None] = None
|
||||
mm_shm_cache_max_object_size_mb: InitVar[int | None] = None
|
||||
mm_encoder_tp_mode: InitVar[MMEncoderTPMode | None] = None
|
||||
mm_encoder_attn_backend: InitVar[_Backend | str | None] = None
|
||||
interleave_mm_strings: InitVar[bool | None] = None
|
||||
skip_mm_profiling: InitVar[bool | None] = None
|
||||
video_pruning_rate: InitVar[float | None] = None
|
||||
@ -424,6 +427,7 @@ class ModelConfig:
|
||||
mm_processor_cache_type: MMCacheType | None,
|
||||
mm_shm_cache_max_object_size_mb: int | None,
|
||||
mm_encoder_tp_mode: MMEncoderTPMode | None,
|
||||
mm_encoder_attn_backend: _Backend | str | None,
|
||||
interleave_mm_strings: bool | None,
|
||||
skip_mm_profiling: bool | None,
|
||||
video_pruning_rate: float | None,
|
||||
@ -733,6 +737,7 @@ class ModelConfig:
|
||||
mm_processor_cache_type=mm_processor_cache_type,
|
||||
mm_shm_cache_max_object_size_mb=mm_shm_cache_max_object_size_mb,
|
||||
mm_encoder_tp_mode=mm_encoder_tp_mode,
|
||||
mm_encoder_attn_backend=mm_encoder_attn_backend,
|
||||
interleave_mm_strings=interleave_mm_strings,
|
||||
skip_mm_profiling=skip_mm_profiling,
|
||||
video_pruning_rate=video_pruning_rate,
|
||||
|
||||
@ -3,13 +3,18 @@
|
||||
|
||||
import hashlib
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, Literal, TypeAlias
|
||||
from typing import TYPE_CHECKING, Any, Literal, TypeAlias
|
||||
|
||||
from pydantic import ConfigDict, Field, field_validator, model_validator
|
||||
from pydantic.dataclasses import dataclass
|
||||
|
||||
from vllm.config.utils import config
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.attention.backends.registry import _Backend
|
||||
else:
|
||||
_Backend = Any
|
||||
|
||||
|
||||
@dataclass
|
||||
class BaseDummyOptions:
|
||||
@ -112,6 +117,10 @@ class MultiModalConfig:
|
||||
DP (which is controlled by `--data-parallel-size`).
|
||||
This is only supported on a per-model basis and falls back to
|
||||
`"weights"` if the encoder does not support DP."""
|
||||
mm_encoder_attn_backend: _Backend | None = None
|
||||
"""Optional override for the multi-modal encoder attention backend when
|
||||
using vision transformers. Accepts any value from
|
||||
`vllm.attention.backends.registry._Backend` (e.g. `FLASH_ATTN`)."""
|
||||
interleave_mm_strings: bool = False
|
||||
"""Enable fully interleaved support for multimodal prompts, while using
|
||||
--chat-template-content-format=string."""
|
||||
@ -148,6 +157,29 @@ class MultiModalConfig:
|
||||
value[k] = BaseDummyOptions(**v)
|
||||
return value
|
||||
|
||||
@field_validator("mm_encoder_attn_backend", mode="before")
|
||||
@classmethod
|
||||
def _validate_mm_encoder_attn_backend(cls, value: object) -> _Backend | None:
|
||||
from vllm.attention.backends.registry import (
|
||||
_Backend as BackendEnum,
|
||||
)
|
||||
from vllm.attention.backends.registry import (
|
||||
backend_name_to_enum,
|
||||
)
|
||||
|
||||
if value is None or isinstance(value, BackendEnum):
|
||||
return value
|
||||
|
||||
if isinstance(value, str):
|
||||
candidate = backend_name_to_enum(value.upper())
|
||||
if candidate is not None:
|
||||
return candidate
|
||||
|
||||
valid_backends = ", ".join(sorted(BackendEnum.__members__.keys()))
|
||||
raise ValueError(
|
||||
f"Invalid mm encoder attention backend. Expected one of: {valid_backends}."
|
||||
)
|
||||
|
||||
@model_validator(mode="after")
|
||||
def _validate_multimodal_config(self):
|
||||
if self.mm_processor_cache_type != "shm" and (
|
||||
@ -172,9 +204,11 @@ class MultiModalConfig:
|
||||
excluding anything before input ids/embeddings and after
|
||||
the final hidden states.
|
||||
"""
|
||||
# no factors to consider.
|
||||
# this config will not affect the computation graph.
|
||||
factors: list[Any] = []
|
||||
factors: list[Any] = [
|
||||
self.mm_encoder_attn_backend.name
|
||||
if self.mm_encoder_attn_backend is not None
|
||||
else None
|
||||
]
|
||||
hash_str = hashlib.md5(str(factors).encode(), usedforsecurity=False).hexdigest()
|
||||
return hash_str
|
||||
|
||||
|
||||
@ -32,6 +32,7 @@ from pydantic.fields import FieldInfo
|
||||
from typing_extensions import TypeIs, deprecated
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.attention.backends.registry import _Backend
|
||||
from vllm.config import (
|
||||
CacheConfig,
|
||||
CompilationConfig,
|
||||
@ -451,6 +452,9 @@ class EngineArgs:
|
||||
MultiModalConfig.mm_shm_cache_max_object_size_mb
|
||||
)
|
||||
mm_encoder_tp_mode: MMEncoderTPMode = MultiModalConfig.mm_encoder_tp_mode
|
||||
mm_encoder_attn_backend: _Backend | str | None = (
|
||||
MultiModalConfig.mm_encoder_attn_backend
|
||||
)
|
||||
io_processor_plugin: str | None = None
|
||||
skip_mm_profiling: bool = MultiModalConfig.skip_mm_profiling
|
||||
video_pruning_rate: float = MultiModalConfig.video_pruning_rate
|
||||
@ -914,6 +918,10 @@ class EngineArgs:
|
||||
multimodal_group.add_argument(
|
||||
"--mm-encoder-tp-mode", **multimodal_kwargs["mm_encoder_tp_mode"]
|
||||
)
|
||||
multimodal_group.add_argument(
|
||||
"--mm-encoder-attn-backend",
|
||||
**multimodal_kwargs["mm_encoder_attn_backend"],
|
||||
)
|
||||
multimodal_group.add_argument(
|
||||
"--interleave-mm-strings", **multimodal_kwargs["interleave_mm_strings"]
|
||||
)
|
||||
@ -1160,6 +1168,7 @@ class EngineArgs:
|
||||
mm_processor_cache_type=self.mm_processor_cache_type,
|
||||
mm_shm_cache_max_object_size_mb=self.mm_shm_cache_max_object_size_mb,
|
||||
mm_encoder_tp_mode=self.mm_encoder_tp_mode,
|
||||
mm_encoder_attn_backend=self.mm_encoder_attn_backend,
|
||||
pooler_config=self.pooler_config,
|
||||
override_pooler_config=self.override_pooler_config,
|
||||
logits_processor_pattern=self.logits_processor_pattern,
|
||||
|
||||
@ -256,6 +256,7 @@ class DotsVisionAttention(nn.Module):
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
prefix: str = "",
|
||||
use_data_parallel: bool = False,
|
||||
attn_backend_override: _Backend | None = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
@ -288,7 +289,9 @@ class DotsVisionAttention(nn.Module):
|
||||
)
|
||||
# Select attention backend
|
||||
self.attn_backend = get_vit_attn_backend(
|
||||
self.hidden_size_per_attention_head, torch.get_default_dtype()
|
||||
self.hidden_size_per_attention_head,
|
||||
torch.get_default_dtype(),
|
||||
attn_backend_override=attn_backend_override,
|
||||
)
|
||||
self.use_upstream_fa = False
|
||||
|
||||
@ -510,6 +513,7 @@ class DotsVisionBlock(nn.Module):
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
prefix: str = "",
|
||||
use_data_parallel: bool = False,
|
||||
attn_backend_override: _Backend | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@ -521,6 +525,7 @@ class DotsVisionBlock(nn.Module):
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.attn",
|
||||
use_data_parallel=use_data_parallel,
|
||||
attn_backend_override=attn_backend_override,
|
||||
)
|
||||
self.norm1 = RMSNorm(config.embed_dim, eps=config.rms_norm_eps)
|
||||
self.mlp = DotsSwiGLUFFN(
|
||||
@ -561,6 +566,7 @@ class DotsVisionTransformer(nn.Module):
|
||||
require_post_norm: bool | None = None,
|
||||
prefix: str = "",
|
||||
use_data_parallel: bool = False,
|
||||
attn_backend_override: _Backend | None = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
@ -571,7 +577,9 @@ class DotsVisionTransformer(nn.Module):
|
||||
head_dim = config.embed_dim // config.num_attention_heads
|
||||
self.rotary_pos_emb = VisionRotaryEmbedding(head_dim // 2)
|
||||
self.attn_backend = get_vit_attn_backend(
|
||||
head_size=head_dim, dtype=torch.get_default_dtype()
|
||||
head_size=head_dim,
|
||||
dtype=torch.get_default_dtype(),
|
||||
attn_backend_override=attn_backend_override,
|
||||
)
|
||||
if self.attn_backend != _Backend.FLASH_ATTN and check_upstream_fa_availability(
|
||||
torch.get_default_dtype()
|
||||
@ -591,6 +599,7 @@ class DotsVisionTransformer(nn.Module):
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.blocks.{i}",
|
||||
use_data_parallel=use_data_parallel,
|
||||
attn_backend_override=attn_backend_override,
|
||||
)
|
||||
for i in range(num_layers)
|
||||
]
|
||||
@ -750,11 +759,17 @@ class DotsOCRForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA
|
||||
self.config.vision_config = vision_config
|
||||
else:
|
||||
vision_config = self.config.vision_config
|
||||
attn_backend_override = (
|
||||
multimodal_config.mm_encoder_attn_backend
|
||||
if multimodal_config is not None
|
||||
else None
|
||||
)
|
||||
self.vision_tower = DotsVisionTransformer(
|
||||
vision_config,
|
||||
quant_config=self.quant_config,
|
||||
prefix=maybe_prefix(prefix, "vision_tower"),
|
||||
use_data_parallel=self.use_data_parallel,
|
||||
attn_backend_override=attn_backend_override,
|
||||
)
|
||||
self.language_model: Qwen2ForCausalLM = init_vllm_registered_model(
|
||||
vllm_config=vllm_config,
|
||||
|
||||
@ -164,6 +164,7 @@ class Ernie4_5_VisionAttention(nn.Module):
|
||||
projection_size: int,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
prefix: str = "",
|
||||
attn_backend_override: _Backend | None = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
# Per attention head and per partition values.
|
||||
@ -196,6 +197,7 @@ class Ernie4_5_VisionAttention(nn.Module):
|
||||
self.attn_backend = get_vit_attn_backend(
|
||||
head_size=self.hidden_size_per_attention_head,
|
||||
dtype=torch.get_default_dtype(),
|
||||
attn_backend_override=attn_backend_override,
|
||||
)
|
||||
|
||||
self.use_upstream_fa = False
|
||||
@ -367,6 +369,7 @@ class Ernie4_5_VisionBlock(nn.Module):
|
||||
norm_layer: Callable[[int], nn.Module] | None = None,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
prefix: str = "",
|
||||
attn_backend_override: _Backend | None = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
@ -382,6 +385,7 @@ class Ernie4_5_VisionBlock(nn.Module):
|
||||
projection_size=dim,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.attn",
|
||||
attn_backend_override=attn_backend_override,
|
||||
)
|
||||
|
||||
self.mlp = Ernie4_5_VisionMLP(
|
||||
@ -458,6 +462,7 @@ class Ernie4_5_VisionTransformer(nn.Module):
|
||||
norm_eps: float = 1e-6,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
prefix: str = "",
|
||||
attn_backend_override: _Backend | None = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
patch_size = vision_config.patch_size
|
||||
@ -493,6 +498,7 @@ class Ernie4_5_VisionTransformer(nn.Module):
|
||||
norm_layer=norm_layer,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.blocks.{layer_idx}",
|
||||
attn_backend_override=attn_backend_override,
|
||||
)
|
||||
for layer_idx in range(depth)
|
||||
]
|
||||
@ -504,7 +510,9 @@ class Ernie4_5_VisionTransformer(nn.Module):
|
||||
self.ln = nn.LayerNorm(hidden_size, eps=1e-6)
|
||||
|
||||
self.attn_backend = get_vit_attn_backend(
|
||||
head_size=head_dim, dtype=torch.get_default_dtype()
|
||||
head_size=head_dim,
|
||||
dtype=torch.get_default_dtype(),
|
||||
attn_backend_override=attn_backend_override,
|
||||
)
|
||||
if self.attn_backend != _Backend.FLASH_ATTN and check_upstream_fa_availability(
|
||||
torch.get_default_dtype()
|
||||
@ -1327,11 +1335,17 @@ class Ernie4_5_VLMoeForConditionalGeneration(
|
||||
self.config = config
|
||||
self.multimodal_config = multimodal_config
|
||||
|
||||
attn_backend_override = (
|
||||
multimodal_config.mm_encoder_attn_backend
|
||||
if multimodal_config is not None
|
||||
else None
|
||||
)
|
||||
self.vision_model = Ernie4_5_VisionTransformer(
|
||||
config.vision_config,
|
||||
norm_eps=getattr(config, "rms_norm_eps", 1e-6),
|
||||
quant_config=quant_config,
|
||||
prefix=maybe_prefix(prefix, "vision_model"),
|
||||
attn_backend_override=attn_backend_override,
|
||||
)
|
||||
|
||||
self.language_model = Ernie4_5_VLMoeForCausalLM(
|
||||
|
||||
@ -247,6 +247,7 @@ class Glm4vVisionAttention(nn.Module):
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
prefix: str = "",
|
||||
use_data_parallel: bool = False,
|
||||
attn_backend_override: _Backend | None = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
# Per attention head and per partition values.
|
||||
@ -287,6 +288,7 @@ class Glm4vVisionAttention(nn.Module):
|
||||
self.attn_backend = get_vit_attn_backend(
|
||||
head_size=self.hidden_size_per_attention_head,
|
||||
dtype=torch.get_default_dtype(),
|
||||
attn_backend_override=attn_backend_override,
|
||||
)
|
||||
self.use_upstream_fa = False
|
||||
|
||||
@ -417,6 +419,7 @@ class Glm4vVisionBlock(nn.Module):
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
prefix: str = "",
|
||||
use_data_parallel: bool = False,
|
||||
attn_backend_override: _Backend | None = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
if norm_layer is None:
|
||||
@ -430,6 +433,7 @@ class Glm4vVisionBlock(nn.Module):
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.attn",
|
||||
use_data_parallel=use_data_parallel,
|
||||
attn_backend_override=attn_backend_override,
|
||||
)
|
||||
self.mlp = Glm4vVisionMLP(
|
||||
dim,
|
||||
@ -696,6 +700,7 @@ class Glm4vVisionTransformer(nn.Module):
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
prefix: str = "",
|
||||
use_data_parallel: bool = False,
|
||||
attn_backend_override: _Backend | None = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
@ -731,6 +736,7 @@ class Glm4vVisionTransformer(nn.Module):
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.blocks.{layer_idx}",
|
||||
use_data_parallel=self.use_data_parallel,
|
||||
attn_backend_override=attn_backend_override,
|
||||
)
|
||||
for layer_idx in range(depth)
|
||||
]
|
||||
@ -759,7 +765,9 @@ class Glm4vVisionTransformer(nn.Module):
|
||||
)
|
||||
|
||||
self.attn_backend = get_vit_attn_backend(
|
||||
head_size=head_dim, dtype=torch.get_default_dtype()
|
||||
head_size=head_dim,
|
||||
dtype=torch.get_default_dtype(),
|
||||
attn_backend_override=attn_backend_override,
|
||||
)
|
||||
if self.attn_backend != _Backend.FLASH_ATTN and check_upstream_fa_availability(
|
||||
torch.get_default_dtype()
|
||||
@ -1437,12 +1445,18 @@ class Glm4vForConditionalGeneration(
|
||||
self.multimodal_config = multimodal_config
|
||||
self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data"
|
||||
|
||||
attn_backend_override = (
|
||||
multimodal_config.mm_encoder_attn_backend
|
||||
if multimodal_config is not None
|
||||
else None
|
||||
)
|
||||
self.visual = Glm4vVisionTransformer(
|
||||
config.vision_config,
|
||||
norm_eps=getattr(config, "rms_norm_eps", 1e-5),
|
||||
quant_config=quant_config,
|
||||
prefix=maybe_prefix(prefix, "visual"),
|
||||
use_data_parallel=self.use_data_parallel,
|
||||
attn_backend_override=attn_backend_override,
|
||||
)
|
||||
|
||||
if config.model_type == "glm4v":
|
||||
|
||||
@ -353,6 +353,7 @@ class KeyeSiglipAttention(nn.Module):
|
||||
config: PretrainedConfig,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
prefix: str = "",
|
||||
attn_backend_override: _Backend | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
@ -392,7 +393,9 @@ class KeyeSiglipAttention(nn.Module):
|
||||
|
||||
# Detect attention implementation.
|
||||
self.attn_backend = get_vit_attn_backend(
|
||||
head_size=self.head_dim, dtype=torch.get_default_dtype()
|
||||
head_size=self.head_dim,
|
||||
dtype=torch.get_default_dtype(),
|
||||
attn_backend_override=attn_backend_override,
|
||||
)
|
||||
|
||||
self.use_upstream_fa = False
|
||||
@ -521,6 +524,7 @@ class KeyeSiglipEncoderLayer(nn.Module):
|
||||
config: PretrainedConfig,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
prefix: str = "",
|
||||
attn_backend_override: _Backend | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.embed_dim = config.hidden_size
|
||||
@ -529,6 +533,7 @@ class KeyeSiglipEncoderLayer(nn.Module):
|
||||
config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.self_attn",
|
||||
attn_backend_override=attn_backend_override,
|
||||
)
|
||||
self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
|
||||
self.mlp = SiglipMLP(
|
||||
@ -573,6 +578,7 @@ class KeyeSiglipEncoder(nn.Module):
|
||||
config: PretrainedConfig,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
prefix: str = "",
|
||||
attn_backend_override: _Backend | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
@ -585,6 +591,7 @@ class KeyeSiglipEncoder(nn.Module):
|
||||
config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.layers.{layer_idx}",
|
||||
attn_backend_override=attn_backend_override,
|
||||
)
|
||||
for layer_idx in range(config.num_hidden_layers)
|
||||
]
|
||||
@ -666,6 +673,7 @@ class KeyeSiglipVisionTransformer(nn.Module):
|
||||
config: PretrainedConfig,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
prefix: str = "",
|
||||
attn_backend_override: _Backend | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
@ -676,6 +684,7 @@ class KeyeSiglipVisionTransformer(nn.Module):
|
||||
config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.encoder",
|
||||
attn_backend_override=attn_backend_override,
|
||||
)
|
||||
self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
|
||||
|
||||
@ -747,6 +756,7 @@ class KeyeSiglipVisionModel(nn.Module):
|
||||
config: PretrainedConfig,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
prefix: str = "",
|
||||
attn_backend_override: _Backend | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@ -754,6 +764,7 @@ class KeyeSiglipVisionModel(nn.Module):
|
||||
config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.vision_model",
|
||||
attn_backend_override=attn_backend_override,
|
||||
)
|
||||
self.quant_config = quant_config
|
||||
|
||||
@ -1296,10 +1307,16 @@ class BaseKeyeModule(nn.Module):
|
||||
self.config = config
|
||||
self.multimodal_config = multimodal_config
|
||||
|
||||
attn_backend_override = (
|
||||
multimodal_config.mm_encoder_attn_backend
|
||||
if multimodal_config is not None
|
||||
else None
|
||||
)
|
||||
self.visual = KeyeSiglipVisionModel(
|
||||
config.vision_config,
|
||||
quant_config=quant_config,
|
||||
prefix=maybe_prefix(prefix, "visual"),
|
||||
attn_backend_override=attn_backend_override,
|
||||
)
|
||||
|
||||
self.mlp_AR = self._build_projector(
|
||||
|
||||
@ -10,6 +10,7 @@ import torch
|
||||
import torch.nn as nn
|
||||
from transformers import BaseImageProcessor, BatchFeature, PretrainedConfig
|
||||
|
||||
from vllm.attention.backends.registry import _Backend
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config.multimodal import BaseDummyOptions
|
||||
from vllm.model_executor.layers.linear import ReplicatedLinear
|
||||
@ -105,6 +106,7 @@ class VisualTokenizer(torch.nn.Module):
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
prefix: str = "",
|
||||
use_data_parallel: bool = False,
|
||||
attn_backend_override: _Backend | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
@ -113,6 +115,7 @@ class VisualTokenizer(torch.nn.Module):
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.vit",
|
||||
use_data_parallel=use_data_parallel,
|
||||
attn_backend_override=attn_backend_override,
|
||||
)
|
||||
# reserved tokens for INDICATOR_IDS
|
||||
head_dim = visual_vocab_size - len(INDICATOR_IDS)
|
||||
@ -132,6 +135,7 @@ class VisualTokenizer(torch.nn.Module):
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
prefix: str = "",
|
||||
use_data_parallel: bool = False,
|
||||
attn_backend_override: _Backend | None = None,
|
||||
):
|
||||
model_type = config.model_type
|
||||
if model_type == "siglip2_navit":
|
||||
@ -140,6 +144,7 @@ class VisualTokenizer(torch.nn.Module):
|
||||
quant_config=quant_config,
|
||||
prefix=prefix,
|
||||
use_data_parallel=use_data_parallel,
|
||||
attn_backend_override=attn_backend_override,
|
||||
)
|
||||
raise ValueError(f"Unsupported visual tokenizer model_type: {model_type}")
|
||||
|
||||
@ -457,6 +462,7 @@ class Ovis2_5(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
super().__init__()
|
||||
config = vllm_config.model_config.hf_config
|
||||
quant_config = vllm_config.quant_config
|
||||
multimodal_config = vllm_config.model_config.multimodal_config
|
||||
|
||||
self.config: PretrainedConfig = config
|
||||
self.llm = init_vllm_registered_model(
|
||||
@ -464,11 +470,17 @@ class Ovis2_5(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
prefix=maybe_prefix(prefix, "llm"),
|
||||
)
|
||||
|
||||
attn_backend_override = (
|
||||
multimodal_config.mm_encoder_attn_backend
|
||||
if multimodal_config is not None
|
||||
else None
|
||||
)
|
||||
self.visual_tokenizer = VisualTokenizer(
|
||||
config=config.vit_config,
|
||||
visual_vocab_size=config.visual_vocab_size,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.visual_tokenizer",
|
||||
attn_backend_override=attn_backend_override,
|
||||
)
|
||||
|
||||
self.vte = VisualEmbedding(config.visual_vocab_size, config.hidden_size)
|
||||
|
||||
@ -637,6 +637,7 @@ class Qwen2_5_VisionTransformer(nn.Module):
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
prefix: str = "",
|
||||
use_data_parallel: bool = False,
|
||||
attn_backend_override: _Backend | None = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
@ -669,7 +670,9 @@ class Qwen2_5_VisionTransformer(nn.Module):
|
||||
|
||||
use_upstream_fa = False
|
||||
self.attn_backend = get_vit_attn_backend(
|
||||
head_size=head_dim, dtype=torch.get_default_dtype()
|
||||
head_size=head_dim,
|
||||
dtype=torch.get_default_dtype(),
|
||||
attn_backend_override=attn_backend_override,
|
||||
)
|
||||
if (
|
||||
self.attn_backend != _Backend.FLASH_ATTN
|
||||
@ -1226,12 +1229,18 @@ class Qwen2_5_VLForConditionalGeneration(
|
||||
if multimodal_config.get_limit_per_prompt(
|
||||
"image"
|
||||
) or multimodal_config.get_limit_per_prompt("video"):
|
||||
attn_backend_override = (
|
||||
multimodal_config.mm_encoder_attn_backend
|
||||
if multimodal_config is not None
|
||||
else None
|
||||
)
|
||||
self.visual = Qwen2_5_VisionTransformer(
|
||||
config.vision_config,
|
||||
norm_eps=getattr(config, "rms_norm_eps", 1e-6),
|
||||
quant_config=self.quant_config,
|
||||
prefix=maybe_prefix(prefix, "visual"),
|
||||
use_data_parallel=self.use_data_parallel,
|
||||
attn_backend_override=attn_backend_override,
|
||||
)
|
||||
else:
|
||||
self.visual = None
|
||||
|
||||
@ -320,6 +320,7 @@ class Qwen2VisionAttention(nn.Module):
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
prefix: str = "",
|
||||
use_data_parallel: bool = False,
|
||||
attn_backend_override: _Backend | None = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
# Per attention head and per partition values.
|
||||
@ -355,6 +356,7 @@ class Qwen2VisionAttention(nn.Module):
|
||||
self.attn_backend = get_vit_attn_backend(
|
||||
head_size=self.hidden_size_per_attention_head,
|
||||
dtype=torch.get_default_dtype(),
|
||||
attn_backend_override=attn_backend_override,
|
||||
)
|
||||
self.use_upstream_fa = False
|
||||
|
||||
@ -497,6 +499,7 @@ class Qwen2VisionBlock(nn.Module):
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
prefix: str = "",
|
||||
use_data_parallel: bool = False,
|
||||
attn_backend_override: _Backend | None = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
if norm_layer is None:
|
||||
@ -512,6 +515,7 @@ class Qwen2VisionBlock(nn.Module):
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.attn",
|
||||
use_data_parallel=use_data_parallel,
|
||||
attn_backend_override=attn_backend_override,
|
||||
)
|
||||
self.mlp = Qwen2VisionMLP(
|
||||
dim,
|
||||
@ -662,6 +666,7 @@ class Qwen2VisionTransformer(nn.Module):
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
prefix: str = "",
|
||||
use_data_parallel: bool = False,
|
||||
attn_backend_override: _Backend | None = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
@ -703,6 +708,7 @@ class Qwen2VisionTransformer(nn.Module):
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.blocks.{layer_idx}",
|
||||
use_data_parallel=use_data_parallel,
|
||||
attn_backend_override=attn_backend_override,
|
||||
)
|
||||
for layer_idx in range(depth)
|
||||
]
|
||||
@ -716,7 +722,9 @@ class Qwen2VisionTransformer(nn.Module):
|
||||
use_data_parallel=use_data_parallel,
|
||||
)
|
||||
self.attn_backend = get_vit_attn_backend(
|
||||
head_size=head_dim, dtype=torch.get_default_dtype()
|
||||
head_size=head_dim,
|
||||
dtype=torch.get_default_dtype(),
|
||||
attn_backend_override=attn_backend_override,
|
||||
)
|
||||
if self.attn_backend != _Backend.FLASH_ATTN and check_upstream_fa_availability(
|
||||
torch.get_default_dtype()
|
||||
@ -1356,12 +1364,18 @@ class Qwen2VLForConditionalGeneration(
|
||||
if multimodal_config.get_limit_per_prompt(
|
||||
"image"
|
||||
) or multimodal_config.get_limit_per_prompt("video"):
|
||||
attn_backend_override = (
|
||||
multimodal_config.mm_encoder_attn_backend
|
||||
if multimodal_config is not None
|
||||
else None
|
||||
)
|
||||
self.visual = Qwen2VisionTransformer(
|
||||
config.vision_config,
|
||||
norm_eps=getattr(config, "rms_norm_eps", 1e-6),
|
||||
quant_config=quant_config,
|
||||
prefix=maybe_prefix(prefix, "visual"),
|
||||
use_data_parallel=self.use_data_parallel,
|
||||
attn_backend_override=attn_backend_override,
|
||||
)
|
||||
else:
|
||||
self.visual = None
|
||||
|
||||
@ -296,6 +296,7 @@ class Qwen3Omni_VisionTransformer(nn.Module):
|
||||
norm_eps: float = 1e-6,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
prefix: str = "",
|
||||
attn_backend_override: _Backend | None = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.hidden_size = vision_config.hidden_size
|
||||
@ -367,7 +368,9 @@ class Qwen3Omni_VisionTransformer(nn.Module):
|
||||
)
|
||||
|
||||
self.attn_backend = get_vit_attn_backend(
|
||||
head_size=head_dim, dtype=torch.get_default_dtype()
|
||||
head_size=head_dim,
|
||||
dtype=torch.get_default_dtype(),
|
||||
attn_backend_override=attn_backend_override,
|
||||
)
|
||||
if self.attn_backend != _Backend.FLASH_ATTN and check_upstream_fa_availability(
|
||||
torch.get_default_dtype()
|
||||
@ -1144,11 +1147,17 @@ class Qwen3OmniMoeThinkerForConditionalGeneration(
|
||||
|
||||
self.audio_tower = Qwen3OmniMoeAudioEncoder(thinker_config.audio_config)
|
||||
|
||||
attn_backend_override = (
|
||||
multimodal_config.mm_encoder_attn_backend
|
||||
if multimodal_config is not None
|
||||
else None
|
||||
)
|
||||
self.visual = Qwen3Omni_VisionTransformer(
|
||||
vision_config=thinker_config.vision_config,
|
||||
norm_eps=getattr(thinker_config.text_config, "rms_norm_eps", 1e-6),
|
||||
quant_config=quant_config,
|
||||
prefix=maybe_prefix(prefix, "visual"),
|
||||
attn_backend_override=attn_backend_override,
|
||||
)
|
||||
self.quant_config = quant_config
|
||||
|
||||
|
||||
@ -300,6 +300,7 @@ class Qwen3_VisionTransformer(nn.Module):
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
prefix: str = "",
|
||||
use_data_parallel: bool = False,
|
||||
attn_backend_override: _Backend | None = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.hidden_size = vision_config.hidden_size
|
||||
@ -359,7 +360,9 @@ class Qwen3_VisionTransformer(nn.Module):
|
||||
)
|
||||
|
||||
self.attn_backend = get_vit_attn_backend(
|
||||
head_size=head_dim, dtype=torch.get_default_dtype()
|
||||
head_size=head_dim,
|
||||
dtype=torch.get_default_dtype(),
|
||||
attn_backend_override=attn_backend_override,
|
||||
)
|
||||
use_upstream_fa = False
|
||||
if (
|
||||
@ -379,7 +382,6 @@ class Qwen3_VisionTransformer(nn.Module):
|
||||
raise RuntimeError(
|
||||
f"Qwen3-VL does not support {self.attn_backend} backend now."
|
||||
)
|
||||
|
||||
self.blocks = nn.ModuleList(
|
||||
[
|
||||
Qwen3_VisionBlock(
|
||||
@ -1214,12 +1216,18 @@ class Qwen3VLForConditionalGeneration(
|
||||
) and not multimodal_config.get_limit_per_prompt("video"):
|
||||
self.visual = None
|
||||
else:
|
||||
attn_backend_override = (
|
||||
multimodal_config.mm_encoder_attn_backend
|
||||
if multimodal_config is not None
|
||||
else None
|
||||
)
|
||||
self.visual = Qwen3_VisionTransformer(
|
||||
config.vision_config,
|
||||
norm_eps=getattr(config, "rms_norm_eps", 1e-6),
|
||||
quant_config=quant_config,
|
||||
prefix=maybe_prefix(prefix, "visual"),
|
||||
use_data_parallel=self.use_data_parallel,
|
||||
attn_backend_override=attn_backend_override,
|
||||
)
|
||||
|
||||
self.language_model = Qwen3LLMForCausalLM(
|
||||
|
||||
@ -208,6 +208,7 @@ class Siglip2Attention(nn.Module):
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
prefix: str = "",
|
||||
use_data_parallel: bool = False,
|
||||
attn_backend_override: _Backend | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
@ -248,7 +249,9 @@ class Siglip2Attention(nn.Module):
|
||||
|
||||
# Detect attention implementation.
|
||||
self.attn_backend = get_vit_attn_backend(
|
||||
head_size=self.head_dim, dtype=torch.get_default_dtype()
|
||||
head_size=self.head_dim,
|
||||
dtype=torch.get_default_dtype(),
|
||||
attn_backend_override=attn_backend_override,
|
||||
)
|
||||
self.use_upstream_fa = False
|
||||
|
||||
@ -372,6 +375,7 @@ class Siglip2EncoderLayer(nn.Module):
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
prefix: str = "",
|
||||
use_data_parallel: bool = False,
|
||||
attn_backend_override: _Backend | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.embed_dim = config.hidden_size
|
||||
@ -381,6 +385,7 @@ class Siglip2EncoderLayer(nn.Module):
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.self_attn",
|
||||
use_data_parallel=use_data_parallel,
|
||||
attn_backend_override=attn_backend_override,
|
||||
)
|
||||
self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
|
||||
self.mlp = Siglip2MLP(
|
||||
@ -434,6 +439,7 @@ class Siglip2Encoder(nn.Module):
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
prefix: str = "",
|
||||
use_data_parallel: bool = False,
|
||||
attn_backend_override: _Backend | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
@ -444,6 +450,7 @@ class Siglip2Encoder(nn.Module):
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.layers.{idx}",
|
||||
use_data_parallel=use_data_parallel,
|
||||
attn_backend_override=attn_backend_override,
|
||||
)
|
||||
for idx in range(config.num_hidden_layers)
|
||||
]
|
||||
@ -618,6 +625,7 @@ class Siglip2VisionTransformer(nn.Module):
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
prefix: str = "",
|
||||
use_data_parallel: bool = False,
|
||||
attn_backend_override: _Backend | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
@ -629,6 +637,7 @@ class Siglip2VisionTransformer(nn.Module):
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.encoder",
|
||||
use_data_parallel=use_data_parallel,
|
||||
attn_backend_override=attn_backend_override,
|
||||
)
|
||||
self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
|
||||
|
||||
@ -657,6 +666,7 @@ class Siglip2NavitModel(torch.nn.Module):
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
prefix: str = "",
|
||||
use_data_parallel: bool = False,
|
||||
attn_backend_override: _Backend | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@ -665,6 +675,7 @@ class Siglip2NavitModel(torch.nn.Module):
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.vision_model",
|
||||
use_data_parallel=use_data_parallel,
|
||||
attn_backend_override=attn_backend_override,
|
||||
)
|
||||
|
||||
def forward(
|
||||
|
||||
@ -78,10 +78,18 @@ def get_vision_encoder_info(hf_config: VisionLanguageConfig) -> VisionEncoderInf
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
|
||||
def get_vit_attn_backend(head_size: int, dtype: torch.dtype) -> _Backend:
|
||||
def get_vit_attn_backend(
|
||||
head_size: int,
|
||||
dtype: torch.dtype,
|
||||
*,
|
||||
attn_backend_override: _Backend | None = None,
|
||||
) -> _Backend:
|
||||
"""
|
||||
Get the available attention backend for Vision Transformer.
|
||||
"""
|
||||
if attn_backend_override is not None:
|
||||
return attn_backend_override
|
||||
|
||||
# Lazy import to avoid circular dependency
|
||||
from vllm.attention.selector import get_env_variable_attn_backend
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user