mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 10:46:08 +08:00
[MM][Core] Decouple ViT backend from LM backend (#27061)
Signed-off-by: Roger Wang <hey@rogerw.io>
This commit is contained in:
parent
72f431e709
commit
c3a2c6ac5f
25
tests/config/test_multimodal_config.py
Normal file
25
tests/config/test_multimodal_config.py
Normal file
@ -0,0 +1,25 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from vllm.attention.backends.registry import _Backend
|
||||||
|
from vllm.config.multimodal import MultiModalConfig
|
||||||
|
|
||||||
|
|
||||||
|
def test_mm_encoder_attn_backend_str_conversion():
|
||||||
|
config = MultiModalConfig(mm_encoder_attn_backend="FLASH_ATTN")
|
||||||
|
assert config.mm_encoder_attn_backend == _Backend.FLASH_ATTN
|
||||||
|
|
||||||
|
|
||||||
|
def test_mm_encoder_attn_backend_invalid():
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
MultiModalConfig(mm_encoder_attn_backend="not_a_backend")
|
||||||
|
|
||||||
|
|
||||||
|
def test_mm_encoder_attn_backend_hash_updates():
|
||||||
|
base_hash = MultiModalConfig().compute_hash()
|
||||||
|
overridden_hash = MultiModalConfig(
|
||||||
|
mm_encoder_attn_backend=_Backend.FLASH_ATTN
|
||||||
|
).compute_hash()
|
||||||
|
assert base_hash != overridden_hash
|
||||||
@ -16,6 +16,7 @@ from vllm.attention.backends.registry import _Backend, backend_name_to_enum
|
|||||||
from vllm.attention.selector import get_attn_backend
|
from vllm.attention.selector import get_attn_backend
|
||||||
from vllm.attention.utils.kv_sharing_utils import validate_kv_sharing_target
|
from vllm.attention.utils.kv_sharing_utils import validate_kv_sharing_target
|
||||||
from vllm.config import CacheConfig, get_current_vllm_config
|
from vllm.config import CacheConfig, get_current_vllm_config
|
||||||
|
from vllm.config.multimodal import MultiModalConfig
|
||||||
from vllm.config.vllm import VllmConfig
|
from vllm.config.vllm import VllmConfig
|
||||||
from vllm.distributed.kv_transfer import (
|
from vllm.distributed.kv_transfer import (
|
||||||
get_kv_transfer_group,
|
get_kv_transfer_group,
|
||||||
@ -443,6 +444,7 @@ class MultiHeadAttention(nn.Module):
|
|||||||
# This has no effect, it is only here to make it easier to swap
|
# This has no effect, it is only here to make it easier to swap
|
||||||
# between Attention and MultiHeadAttention
|
# between Attention and MultiHeadAttention
|
||||||
prefix: str = "",
|
prefix: str = "",
|
||||||
|
multimodal_config: MultiModalConfig | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.num_heads = num_heads
|
self.num_heads = num_heads
|
||||||
@ -462,7 +464,14 @@ class MultiHeadAttention(nn.Module):
|
|||||||
dtype = torch.get_default_dtype()
|
dtype = torch.get_default_dtype()
|
||||||
|
|
||||||
# Determine the attention backend
|
# Determine the attention backend
|
||||||
backend = get_vit_attn_backend(head_size=head_size, dtype=dtype)
|
attn_backend_override = None
|
||||||
|
if multimodal_config is not None:
|
||||||
|
attn_backend_override = multimodal_config.mm_encoder_attn_backend
|
||||||
|
backend = get_vit_attn_backend(
|
||||||
|
head_size=head_size,
|
||||||
|
dtype=dtype,
|
||||||
|
attn_backend_override=attn_backend_override,
|
||||||
|
)
|
||||||
|
|
||||||
# Some auto-selected backends can be upgraded
|
# Some auto-selected backends can be upgraded
|
||||||
# to upstream flash attention if available.
|
# to upstream flash attention if available.
|
||||||
|
|||||||
@ -50,6 +50,7 @@ if TYPE_CHECKING:
|
|||||||
|
|
||||||
import vllm.model_executor.layers.quantization as me_quant
|
import vllm.model_executor.layers.quantization as me_quant
|
||||||
import vllm.model_executor.models as me_models
|
import vllm.model_executor.models as me_models
|
||||||
|
from vllm.attention.backends.registry import _Backend
|
||||||
from vllm.config.load import LoadConfig
|
from vllm.config.load import LoadConfig
|
||||||
from vllm.config.parallel import ParallelConfig
|
from vllm.config.parallel import ParallelConfig
|
||||||
from vllm.model_executor.layers.quantization import QuantizationMethods
|
from vllm.model_executor.layers.quantization import QuantizationMethods
|
||||||
@ -57,6 +58,7 @@ if TYPE_CHECKING:
|
|||||||
else:
|
else:
|
||||||
PretrainedConfig = Any
|
PretrainedConfig = Any
|
||||||
|
|
||||||
|
_Backend = Any
|
||||||
me_quant = LazyLoader(
|
me_quant = LazyLoader(
|
||||||
"model_executor", globals(), "vllm.model_executor.layers.quantization"
|
"model_executor", globals(), "vllm.model_executor.layers.quantization"
|
||||||
)
|
)
|
||||||
@ -307,6 +309,7 @@ class ModelConfig:
|
|||||||
mm_processor_cache_type: InitVar[MMCacheType | None] = None
|
mm_processor_cache_type: InitVar[MMCacheType | None] = None
|
||||||
mm_shm_cache_max_object_size_mb: InitVar[int | None] = None
|
mm_shm_cache_max_object_size_mb: InitVar[int | None] = None
|
||||||
mm_encoder_tp_mode: InitVar[MMEncoderTPMode | None] = None
|
mm_encoder_tp_mode: InitVar[MMEncoderTPMode | None] = None
|
||||||
|
mm_encoder_attn_backend: InitVar[_Backend | str | None] = None
|
||||||
interleave_mm_strings: InitVar[bool | None] = None
|
interleave_mm_strings: InitVar[bool | None] = None
|
||||||
skip_mm_profiling: InitVar[bool | None] = None
|
skip_mm_profiling: InitVar[bool | None] = None
|
||||||
video_pruning_rate: InitVar[float | None] = None
|
video_pruning_rate: InitVar[float | None] = None
|
||||||
@ -424,6 +427,7 @@ class ModelConfig:
|
|||||||
mm_processor_cache_type: MMCacheType | None,
|
mm_processor_cache_type: MMCacheType | None,
|
||||||
mm_shm_cache_max_object_size_mb: int | None,
|
mm_shm_cache_max_object_size_mb: int | None,
|
||||||
mm_encoder_tp_mode: MMEncoderTPMode | None,
|
mm_encoder_tp_mode: MMEncoderTPMode | None,
|
||||||
|
mm_encoder_attn_backend: _Backend | str | None,
|
||||||
interleave_mm_strings: bool | None,
|
interleave_mm_strings: bool | None,
|
||||||
skip_mm_profiling: bool | None,
|
skip_mm_profiling: bool | None,
|
||||||
video_pruning_rate: float | None,
|
video_pruning_rate: float | None,
|
||||||
@ -733,6 +737,7 @@ class ModelConfig:
|
|||||||
mm_processor_cache_type=mm_processor_cache_type,
|
mm_processor_cache_type=mm_processor_cache_type,
|
||||||
mm_shm_cache_max_object_size_mb=mm_shm_cache_max_object_size_mb,
|
mm_shm_cache_max_object_size_mb=mm_shm_cache_max_object_size_mb,
|
||||||
mm_encoder_tp_mode=mm_encoder_tp_mode,
|
mm_encoder_tp_mode=mm_encoder_tp_mode,
|
||||||
|
mm_encoder_attn_backend=mm_encoder_attn_backend,
|
||||||
interleave_mm_strings=interleave_mm_strings,
|
interleave_mm_strings=interleave_mm_strings,
|
||||||
skip_mm_profiling=skip_mm_profiling,
|
skip_mm_profiling=skip_mm_profiling,
|
||||||
video_pruning_rate=video_pruning_rate,
|
video_pruning_rate=video_pruning_rate,
|
||||||
|
|||||||
@ -3,13 +3,18 @@
|
|||||||
|
|
||||||
import hashlib
|
import hashlib
|
||||||
from collections.abc import Mapping
|
from collections.abc import Mapping
|
||||||
from typing import Any, Literal, TypeAlias
|
from typing import TYPE_CHECKING, Any, Literal, TypeAlias
|
||||||
|
|
||||||
from pydantic import ConfigDict, Field, field_validator, model_validator
|
from pydantic import ConfigDict, Field, field_validator, model_validator
|
||||||
from pydantic.dataclasses import dataclass
|
from pydantic.dataclasses import dataclass
|
||||||
|
|
||||||
from vllm.config.utils import config
|
from vllm.config.utils import config
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from vllm.attention.backends.registry import _Backend
|
||||||
|
else:
|
||||||
|
_Backend = Any
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class BaseDummyOptions:
|
class BaseDummyOptions:
|
||||||
@ -112,6 +117,10 @@ class MultiModalConfig:
|
|||||||
DP (which is controlled by `--data-parallel-size`).
|
DP (which is controlled by `--data-parallel-size`).
|
||||||
This is only supported on a per-model basis and falls back to
|
This is only supported on a per-model basis and falls back to
|
||||||
`"weights"` if the encoder does not support DP."""
|
`"weights"` if the encoder does not support DP."""
|
||||||
|
mm_encoder_attn_backend: _Backend | None = None
|
||||||
|
"""Optional override for the multi-modal encoder attention backend when
|
||||||
|
using vision transformers. Accepts any value from
|
||||||
|
`vllm.attention.backends.registry._Backend` (e.g. `FLASH_ATTN`)."""
|
||||||
interleave_mm_strings: bool = False
|
interleave_mm_strings: bool = False
|
||||||
"""Enable fully interleaved support for multimodal prompts, while using
|
"""Enable fully interleaved support for multimodal prompts, while using
|
||||||
--chat-template-content-format=string."""
|
--chat-template-content-format=string."""
|
||||||
@ -148,6 +157,29 @@ class MultiModalConfig:
|
|||||||
value[k] = BaseDummyOptions(**v)
|
value[k] = BaseDummyOptions(**v)
|
||||||
return value
|
return value
|
||||||
|
|
||||||
|
@field_validator("mm_encoder_attn_backend", mode="before")
|
||||||
|
@classmethod
|
||||||
|
def _validate_mm_encoder_attn_backend(cls, value: object) -> _Backend | None:
|
||||||
|
from vllm.attention.backends.registry import (
|
||||||
|
_Backend as BackendEnum,
|
||||||
|
)
|
||||||
|
from vllm.attention.backends.registry import (
|
||||||
|
backend_name_to_enum,
|
||||||
|
)
|
||||||
|
|
||||||
|
if value is None or isinstance(value, BackendEnum):
|
||||||
|
return value
|
||||||
|
|
||||||
|
if isinstance(value, str):
|
||||||
|
candidate = backend_name_to_enum(value.upper())
|
||||||
|
if candidate is not None:
|
||||||
|
return candidate
|
||||||
|
|
||||||
|
valid_backends = ", ".join(sorted(BackendEnum.__members__.keys()))
|
||||||
|
raise ValueError(
|
||||||
|
f"Invalid mm encoder attention backend. Expected one of: {valid_backends}."
|
||||||
|
)
|
||||||
|
|
||||||
@model_validator(mode="after")
|
@model_validator(mode="after")
|
||||||
def _validate_multimodal_config(self):
|
def _validate_multimodal_config(self):
|
||||||
if self.mm_processor_cache_type != "shm" and (
|
if self.mm_processor_cache_type != "shm" and (
|
||||||
@ -172,9 +204,11 @@ class MultiModalConfig:
|
|||||||
excluding anything before input ids/embeddings and after
|
excluding anything before input ids/embeddings and after
|
||||||
the final hidden states.
|
the final hidden states.
|
||||||
"""
|
"""
|
||||||
# no factors to consider.
|
factors: list[Any] = [
|
||||||
# this config will not affect the computation graph.
|
self.mm_encoder_attn_backend.name
|
||||||
factors: list[Any] = []
|
if self.mm_encoder_attn_backend is not None
|
||||||
|
else None
|
||||||
|
]
|
||||||
hash_str = hashlib.md5(str(factors).encode(), usedforsecurity=False).hexdigest()
|
hash_str = hashlib.md5(str(factors).encode(), usedforsecurity=False).hexdigest()
|
||||||
return hash_str
|
return hash_str
|
||||||
|
|
||||||
|
|||||||
@ -32,6 +32,7 @@ from pydantic.fields import FieldInfo
|
|||||||
from typing_extensions import TypeIs, deprecated
|
from typing_extensions import TypeIs, deprecated
|
||||||
|
|
||||||
import vllm.envs as envs
|
import vllm.envs as envs
|
||||||
|
from vllm.attention.backends.registry import _Backend
|
||||||
from vllm.config import (
|
from vllm.config import (
|
||||||
CacheConfig,
|
CacheConfig,
|
||||||
CompilationConfig,
|
CompilationConfig,
|
||||||
@ -451,6 +452,9 @@ class EngineArgs:
|
|||||||
MultiModalConfig.mm_shm_cache_max_object_size_mb
|
MultiModalConfig.mm_shm_cache_max_object_size_mb
|
||||||
)
|
)
|
||||||
mm_encoder_tp_mode: MMEncoderTPMode = MultiModalConfig.mm_encoder_tp_mode
|
mm_encoder_tp_mode: MMEncoderTPMode = MultiModalConfig.mm_encoder_tp_mode
|
||||||
|
mm_encoder_attn_backend: _Backend | str | None = (
|
||||||
|
MultiModalConfig.mm_encoder_attn_backend
|
||||||
|
)
|
||||||
io_processor_plugin: str | None = None
|
io_processor_plugin: str | None = None
|
||||||
skip_mm_profiling: bool = MultiModalConfig.skip_mm_profiling
|
skip_mm_profiling: bool = MultiModalConfig.skip_mm_profiling
|
||||||
video_pruning_rate: float = MultiModalConfig.video_pruning_rate
|
video_pruning_rate: float = MultiModalConfig.video_pruning_rate
|
||||||
@ -914,6 +918,10 @@ class EngineArgs:
|
|||||||
multimodal_group.add_argument(
|
multimodal_group.add_argument(
|
||||||
"--mm-encoder-tp-mode", **multimodal_kwargs["mm_encoder_tp_mode"]
|
"--mm-encoder-tp-mode", **multimodal_kwargs["mm_encoder_tp_mode"]
|
||||||
)
|
)
|
||||||
|
multimodal_group.add_argument(
|
||||||
|
"--mm-encoder-attn-backend",
|
||||||
|
**multimodal_kwargs["mm_encoder_attn_backend"],
|
||||||
|
)
|
||||||
multimodal_group.add_argument(
|
multimodal_group.add_argument(
|
||||||
"--interleave-mm-strings", **multimodal_kwargs["interleave_mm_strings"]
|
"--interleave-mm-strings", **multimodal_kwargs["interleave_mm_strings"]
|
||||||
)
|
)
|
||||||
@ -1160,6 +1168,7 @@ class EngineArgs:
|
|||||||
mm_processor_cache_type=self.mm_processor_cache_type,
|
mm_processor_cache_type=self.mm_processor_cache_type,
|
||||||
mm_shm_cache_max_object_size_mb=self.mm_shm_cache_max_object_size_mb,
|
mm_shm_cache_max_object_size_mb=self.mm_shm_cache_max_object_size_mb,
|
||||||
mm_encoder_tp_mode=self.mm_encoder_tp_mode,
|
mm_encoder_tp_mode=self.mm_encoder_tp_mode,
|
||||||
|
mm_encoder_attn_backend=self.mm_encoder_attn_backend,
|
||||||
pooler_config=self.pooler_config,
|
pooler_config=self.pooler_config,
|
||||||
override_pooler_config=self.override_pooler_config,
|
override_pooler_config=self.override_pooler_config,
|
||||||
logits_processor_pattern=self.logits_processor_pattern,
|
logits_processor_pattern=self.logits_processor_pattern,
|
||||||
|
|||||||
@ -256,6 +256,7 @@ class DotsVisionAttention(nn.Module):
|
|||||||
quant_config: QuantizationConfig | None = None,
|
quant_config: QuantizationConfig | None = None,
|
||||||
prefix: str = "",
|
prefix: str = "",
|
||||||
use_data_parallel: bool = False,
|
use_data_parallel: bool = False,
|
||||||
|
attn_backend_override: _Backend | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@ -288,7 +289,9 @@ class DotsVisionAttention(nn.Module):
|
|||||||
)
|
)
|
||||||
# Select attention backend
|
# Select attention backend
|
||||||
self.attn_backend = get_vit_attn_backend(
|
self.attn_backend = get_vit_attn_backend(
|
||||||
self.hidden_size_per_attention_head, torch.get_default_dtype()
|
self.hidden_size_per_attention_head,
|
||||||
|
torch.get_default_dtype(),
|
||||||
|
attn_backend_override=attn_backend_override,
|
||||||
)
|
)
|
||||||
self.use_upstream_fa = False
|
self.use_upstream_fa = False
|
||||||
|
|
||||||
@ -510,6 +513,7 @@ class DotsVisionBlock(nn.Module):
|
|||||||
quant_config: QuantizationConfig | None = None,
|
quant_config: QuantizationConfig | None = None,
|
||||||
prefix: str = "",
|
prefix: str = "",
|
||||||
use_data_parallel: bool = False,
|
use_data_parallel: bool = False,
|
||||||
|
attn_backend_override: _Backend | None = None,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@ -521,6 +525,7 @@ class DotsVisionBlock(nn.Module):
|
|||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
prefix=f"{prefix}.attn",
|
prefix=f"{prefix}.attn",
|
||||||
use_data_parallel=use_data_parallel,
|
use_data_parallel=use_data_parallel,
|
||||||
|
attn_backend_override=attn_backend_override,
|
||||||
)
|
)
|
||||||
self.norm1 = RMSNorm(config.embed_dim, eps=config.rms_norm_eps)
|
self.norm1 = RMSNorm(config.embed_dim, eps=config.rms_norm_eps)
|
||||||
self.mlp = DotsSwiGLUFFN(
|
self.mlp = DotsSwiGLUFFN(
|
||||||
@ -561,6 +566,7 @@ class DotsVisionTransformer(nn.Module):
|
|||||||
require_post_norm: bool | None = None,
|
require_post_norm: bool | None = None,
|
||||||
prefix: str = "",
|
prefix: str = "",
|
||||||
use_data_parallel: bool = False,
|
use_data_parallel: bool = False,
|
||||||
|
attn_backend_override: _Backend | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.config = config
|
self.config = config
|
||||||
@ -571,7 +577,9 @@ class DotsVisionTransformer(nn.Module):
|
|||||||
head_dim = config.embed_dim // config.num_attention_heads
|
head_dim = config.embed_dim // config.num_attention_heads
|
||||||
self.rotary_pos_emb = VisionRotaryEmbedding(head_dim // 2)
|
self.rotary_pos_emb = VisionRotaryEmbedding(head_dim // 2)
|
||||||
self.attn_backend = get_vit_attn_backend(
|
self.attn_backend = get_vit_attn_backend(
|
||||||
head_size=head_dim, dtype=torch.get_default_dtype()
|
head_size=head_dim,
|
||||||
|
dtype=torch.get_default_dtype(),
|
||||||
|
attn_backend_override=attn_backend_override,
|
||||||
)
|
)
|
||||||
if self.attn_backend != _Backend.FLASH_ATTN and check_upstream_fa_availability(
|
if self.attn_backend != _Backend.FLASH_ATTN and check_upstream_fa_availability(
|
||||||
torch.get_default_dtype()
|
torch.get_default_dtype()
|
||||||
@ -591,6 +599,7 @@ class DotsVisionTransformer(nn.Module):
|
|||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
prefix=f"{prefix}.blocks.{i}",
|
prefix=f"{prefix}.blocks.{i}",
|
||||||
use_data_parallel=use_data_parallel,
|
use_data_parallel=use_data_parallel,
|
||||||
|
attn_backend_override=attn_backend_override,
|
||||||
)
|
)
|
||||||
for i in range(num_layers)
|
for i in range(num_layers)
|
||||||
]
|
]
|
||||||
@ -750,11 +759,17 @@ class DotsOCRForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA
|
|||||||
self.config.vision_config = vision_config
|
self.config.vision_config = vision_config
|
||||||
else:
|
else:
|
||||||
vision_config = self.config.vision_config
|
vision_config = self.config.vision_config
|
||||||
|
attn_backend_override = (
|
||||||
|
multimodal_config.mm_encoder_attn_backend
|
||||||
|
if multimodal_config is not None
|
||||||
|
else None
|
||||||
|
)
|
||||||
self.vision_tower = DotsVisionTransformer(
|
self.vision_tower = DotsVisionTransformer(
|
||||||
vision_config,
|
vision_config,
|
||||||
quant_config=self.quant_config,
|
quant_config=self.quant_config,
|
||||||
prefix=maybe_prefix(prefix, "vision_tower"),
|
prefix=maybe_prefix(prefix, "vision_tower"),
|
||||||
use_data_parallel=self.use_data_parallel,
|
use_data_parallel=self.use_data_parallel,
|
||||||
|
attn_backend_override=attn_backend_override,
|
||||||
)
|
)
|
||||||
self.language_model: Qwen2ForCausalLM = init_vllm_registered_model(
|
self.language_model: Qwen2ForCausalLM = init_vllm_registered_model(
|
||||||
vllm_config=vllm_config,
|
vllm_config=vllm_config,
|
||||||
|
|||||||
@ -164,6 +164,7 @@ class Ernie4_5_VisionAttention(nn.Module):
|
|||||||
projection_size: int,
|
projection_size: int,
|
||||||
quant_config: QuantizationConfig | None = None,
|
quant_config: QuantizationConfig | None = None,
|
||||||
prefix: str = "",
|
prefix: str = "",
|
||||||
|
attn_backend_override: _Backend | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
# Per attention head and per partition values.
|
# Per attention head and per partition values.
|
||||||
@ -196,6 +197,7 @@ class Ernie4_5_VisionAttention(nn.Module):
|
|||||||
self.attn_backend = get_vit_attn_backend(
|
self.attn_backend = get_vit_attn_backend(
|
||||||
head_size=self.hidden_size_per_attention_head,
|
head_size=self.hidden_size_per_attention_head,
|
||||||
dtype=torch.get_default_dtype(),
|
dtype=torch.get_default_dtype(),
|
||||||
|
attn_backend_override=attn_backend_override,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.use_upstream_fa = False
|
self.use_upstream_fa = False
|
||||||
@ -367,6 +369,7 @@ class Ernie4_5_VisionBlock(nn.Module):
|
|||||||
norm_layer: Callable[[int], nn.Module] | None = None,
|
norm_layer: Callable[[int], nn.Module] | None = None,
|
||||||
quant_config: QuantizationConfig | None = None,
|
quant_config: QuantizationConfig | None = None,
|
||||||
prefix: str = "",
|
prefix: str = "",
|
||||||
|
attn_backend_override: _Backend | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@ -382,6 +385,7 @@ class Ernie4_5_VisionBlock(nn.Module):
|
|||||||
projection_size=dim,
|
projection_size=dim,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
prefix=f"{prefix}.attn",
|
prefix=f"{prefix}.attn",
|
||||||
|
attn_backend_override=attn_backend_override,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.mlp = Ernie4_5_VisionMLP(
|
self.mlp = Ernie4_5_VisionMLP(
|
||||||
@ -458,6 +462,7 @@ class Ernie4_5_VisionTransformer(nn.Module):
|
|||||||
norm_eps: float = 1e-6,
|
norm_eps: float = 1e-6,
|
||||||
quant_config: QuantizationConfig | None = None,
|
quant_config: QuantizationConfig | None = None,
|
||||||
prefix: str = "",
|
prefix: str = "",
|
||||||
|
attn_backend_override: _Backend | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
patch_size = vision_config.patch_size
|
patch_size = vision_config.patch_size
|
||||||
@ -493,6 +498,7 @@ class Ernie4_5_VisionTransformer(nn.Module):
|
|||||||
norm_layer=norm_layer,
|
norm_layer=norm_layer,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
prefix=f"{prefix}.blocks.{layer_idx}",
|
prefix=f"{prefix}.blocks.{layer_idx}",
|
||||||
|
attn_backend_override=attn_backend_override,
|
||||||
)
|
)
|
||||||
for layer_idx in range(depth)
|
for layer_idx in range(depth)
|
||||||
]
|
]
|
||||||
@ -504,7 +510,9 @@ class Ernie4_5_VisionTransformer(nn.Module):
|
|||||||
self.ln = nn.LayerNorm(hidden_size, eps=1e-6)
|
self.ln = nn.LayerNorm(hidden_size, eps=1e-6)
|
||||||
|
|
||||||
self.attn_backend = get_vit_attn_backend(
|
self.attn_backend = get_vit_attn_backend(
|
||||||
head_size=head_dim, dtype=torch.get_default_dtype()
|
head_size=head_dim,
|
||||||
|
dtype=torch.get_default_dtype(),
|
||||||
|
attn_backend_override=attn_backend_override,
|
||||||
)
|
)
|
||||||
if self.attn_backend != _Backend.FLASH_ATTN and check_upstream_fa_availability(
|
if self.attn_backend != _Backend.FLASH_ATTN and check_upstream_fa_availability(
|
||||||
torch.get_default_dtype()
|
torch.get_default_dtype()
|
||||||
@ -1327,11 +1335,17 @@ class Ernie4_5_VLMoeForConditionalGeneration(
|
|||||||
self.config = config
|
self.config = config
|
||||||
self.multimodal_config = multimodal_config
|
self.multimodal_config = multimodal_config
|
||||||
|
|
||||||
|
attn_backend_override = (
|
||||||
|
multimodal_config.mm_encoder_attn_backend
|
||||||
|
if multimodal_config is not None
|
||||||
|
else None
|
||||||
|
)
|
||||||
self.vision_model = Ernie4_5_VisionTransformer(
|
self.vision_model = Ernie4_5_VisionTransformer(
|
||||||
config.vision_config,
|
config.vision_config,
|
||||||
norm_eps=getattr(config, "rms_norm_eps", 1e-6),
|
norm_eps=getattr(config, "rms_norm_eps", 1e-6),
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
prefix=maybe_prefix(prefix, "vision_model"),
|
prefix=maybe_prefix(prefix, "vision_model"),
|
||||||
|
attn_backend_override=attn_backend_override,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.language_model = Ernie4_5_VLMoeForCausalLM(
|
self.language_model = Ernie4_5_VLMoeForCausalLM(
|
||||||
|
|||||||
@ -247,6 +247,7 @@ class Glm4vVisionAttention(nn.Module):
|
|||||||
quant_config: QuantizationConfig | None = None,
|
quant_config: QuantizationConfig | None = None,
|
||||||
prefix: str = "",
|
prefix: str = "",
|
||||||
use_data_parallel: bool = False,
|
use_data_parallel: bool = False,
|
||||||
|
attn_backend_override: _Backend | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
# Per attention head and per partition values.
|
# Per attention head and per partition values.
|
||||||
@ -287,6 +288,7 @@ class Glm4vVisionAttention(nn.Module):
|
|||||||
self.attn_backend = get_vit_attn_backend(
|
self.attn_backend = get_vit_attn_backend(
|
||||||
head_size=self.hidden_size_per_attention_head,
|
head_size=self.hidden_size_per_attention_head,
|
||||||
dtype=torch.get_default_dtype(),
|
dtype=torch.get_default_dtype(),
|
||||||
|
attn_backend_override=attn_backend_override,
|
||||||
)
|
)
|
||||||
self.use_upstream_fa = False
|
self.use_upstream_fa = False
|
||||||
|
|
||||||
@ -417,6 +419,7 @@ class Glm4vVisionBlock(nn.Module):
|
|||||||
quant_config: QuantizationConfig | None = None,
|
quant_config: QuantizationConfig | None = None,
|
||||||
prefix: str = "",
|
prefix: str = "",
|
||||||
use_data_parallel: bool = False,
|
use_data_parallel: bool = False,
|
||||||
|
attn_backend_override: _Backend | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
if norm_layer is None:
|
if norm_layer is None:
|
||||||
@ -430,6 +433,7 @@ class Glm4vVisionBlock(nn.Module):
|
|||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
prefix=f"{prefix}.attn",
|
prefix=f"{prefix}.attn",
|
||||||
use_data_parallel=use_data_parallel,
|
use_data_parallel=use_data_parallel,
|
||||||
|
attn_backend_override=attn_backend_override,
|
||||||
)
|
)
|
||||||
self.mlp = Glm4vVisionMLP(
|
self.mlp = Glm4vVisionMLP(
|
||||||
dim,
|
dim,
|
||||||
@ -696,6 +700,7 @@ class Glm4vVisionTransformer(nn.Module):
|
|||||||
quant_config: QuantizationConfig | None = None,
|
quant_config: QuantizationConfig | None = None,
|
||||||
prefix: str = "",
|
prefix: str = "",
|
||||||
use_data_parallel: bool = False,
|
use_data_parallel: bool = False,
|
||||||
|
attn_backend_override: _Backend | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@ -731,6 +736,7 @@ class Glm4vVisionTransformer(nn.Module):
|
|||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
prefix=f"{prefix}.blocks.{layer_idx}",
|
prefix=f"{prefix}.blocks.{layer_idx}",
|
||||||
use_data_parallel=self.use_data_parallel,
|
use_data_parallel=self.use_data_parallel,
|
||||||
|
attn_backend_override=attn_backend_override,
|
||||||
)
|
)
|
||||||
for layer_idx in range(depth)
|
for layer_idx in range(depth)
|
||||||
]
|
]
|
||||||
@ -759,7 +765,9 @@ class Glm4vVisionTransformer(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
self.attn_backend = get_vit_attn_backend(
|
self.attn_backend = get_vit_attn_backend(
|
||||||
head_size=head_dim, dtype=torch.get_default_dtype()
|
head_size=head_dim,
|
||||||
|
dtype=torch.get_default_dtype(),
|
||||||
|
attn_backend_override=attn_backend_override,
|
||||||
)
|
)
|
||||||
if self.attn_backend != _Backend.FLASH_ATTN and check_upstream_fa_availability(
|
if self.attn_backend != _Backend.FLASH_ATTN and check_upstream_fa_availability(
|
||||||
torch.get_default_dtype()
|
torch.get_default_dtype()
|
||||||
@ -1437,12 +1445,18 @@ class Glm4vForConditionalGeneration(
|
|||||||
self.multimodal_config = multimodal_config
|
self.multimodal_config = multimodal_config
|
||||||
self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data"
|
self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data"
|
||||||
|
|
||||||
|
attn_backend_override = (
|
||||||
|
multimodal_config.mm_encoder_attn_backend
|
||||||
|
if multimodal_config is not None
|
||||||
|
else None
|
||||||
|
)
|
||||||
self.visual = Glm4vVisionTransformer(
|
self.visual = Glm4vVisionTransformer(
|
||||||
config.vision_config,
|
config.vision_config,
|
||||||
norm_eps=getattr(config, "rms_norm_eps", 1e-5),
|
norm_eps=getattr(config, "rms_norm_eps", 1e-5),
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
prefix=maybe_prefix(prefix, "visual"),
|
prefix=maybe_prefix(prefix, "visual"),
|
||||||
use_data_parallel=self.use_data_parallel,
|
use_data_parallel=self.use_data_parallel,
|
||||||
|
attn_backend_override=attn_backend_override,
|
||||||
)
|
)
|
||||||
|
|
||||||
if config.model_type == "glm4v":
|
if config.model_type == "glm4v":
|
||||||
|
|||||||
@ -353,6 +353,7 @@ class KeyeSiglipAttention(nn.Module):
|
|||||||
config: PretrainedConfig,
|
config: PretrainedConfig,
|
||||||
quant_config: QuantizationConfig | None = None,
|
quant_config: QuantizationConfig | None = None,
|
||||||
prefix: str = "",
|
prefix: str = "",
|
||||||
|
attn_backend_override: _Backend | None = None,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.config = config
|
self.config = config
|
||||||
@ -392,7 +393,9 @@ class KeyeSiglipAttention(nn.Module):
|
|||||||
|
|
||||||
# Detect attention implementation.
|
# Detect attention implementation.
|
||||||
self.attn_backend = get_vit_attn_backend(
|
self.attn_backend = get_vit_attn_backend(
|
||||||
head_size=self.head_dim, dtype=torch.get_default_dtype()
|
head_size=self.head_dim,
|
||||||
|
dtype=torch.get_default_dtype(),
|
||||||
|
attn_backend_override=attn_backend_override,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.use_upstream_fa = False
|
self.use_upstream_fa = False
|
||||||
@ -521,6 +524,7 @@ class KeyeSiglipEncoderLayer(nn.Module):
|
|||||||
config: PretrainedConfig,
|
config: PretrainedConfig,
|
||||||
quant_config: QuantizationConfig | None = None,
|
quant_config: QuantizationConfig | None = None,
|
||||||
prefix: str = "",
|
prefix: str = "",
|
||||||
|
attn_backend_override: _Backend | None = None,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.embed_dim = config.hidden_size
|
self.embed_dim = config.hidden_size
|
||||||
@ -529,6 +533,7 @@ class KeyeSiglipEncoderLayer(nn.Module):
|
|||||||
config,
|
config,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
prefix=f"{prefix}.self_attn",
|
prefix=f"{prefix}.self_attn",
|
||||||
|
attn_backend_override=attn_backend_override,
|
||||||
)
|
)
|
||||||
self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
|
self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
|
||||||
self.mlp = SiglipMLP(
|
self.mlp = SiglipMLP(
|
||||||
@ -573,6 +578,7 @@ class KeyeSiglipEncoder(nn.Module):
|
|||||||
config: PretrainedConfig,
|
config: PretrainedConfig,
|
||||||
quant_config: QuantizationConfig | None = None,
|
quant_config: QuantizationConfig | None = None,
|
||||||
prefix: str = "",
|
prefix: str = "",
|
||||||
|
attn_backend_override: _Backend | None = None,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.config = config
|
self.config = config
|
||||||
@ -585,6 +591,7 @@ class KeyeSiglipEncoder(nn.Module):
|
|||||||
config,
|
config,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
prefix=f"{prefix}.layers.{layer_idx}",
|
prefix=f"{prefix}.layers.{layer_idx}",
|
||||||
|
attn_backend_override=attn_backend_override,
|
||||||
)
|
)
|
||||||
for layer_idx in range(config.num_hidden_layers)
|
for layer_idx in range(config.num_hidden_layers)
|
||||||
]
|
]
|
||||||
@ -666,6 +673,7 @@ class KeyeSiglipVisionTransformer(nn.Module):
|
|||||||
config: PretrainedConfig,
|
config: PretrainedConfig,
|
||||||
quant_config: QuantizationConfig | None = None,
|
quant_config: QuantizationConfig | None = None,
|
||||||
prefix: str = "",
|
prefix: str = "",
|
||||||
|
attn_backend_override: _Backend | None = None,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.config = config
|
self.config = config
|
||||||
@ -676,6 +684,7 @@ class KeyeSiglipVisionTransformer(nn.Module):
|
|||||||
config,
|
config,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
prefix=f"{prefix}.encoder",
|
prefix=f"{prefix}.encoder",
|
||||||
|
attn_backend_override=attn_backend_override,
|
||||||
)
|
)
|
||||||
self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
|
self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
|
||||||
|
|
||||||
@ -747,6 +756,7 @@ class KeyeSiglipVisionModel(nn.Module):
|
|||||||
config: PretrainedConfig,
|
config: PretrainedConfig,
|
||||||
quant_config: QuantizationConfig | None = None,
|
quant_config: QuantizationConfig | None = None,
|
||||||
prefix: str = "",
|
prefix: str = "",
|
||||||
|
attn_backend_override: _Backend | None = None,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@ -754,6 +764,7 @@ class KeyeSiglipVisionModel(nn.Module):
|
|||||||
config,
|
config,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
prefix=f"{prefix}.vision_model",
|
prefix=f"{prefix}.vision_model",
|
||||||
|
attn_backend_override=attn_backend_override,
|
||||||
)
|
)
|
||||||
self.quant_config = quant_config
|
self.quant_config = quant_config
|
||||||
|
|
||||||
@ -1296,10 +1307,16 @@ class BaseKeyeModule(nn.Module):
|
|||||||
self.config = config
|
self.config = config
|
||||||
self.multimodal_config = multimodal_config
|
self.multimodal_config = multimodal_config
|
||||||
|
|
||||||
|
attn_backend_override = (
|
||||||
|
multimodal_config.mm_encoder_attn_backend
|
||||||
|
if multimodal_config is not None
|
||||||
|
else None
|
||||||
|
)
|
||||||
self.visual = KeyeSiglipVisionModel(
|
self.visual = KeyeSiglipVisionModel(
|
||||||
config.vision_config,
|
config.vision_config,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
prefix=maybe_prefix(prefix, "visual"),
|
prefix=maybe_prefix(prefix, "visual"),
|
||||||
|
attn_backend_override=attn_backend_override,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.mlp_AR = self._build_projector(
|
self.mlp_AR = self._build_projector(
|
||||||
|
|||||||
@ -10,6 +10,7 @@ import torch
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from transformers import BaseImageProcessor, BatchFeature, PretrainedConfig
|
from transformers import BaseImageProcessor, BatchFeature, PretrainedConfig
|
||||||
|
|
||||||
|
from vllm.attention.backends.registry import _Backend
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
from vllm.config.multimodal import BaseDummyOptions
|
from vllm.config.multimodal import BaseDummyOptions
|
||||||
from vllm.model_executor.layers.linear import ReplicatedLinear
|
from vllm.model_executor.layers.linear import ReplicatedLinear
|
||||||
@ -105,6 +106,7 @@ class VisualTokenizer(torch.nn.Module):
|
|||||||
quant_config: QuantizationConfig | None = None,
|
quant_config: QuantizationConfig | None = None,
|
||||||
prefix: str = "",
|
prefix: str = "",
|
||||||
use_data_parallel: bool = False,
|
use_data_parallel: bool = False,
|
||||||
|
attn_backend_override: _Backend | None = None,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.config = config
|
self.config = config
|
||||||
@ -113,6 +115,7 @@ class VisualTokenizer(torch.nn.Module):
|
|||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
prefix=f"{prefix}.vit",
|
prefix=f"{prefix}.vit",
|
||||||
use_data_parallel=use_data_parallel,
|
use_data_parallel=use_data_parallel,
|
||||||
|
attn_backend_override=attn_backend_override,
|
||||||
)
|
)
|
||||||
# reserved tokens for INDICATOR_IDS
|
# reserved tokens for INDICATOR_IDS
|
||||||
head_dim = visual_vocab_size - len(INDICATOR_IDS)
|
head_dim = visual_vocab_size - len(INDICATOR_IDS)
|
||||||
@ -132,6 +135,7 @@ class VisualTokenizer(torch.nn.Module):
|
|||||||
quant_config: QuantizationConfig | None = None,
|
quant_config: QuantizationConfig | None = None,
|
||||||
prefix: str = "",
|
prefix: str = "",
|
||||||
use_data_parallel: bool = False,
|
use_data_parallel: bool = False,
|
||||||
|
attn_backend_override: _Backend | None = None,
|
||||||
):
|
):
|
||||||
model_type = config.model_type
|
model_type = config.model_type
|
||||||
if model_type == "siglip2_navit":
|
if model_type == "siglip2_navit":
|
||||||
@ -140,6 +144,7 @@ class VisualTokenizer(torch.nn.Module):
|
|||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
prefix=prefix,
|
prefix=prefix,
|
||||||
use_data_parallel=use_data_parallel,
|
use_data_parallel=use_data_parallel,
|
||||||
|
attn_backend_override=attn_backend_override,
|
||||||
)
|
)
|
||||||
raise ValueError(f"Unsupported visual tokenizer model_type: {model_type}")
|
raise ValueError(f"Unsupported visual tokenizer model_type: {model_type}")
|
||||||
|
|
||||||
@ -457,6 +462,7 @@ class Ovis2_5(nn.Module, SupportsMultiModal, SupportsPP):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
config = vllm_config.model_config.hf_config
|
config = vllm_config.model_config.hf_config
|
||||||
quant_config = vllm_config.quant_config
|
quant_config = vllm_config.quant_config
|
||||||
|
multimodal_config = vllm_config.model_config.multimodal_config
|
||||||
|
|
||||||
self.config: PretrainedConfig = config
|
self.config: PretrainedConfig = config
|
||||||
self.llm = init_vllm_registered_model(
|
self.llm = init_vllm_registered_model(
|
||||||
@ -464,11 +470,17 @@ class Ovis2_5(nn.Module, SupportsMultiModal, SupportsPP):
|
|||||||
prefix=maybe_prefix(prefix, "llm"),
|
prefix=maybe_prefix(prefix, "llm"),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
attn_backend_override = (
|
||||||
|
multimodal_config.mm_encoder_attn_backend
|
||||||
|
if multimodal_config is not None
|
||||||
|
else None
|
||||||
|
)
|
||||||
self.visual_tokenizer = VisualTokenizer(
|
self.visual_tokenizer = VisualTokenizer(
|
||||||
config=config.vit_config,
|
config=config.vit_config,
|
||||||
visual_vocab_size=config.visual_vocab_size,
|
visual_vocab_size=config.visual_vocab_size,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
prefix=f"{prefix}.visual_tokenizer",
|
prefix=f"{prefix}.visual_tokenizer",
|
||||||
|
attn_backend_override=attn_backend_override,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.vte = VisualEmbedding(config.visual_vocab_size, config.hidden_size)
|
self.vte = VisualEmbedding(config.visual_vocab_size, config.hidden_size)
|
||||||
|
|||||||
@ -637,6 +637,7 @@ class Qwen2_5_VisionTransformer(nn.Module):
|
|||||||
quant_config: QuantizationConfig | None = None,
|
quant_config: QuantizationConfig | None = None,
|
||||||
prefix: str = "",
|
prefix: str = "",
|
||||||
use_data_parallel: bool = False,
|
use_data_parallel: bool = False,
|
||||||
|
attn_backend_override: _Backend | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@ -669,7 +670,9 @@ class Qwen2_5_VisionTransformer(nn.Module):
|
|||||||
|
|
||||||
use_upstream_fa = False
|
use_upstream_fa = False
|
||||||
self.attn_backend = get_vit_attn_backend(
|
self.attn_backend = get_vit_attn_backend(
|
||||||
head_size=head_dim, dtype=torch.get_default_dtype()
|
head_size=head_dim,
|
||||||
|
dtype=torch.get_default_dtype(),
|
||||||
|
attn_backend_override=attn_backend_override,
|
||||||
)
|
)
|
||||||
if (
|
if (
|
||||||
self.attn_backend != _Backend.FLASH_ATTN
|
self.attn_backend != _Backend.FLASH_ATTN
|
||||||
@ -1226,12 +1229,18 @@ class Qwen2_5_VLForConditionalGeneration(
|
|||||||
if multimodal_config.get_limit_per_prompt(
|
if multimodal_config.get_limit_per_prompt(
|
||||||
"image"
|
"image"
|
||||||
) or multimodal_config.get_limit_per_prompt("video"):
|
) or multimodal_config.get_limit_per_prompt("video"):
|
||||||
|
attn_backend_override = (
|
||||||
|
multimodal_config.mm_encoder_attn_backend
|
||||||
|
if multimodal_config is not None
|
||||||
|
else None
|
||||||
|
)
|
||||||
self.visual = Qwen2_5_VisionTransformer(
|
self.visual = Qwen2_5_VisionTransformer(
|
||||||
config.vision_config,
|
config.vision_config,
|
||||||
norm_eps=getattr(config, "rms_norm_eps", 1e-6),
|
norm_eps=getattr(config, "rms_norm_eps", 1e-6),
|
||||||
quant_config=self.quant_config,
|
quant_config=self.quant_config,
|
||||||
prefix=maybe_prefix(prefix, "visual"),
|
prefix=maybe_prefix(prefix, "visual"),
|
||||||
use_data_parallel=self.use_data_parallel,
|
use_data_parallel=self.use_data_parallel,
|
||||||
|
attn_backend_override=attn_backend_override,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
self.visual = None
|
self.visual = None
|
||||||
|
|||||||
@ -320,6 +320,7 @@ class Qwen2VisionAttention(nn.Module):
|
|||||||
quant_config: QuantizationConfig | None = None,
|
quant_config: QuantizationConfig | None = None,
|
||||||
prefix: str = "",
|
prefix: str = "",
|
||||||
use_data_parallel: bool = False,
|
use_data_parallel: bool = False,
|
||||||
|
attn_backend_override: _Backend | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
# Per attention head and per partition values.
|
# Per attention head and per partition values.
|
||||||
@ -355,6 +356,7 @@ class Qwen2VisionAttention(nn.Module):
|
|||||||
self.attn_backend = get_vit_attn_backend(
|
self.attn_backend = get_vit_attn_backend(
|
||||||
head_size=self.hidden_size_per_attention_head,
|
head_size=self.hidden_size_per_attention_head,
|
||||||
dtype=torch.get_default_dtype(),
|
dtype=torch.get_default_dtype(),
|
||||||
|
attn_backend_override=attn_backend_override,
|
||||||
)
|
)
|
||||||
self.use_upstream_fa = False
|
self.use_upstream_fa = False
|
||||||
|
|
||||||
@ -497,6 +499,7 @@ class Qwen2VisionBlock(nn.Module):
|
|||||||
quant_config: QuantizationConfig | None = None,
|
quant_config: QuantizationConfig | None = None,
|
||||||
prefix: str = "",
|
prefix: str = "",
|
||||||
use_data_parallel: bool = False,
|
use_data_parallel: bool = False,
|
||||||
|
attn_backend_override: _Backend | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
if norm_layer is None:
|
if norm_layer is None:
|
||||||
@ -512,6 +515,7 @@ class Qwen2VisionBlock(nn.Module):
|
|||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
prefix=f"{prefix}.attn",
|
prefix=f"{prefix}.attn",
|
||||||
use_data_parallel=use_data_parallel,
|
use_data_parallel=use_data_parallel,
|
||||||
|
attn_backend_override=attn_backend_override,
|
||||||
)
|
)
|
||||||
self.mlp = Qwen2VisionMLP(
|
self.mlp = Qwen2VisionMLP(
|
||||||
dim,
|
dim,
|
||||||
@ -662,6 +666,7 @@ class Qwen2VisionTransformer(nn.Module):
|
|||||||
quant_config: QuantizationConfig | None = None,
|
quant_config: QuantizationConfig | None = None,
|
||||||
prefix: str = "",
|
prefix: str = "",
|
||||||
use_data_parallel: bool = False,
|
use_data_parallel: bool = False,
|
||||||
|
attn_backend_override: _Backend | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@ -703,6 +708,7 @@ class Qwen2VisionTransformer(nn.Module):
|
|||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
prefix=f"{prefix}.blocks.{layer_idx}",
|
prefix=f"{prefix}.blocks.{layer_idx}",
|
||||||
use_data_parallel=use_data_parallel,
|
use_data_parallel=use_data_parallel,
|
||||||
|
attn_backend_override=attn_backend_override,
|
||||||
)
|
)
|
||||||
for layer_idx in range(depth)
|
for layer_idx in range(depth)
|
||||||
]
|
]
|
||||||
@ -716,7 +722,9 @@ class Qwen2VisionTransformer(nn.Module):
|
|||||||
use_data_parallel=use_data_parallel,
|
use_data_parallel=use_data_parallel,
|
||||||
)
|
)
|
||||||
self.attn_backend = get_vit_attn_backend(
|
self.attn_backend = get_vit_attn_backend(
|
||||||
head_size=head_dim, dtype=torch.get_default_dtype()
|
head_size=head_dim,
|
||||||
|
dtype=torch.get_default_dtype(),
|
||||||
|
attn_backend_override=attn_backend_override,
|
||||||
)
|
)
|
||||||
if self.attn_backend != _Backend.FLASH_ATTN and check_upstream_fa_availability(
|
if self.attn_backend != _Backend.FLASH_ATTN and check_upstream_fa_availability(
|
||||||
torch.get_default_dtype()
|
torch.get_default_dtype()
|
||||||
@ -1356,12 +1364,18 @@ class Qwen2VLForConditionalGeneration(
|
|||||||
if multimodal_config.get_limit_per_prompt(
|
if multimodal_config.get_limit_per_prompt(
|
||||||
"image"
|
"image"
|
||||||
) or multimodal_config.get_limit_per_prompt("video"):
|
) or multimodal_config.get_limit_per_prompt("video"):
|
||||||
|
attn_backend_override = (
|
||||||
|
multimodal_config.mm_encoder_attn_backend
|
||||||
|
if multimodal_config is not None
|
||||||
|
else None
|
||||||
|
)
|
||||||
self.visual = Qwen2VisionTransformer(
|
self.visual = Qwen2VisionTransformer(
|
||||||
config.vision_config,
|
config.vision_config,
|
||||||
norm_eps=getattr(config, "rms_norm_eps", 1e-6),
|
norm_eps=getattr(config, "rms_norm_eps", 1e-6),
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
prefix=maybe_prefix(prefix, "visual"),
|
prefix=maybe_prefix(prefix, "visual"),
|
||||||
use_data_parallel=self.use_data_parallel,
|
use_data_parallel=self.use_data_parallel,
|
||||||
|
attn_backend_override=attn_backend_override,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
self.visual = None
|
self.visual = None
|
||||||
|
|||||||
@ -296,6 +296,7 @@ class Qwen3Omni_VisionTransformer(nn.Module):
|
|||||||
norm_eps: float = 1e-6,
|
norm_eps: float = 1e-6,
|
||||||
quant_config: QuantizationConfig | None = None,
|
quant_config: QuantizationConfig | None = None,
|
||||||
prefix: str = "",
|
prefix: str = "",
|
||||||
|
attn_backend_override: _Backend | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.hidden_size = vision_config.hidden_size
|
self.hidden_size = vision_config.hidden_size
|
||||||
@ -367,7 +368,9 @@ class Qwen3Omni_VisionTransformer(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
self.attn_backend = get_vit_attn_backend(
|
self.attn_backend = get_vit_attn_backend(
|
||||||
head_size=head_dim, dtype=torch.get_default_dtype()
|
head_size=head_dim,
|
||||||
|
dtype=torch.get_default_dtype(),
|
||||||
|
attn_backend_override=attn_backend_override,
|
||||||
)
|
)
|
||||||
if self.attn_backend != _Backend.FLASH_ATTN and check_upstream_fa_availability(
|
if self.attn_backend != _Backend.FLASH_ATTN and check_upstream_fa_availability(
|
||||||
torch.get_default_dtype()
|
torch.get_default_dtype()
|
||||||
@ -1144,11 +1147,17 @@ class Qwen3OmniMoeThinkerForConditionalGeneration(
|
|||||||
|
|
||||||
self.audio_tower = Qwen3OmniMoeAudioEncoder(thinker_config.audio_config)
|
self.audio_tower = Qwen3OmniMoeAudioEncoder(thinker_config.audio_config)
|
||||||
|
|
||||||
|
attn_backend_override = (
|
||||||
|
multimodal_config.mm_encoder_attn_backend
|
||||||
|
if multimodal_config is not None
|
||||||
|
else None
|
||||||
|
)
|
||||||
self.visual = Qwen3Omni_VisionTransformer(
|
self.visual = Qwen3Omni_VisionTransformer(
|
||||||
vision_config=thinker_config.vision_config,
|
vision_config=thinker_config.vision_config,
|
||||||
norm_eps=getattr(thinker_config.text_config, "rms_norm_eps", 1e-6),
|
norm_eps=getattr(thinker_config.text_config, "rms_norm_eps", 1e-6),
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
prefix=maybe_prefix(prefix, "visual"),
|
prefix=maybe_prefix(prefix, "visual"),
|
||||||
|
attn_backend_override=attn_backend_override,
|
||||||
)
|
)
|
||||||
self.quant_config = quant_config
|
self.quant_config = quant_config
|
||||||
|
|
||||||
|
|||||||
@ -300,6 +300,7 @@ class Qwen3_VisionTransformer(nn.Module):
|
|||||||
quant_config: QuantizationConfig | None = None,
|
quant_config: QuantizationConfig | None = None,
|
||||||
prefix: str = "",
|
prefix: str = "",
|
||||||
use_data_parallel: bool = False,
|
use_data_parallel: bool = False,
|
||||||
|
attn_backend_override: _Backend | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.hidden_size = vision_config.hidden_size
|
self.hidden_size = vision_config.hidden_size
|
||||||
@ -359,7 +360,9 @@ class Qwen3_VisionTransformer(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
self.attn_backend = get_vit_attn_backend(
|
self.attn_backend = get_vit_attn_backend(
|
||||||
head_size=head_dim, dtype=torch.get_default_dtype()
|
head_size=head_dim,
|
||||||
|
dtype=torch.get_default_dtype(),
|
||||||
|
attn_backend_override=attn_backend_override,
|
||||||
)
|
)
|
||||||
use_upstream_fa = False
|
use_upstream_fa = False
|
||||||
if (
|
if (
|
||||||
@ -379,7 +382,6 @@ class Qwen3_VisionTransformer(nn.Module):
|
|||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
f"Qwen3-VL does not support {self.attn_backend} backend now."
|
f"Qwen3-VL does not support {self.attn_backend} backend now."
|
||||||
)
|
)
|
||||||
|
|
||||||
self.blocks = nn.ModuleList(
|
self.blocks = nn.ModuleList(
|
||||||
[
|
[
|
||||||
Qwen3_VisionBlock(
|
Qwen3_VisionBlock(
|
||||||
@ -1214,12 +1216,18 @@ class Qwen3VLForConditionalGeneration(
|
|||||||
) and not multimodal_config.get_limit_per_prompt("video"):
|
) and not multimodal_config.get_limit_per_prompt("video"):
|
||||||
self.visual = None
|
self.visual = None
|
||||||
else:
|
else:
|
||||||
|
attn_backend_override = (
|
||||||
|
multimodal_config.mm_encoder_attn_backend
|
||||||
|
if multimodal_config is not None
|
||||||
|
else None
|
||||||
|
)
|
||||||
self.visual = Qwen3_VisionTransformer(
|
self.visual = Qwen3_VisionTransformer(
|
||||||
config.vision_config,
|
config.vision_config,
|
||||||
norm_eps=getattr(config, "rms_norm_eps", 1e-6),
|
norm_eps=getattr(config, "rms_norm_eps", 1e-6),
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
prefix=maybe_prefix(prefix, "visual"),
|
prefix=maybe_prefix(prefix, "visual"),
|
||||||
use_data_parallel=self.use_data_parallel,
|
use_data_parallel=self.use_data_parallel,
|
||||||
|
attn_backend_override=attn_backend_override,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.language_model = Qwen3LLMForCausalLM(
|
self.language_model = Qwen3LLMForCausalLM(
|
||||||
|
|||||||
@ -208,6 +208,7 @@ class Siglip2Attention(nn.Module):
|
|||||||
quant_config: QuantizationConfig | None = None,
|
quant_config: QuantizationConfig | None = None,
|
||||||
prefix: str = "",
|
prefix: str = "",
|
||||||
use_data_parallel: bool = False,
|
use_data_parallel: bool = False,
|
||||||
|
attn_backend_override: _Backend | None = None,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.config = config
|
self.config = config
|
||||||
@ -248,7 +249,9 @@ class Siglip2Attention(nn.Module):
|
|||||||
|
|
||||||
# Detect attention implementation.
|
# Detect attention implementation.
|
||||||
self.attn_backend = get_vit_attn_backend(
|
self.attn_backend = get_vit_attn_backend(
|
||||||
head_size=self.head_dim, dtype=torch.get_default_dtype()
|
head_size=self.head_dim,
|
||||||
|
dtype=torch.get_default_dtype(),
|
||||||
|
attn_backend_override=attn_backend_override,
|
||||||
)
|
)
|
||||||
self.use_upstream_fa = False
|
self.use_upstream_fa = False
|
||||||
|
|
||||||
@ -372,6 +375,7 @@ class Siglip2EncoderLayer(nn.Module):
|
|||||||
quant_config: QuantizationConfig | None = None,
|
quant_config: QuantizationConfig | None = None,
|
||||||
prefix: str = "",
|
prefix: str = "",
|
||||||
use_data_parallel: bool = False,
|
use_data_parallel: bool = False,
|
||||||
|
attn_backend_override: _Backend | None = None,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.embed_dim = config.hidden_size
|
self.embed_dim = config.hidden_size
|
||||||
@ -381,6 +385,7 @@ class Siglip2EncoderLayer(nn.Module):
|
|||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
prefix=f"{prefix}.self_attn",
|
prefix=f"{prefix}.self_attn",
|
||||||
use_data_parallel=use_data_parallel,
|
use_data_parallel=use_data_parallel,
|
||||||
|
attn_backend_override=attn_backend_override,
|
||||||
)
|
)
|
||||||
self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
|
self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
|
||||||
self.mlp = Siglip2MLP(
|
self.mlp = Siglip2MLP(
|
||||||
@ -434,6 +439,7 @@ class Siglip2Encoder(nn.Module):
|
|||||||
quant_config: QuantizationConfig | None = None,
|
quant_config: QuantizationConfig | None = None,
|
||||||
prefix: str = "",
|
prefix: str = "",
|
||||||
use_data_parallel: bool = False,
|
use_data_parallel: bool = False,
|
||||||
|
attn_backend_override: _Backend | None = None,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.config = config
|
self.config = config
|
||||||
@ -444,6 +450,7 @@ class Siglip2Encoder(nn.Module):
|
|||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
prefix=f"{prefix}.layers.{idx}",
|
prefix=f"{prefix}.layers.{idx}",
|
||||||
use_data_parallel=use_data_parallel,
|
use_data_parallel=use_data_parallel,
|
||||||
|
attn_backend_override=attn_backend_override,
|
||||||
)
|
)
|
||||||
for idx in range(config.num_hidden_layers)
|
for idx in range(config.num_hidden_layers)
|
||||||
]
|
]
|
||||||
@ -618,6 +625,7 @@ class Siglip2VisionTransformer(nn.Module):
|
|||||||
quant_config: QuantizationConfig | None = None,
|
quant_config: QuantizationConfig | None = None,
|
||||||
prefix: str = "",
|
prefix: str = "",
|
||||||
use_data_parallel: bool = False,
|
use_data_parallel: bool = False,
|
||||||
|
attn_backend_override: _Backend | None = None,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.config = config
|
self.config = config
|
||||||
@ -629,6 +637,7 @@ class Siglip2VisionTransformer(nn.Module):
|
|||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
prefix=f"{prefix}.encoder",
|
prefix=f"{prefix}.encoder",
|
||||||
use_data_parallel=use_data_parallel,
|
use_data_parallel=use_data_parallel,
|
||||||
|
attn_backend_override=attn_backend_override,
|
||||||
)
|
)
|
||||||
self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
|
self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
|
||||||
|
|
||||||
@ -657,6 +666,7 @@ class Siglip2NavitModel(torch.nn.Module):
|
|||||||
quant_config: QuantizationConfig | None = None,
|
quant_config: QuantizationConfig | None = None,
|
||||||
prefix: str = "",
|
prefix: str = "",
|
||||||
use_data_parallel: bool = False,
|
use_data_parallel: bool = False,
|
||||||
|
attn_backend_override: _Backend | None = None,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@ -665,6 +675,7 @@ class Siglip2NavitModel(torch.nn.Module):
|
|||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
prefix=f"{prefix}.vision_model",
|
prefix=f"{prefix}.vision_model",
|
||||||
use_data_parallel=use_data_parallel,
|
use_data_parallel=use_data_parallel,
|
||||||
|
attn_backend_override=attn_backend_override,
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
|
|||||||
@ -78,10 +78,18 @@ def get_vision_encoder_info(hf_config: VisionLanguageConfig) -> VisionEncoderInf
|
|||||||
raise NotImplementedError(msg)
|
raise NotImplementedError(msg)
|
||||||
|
|
||||||
|
|
||||||
def get_vit_attn_backend(head_size: int, dtype: torch.dtype) -> _Backend:
|
def get_vit_attn_backend(
|
||||||
|
head_size: int,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
*,
|
||||||
|
attn_backend_override: _Backend | None = None,
|
||||||
|
) -> _Backend:
|
||||||
"""
|
"""
|
||||||
Get the available attention backend for Vision Transformer.
|
Get the available attention backend for Vision Transformer.
|
||||||
"""
|
"""
|
||||||
|
if attn_backend_override is not None:
|
||||||
|
return attn_backend_override
|
||||||
|
|
||||||
# Lazy import to avoid circular dependency
|
# Lazy import to avoid circular dependency
|
||||||
from vllm.attention.selector import get_env_variable_attn_backend
|
from vllm.attention.selector import get_env_variable_attn_backend
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user