[CustomOp][MM] Extract MMEncoderAttention as CustomOp and replace the backend of QwenVisionAttention with it. (#30125)

Signed-off-by: shen-shanshan <467638484@qq.com>
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com>
Co-authored-by: Isotr0py <mozf@mail2.sysu.edu.cn>
Co-authored-by: tjtanaa <tunjian.tan@embeddedllm.com>
This commit is contained in:
Shanshan Shen 2025-12-15 11:13:32 +08:00 committed by GitHub
parent 84e23d103d
commit 87b4d1557d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
24 changed files with 1262 additions and 851 deletions

View File

@ -0,0 +1,434 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Consolidated test for ViT attention backend functionality across multiple models.
This test validates that each multimodal model can successfully generate outputs
using different ViT attention backends. Tests are parametrized by model and backend.
"""
from dataclasses import asdict
from typing import Any
import pytest
from transformers import AutoProcessor
from vllm import LLM, EngineArgs, SamplingParams
from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.multimodal.utils import encode_image_base64
from vllm.multimodal.video import sample_frames_from_video
from vllm.platforms import current_platform
from ....utils import create_new_process_for_each_test
from ...utils import dummy_hf_overrides
# Dots.OCR prompt from official repository
# https://github.com/rednote-hilab/dots.ocr/blob/d72d1d8c5bdd0362eb264f714cdbd1e5daa7cdff/dots_ocr/utils/prompts.py#L3
# ruff: noqa: E501
DOTS_OCR_PROMPT = """Please output the layout information from the PDF image, including each layout element's bbox, its category, and the corresponding text content within the bbox.
1. Bbox format: [x1, y1, x2, y2]
2. Layout Categories: The possible categories are ['Caption', 'Footnote', 'Formula', 'List-item', 'Page-footer', 'Page-header', 'Picture', 'Section-header', 'Table', 'Text', 'Title'].
3. Text Extraction & Formatting Rules:
- Picture: For the 'Picture' category, the text field should be omitted.
- Formula: Format its text as LaTeX.
- Table: Format its text as HTML.
- All Others (Text, Title, etc.): Format their text as Markdown.
4. Constraints:
- The output text must be the original text from the image, with no translation.
- All layout elements must be sorted according to human reading order.
5. Final Output: The entire output must be a single JSON object.
"""
VIDEO_PLACEHOLDER = "<|vision_start|><|video_pad|><|vision_end|>"
# Model configurations
MODEL_CONFIGS: dict[str, dict[str, Any]] = {
"dots_ocr": {
"model_name": "rednote-hilab/dots.ocr",
"interface": "llm_chat",
"max_model_len": 32768,
"max_num_seqs": 1,
"limit_mm_per_prompt": {"image": 1},
"sampling_params": {
"temperature": 0.1,
"max_tokens": 16384,
"top_p": 0.9,
"stop_token_ids": None,
},
"use_specific_image": "stop_sign",
"prompt_builder": "build_dots_ocr_prompt",
"output_validator": lambda x: len(x) > 10 and "stop" in x.lower(),
},
"ernie45_vl": {
"model_name": "baidu/ERNIE-4.5-VL-28B-A3B-PT",
"interface": "llm_generate",
"max_model_len": 16384,
"max_num_seqs": 2,
"sampling_params": {
"temperature": 0.0,
"max_tokens": 256,
"stop_token_ids": None,
},
"use_processor": True,
"question": "What is the content of each image?",
},
"glm4_1v": {
"model_name": "zai-org/GLM-4.1V-9B-Thinking",
"interface": "llm_generate",
"max_model_len": 32768,
"max_num_seqs": 2,
"sampling_params": {
"temperature": 0.0,
"max_tokens": 256,
"stop_token_ids": None,
},
"use_processor": True,
"question": "What is the content of each image?",
},
"keye_vl": {
"model_name": "Kwai-Keye/Keye-VL-8B-Preview",
"interface": "llm_generate",
"max_model_len": 8192,
"max_num_seqs": 5,
"sampling_params": {
"temperature": 0.0,
"max_tokens": 256,
"stop_token_ids": None,
},
"supported_backends": {
AttentionBackendEnum.FLASH_ATTN,
AttentionBackendEnum.ROCM_AITER_FA,
},
"use_processor": True,
"question": "What is the content of each image?",
},
"ovis2_5": {
"model_name": "AIDC-AI/Ovis2.5-2B",
"interface": "llm_generate",
"max_model_len": 8192,
"max_num_seqs": 2,
"sampling_params": {
"temperature": 0.0,
"max_tokens": 256,
"stop_token_ids": None,
},
"prompt_builder": "build_ovis_prompt",
"question": "What is the content of each image?",
},
"qwen2_5_vl": {
"model_name": "Qwen/Qwen2.5-VL-3B-Instruct",
"interface": "vllm_runner",
"media_type": "video",
"max_model_len": 4000,
"max_num_seqs": 1,
"limit_mm_per_prompt": {"video": 1},
"sampling_params": {
"max_tokens": 128,
},
"runner_kwargs": {
"runner": "generate",
"dtype": "bfloat16",
},
"video_params": {
"num_frames": 16,
"pruning_rates": [0.0, 0.75],
},
},
"qwen2_5_omni": {
"model_name": "Qwen/Qwen2.5-Omni-3B",
"interface": "llm_generate",
"max_model_len": 32768,
"max_num_seqs": 2,
"limit_mm_per_prompt": {"image": 3, "video": 3, "audio": 3},
"sampling_params": {
"temperature": 0.6,
"top_p": 0.95,
"top_k": 20,
"max_tokens": 16384,
},
"use_processor": True,
"question": "What is the content of each image?",
},
"qwen3_omni": {
"model_name": "Qwen/Qwen3-Omni-30B-A3B-Instruct",
"interface": "llm_generate",
"max_model_len": 32768,
"max_num_seqs": 2,
"limit_mm_per_prompt": {"image": 3, "video": 3, "audio": 3},
"sampling_params": {
"temperature": 0.6,
"top_p": 0.95,
"top_k": 20,
"max_tokens": 16384,
},
"use_processor": True,
"question": "What is the content of each image?",
},
}
# Prompt builder functions
def build_dots_ocr_prompt(images, config):
"""Build Dots.OCR specific prompt with OCR instructions."""
# Use only stop_sign image for Dots.OCR
image = images[0] # Already filtered to stop_sign
image_url = f"data:image/jpeg;base64,{encode_image_base64(image)}"
placeholders = [{"type": "image_url", "image_url": {"url": image_url}}]
messages = [
{
"role": "user",
"content": [
*placeholders,
{
"type": "text",
"text": f"<|img|><|imgpad|><|endofimg|>{DOTS_OCR_PROMPT}",
},
],
},
]
return messages
def build_processor_prompt(images, config):
"""Build prompt using AutoProcessor.apply_chat_template()."""
processor = AutoProcessor.from_pretrained(
config["model_name"], trust_remote_code=True
)
image_urls = [
f"data:image/jpeg;base64,{encode_image_base64(img)}" for img in images
]
placeholders = [{"type": "image", "image": url} for url in image_urls]
messages = [
{
"role": "user",
"content": [
*placeholders,
{"type": "text", "text": config["question"]},
],
},
]
return processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
def build_ovis_prompt(images, config):
"""Build Ovis2.5 specific prompt with custom format."""
image_urls = [
f"data:image/jpeg;base64,{encode_image_base64(img)}" for img in images
]
placeholders = "\n".join(
f"Image-{i}: <image>\n" for i, _ in enumerate(image_urls, start=1)
)
return (
f"<|im_start|>user\n\n{placeholders}\n{config['question']}<|im_end|>\n"
"<|im_start|>assistant\n"
)
def build_qwen2_5_video_prompt():
"""Build Qwen2.5-VL video prompt with EVS placeholder."""
return (
f"<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n"
f"<|im_start|>user\n{VIDEO_PLACEHOLDER}"
"Describe this video with a short sentence (no more than 20 words)"
"<|im_end|><|im_start|>assistant\n"
)
# Handler functions
def run_llm_generate_test(config, mm_encoder_attn_backend, image_assets):
"""Standard LLM.generate() interface handler."""
images = [asset.pil_image for asset in image_assets]
# Build prompt
if config.get("use_processor"):
prompt = build_processor_prompt(images, config)
else:
prompt_builder_name = config.get("prompt_builder", "build_ovis_prompt")
prompt_builder = globals()[prompt_builder_name]
prompt = prompt_builder(images, config)
# Determine limit_mm_per_prompt
limit_mm_per_prompt = config.get("limit_mm_per_prompt", {"image": len(images)})
# Create engine
engine_args = EngineArgs(
model=config["model_name"],
trust_remote_code=True,
max_model_len=config["max_model_len"],
max_num_seqs=config["max_num_seqs"],
limit_mm_per_prompt=limit_mm_per_prompt,
mm_encoder_attn_backend=mm_encoder_attn_backend,
hf_overrides=dummy_hf_overrides,
load_format="dummy",
)
engine_dict = asdict(engine_args) | {"seed": 42}
llm = LLM(**engine_dict)
# Generate
sampling_params = SamplingParams(**config["sampling_params"])
outputs = llm.generate(
{
"prompt": prompt,
"multi_modal_data": {"image": images},
},
sampling_params=sampling_params,
)
# Validate
for o in outputs:
generated_text = o.outputs[0].text
validator = config.get("output_validator", lambda x: len(x) > 10)
assert validator(generated_text), (
f"Validation failed for {config['model_name']}: {generated_text}"
)
def run_llm_chat_test(config, mm_encoder_attn_backend, image_assets):
"""LLM.chat() interface handler for Dots.OCR."""
# Filter to stop_sign image only
stop_sign_image = [
asset.pil_image for asset in image_assets if asset.name == "stop_sign"
][0]
# Build messages
messages = build_dots_ocr_prompt([stop_sign_image], config)
# Create engine
engine_args = EngineArgs(
model=config["model_name"],
trust_remote_code=True,
max_model_len=config["max_model_len"],
max_num_seqs=config["max_num_seqs"],
limit_mm_per_prompt=config["limit_mm_per_prompt"],
mm_encoder_attn_backend=mm_encoder_attn_backend,
hf_overrides=dummy_hf_overrides,
load_format="dummy",
)
engine_dict = asdict(engine_args) | {"seed": 42}
llm = LLM(**engine_dict)
# Generate using chat
sampling_params = SamplingParams(**config["sampling_params"])
outputs = llm.chat(messages=messages, sampling_params=sampling_params)
# Validate
for o in outputs:
generated_text = o.outputs[0].text
validator = config.get("output_validator", lambda x: len(x) > 10)
assert validator(generated_text), (
f"Validation failed for {config['model_name']}: {generated_text}"
)
def run_video_test(config, mm_encoder_attn_backend, video_assets, vllm_runner):
"""Video test with EVS (Efficient Video Sampling) handler."""
for pruning_rate in config["video_params"]["pruning_rates"]:
num_frames = config["video_params"]["num_frames"]
# Sample frames from video
sampled_vids = [
sample_frames_from_video(asset.np_ndarrays, num_frames)
for asset in video_assets
]
# Build prompt and prepare video
prompt = build_qwen2_5_video_prompt()
prompts = [prompt]
videos = [sampled_vids[0]]
# Run with vllm_runner context manager
with vllm_runner(
config["model_name"],
max_model_len=config["max_model_len"],
max_num_seqs=config["max_num_seqs"],
limit_mm_per_prompt=config["limit_mm_per_prompt"],
tensor_parallel_size=1,
video_pruning_rate=pruning_rate,
mm_encoder_attn_backend=mm_encoder_attn_backend,
hf_overrides=dummy_hf_overrides,
load_format="dummy",
**config["runner_kwargs"],
) as vllm_model:
outputs = vllm_model.generate_greedy(
prompts,
config["sampling_params"]["max_tokens"],
videos=videos,
)
# Validate output
assert len(outputs) == 1, f"Expected 1 output, got {len(outputs)}"
output_ids, output_text = outputs[0]
assert len(output_ids) > 0, "Generated no output IDs"
assert len(output_text) > 0, "Generated empty text"
assert isinstance(output_text, str), (
f"Output is not string: {type(output_text)}"
)
# Main test function
@pytest.mark.parametrize("model_key", list(MODEL_CONFIGS.keys()))
@pytest.mark.parametrize(
"mm_encoder_attn_backend",
[None] + current_platform.get_supported_vit_attn_backends(),
)
@create_new_process_for_each_test()
def test_vit_backend_functionality(
model_key: str,
mm_encoder_attn_backend: AttentionBackendEnum | None,
image_assets,
video_assets,
vllm_runner,
request,
):
"""Test ViT attention backend functionality for multimodal models.
This test validates that each model can successfully generate outputs
using different ViT attention backends. The test:
1. Filters unsupported backends per model
2. Applies appropriate GPU marks
3. Routes to the correct test handler based on interface
4. Validates output meets minimum requirements
"""
config = MODEL_CONFIGS[model_key]
# Step 1: Backend filtering
if (
"supported_backends" in config
and mm_encoder_attn_backend is not None
and mm_encoder_attn_backend not in config["supported_backends"]
):
pytest.skip(
f"{model_key} does not support {mm_encoder_attn_backend} backend now."
)
# Step 2: Apply GPU marks dynamically
if "gpu_marks" in config:
for mark in config["gpu_marks"]:
request.applymarker(mark)
# Step 3: Route to appropriate handler
if config.get("media_type") == "video":
run_video_test(config, mm_encoder_attn_backend, video_assets, vllm_runner)
elif config["interface"] == "llm_chat":
run_llm_chat_test(config, mm_encoder_attn_backend, image_assets)
elif config["interface"] == "llm_generate":
run_llm_generate_test(config, mm_encoder_attn_backend, image_assets)
else:
raise ValueError(f"Unknown interface: {config['interface']}")

View File

@ -3,7 +3,6 @@
"""Attention layer."""
import functools
from collections.abc import Callable
from typing import cast
import torch
@ -17,6 +16,7 @@ from vllm.attention.backends.abstract import (
MLAAttentionImpl,
)
from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.attention.layers.mm_encoder_attention import maybe_get_vit_flash_attn_backend
from vllm.attention.selector import get_attn_backend
from vllm.attention.utils.fa_utils import get_flash_attn_version
from vllm.attention.utils.kv_sharing_utils import validate_kv_sharing_target
@ -49,58 +49,9 @@ from vllm.v1.kv_cache_interface import (
SlidingWindowSpec,
)
if current_platform.is_rocm():
from vllm.platforms.rocm import on_gfx9
else:
on_gfx9 = lambda *args, **kwargs: False
FP8_DTYPE = current_platform.fp8_dtype()
logger = init_logger(__name__)
def maybe_get_vit_flash_attn_backend(
attn_backend: AttentionBackendEnum,
attn_backend_override: AttentionBackendEnum | None = None,
) -> tuple[AttentionBackendEnum, Callable | None]:
if current_platform.is_rocm():
if envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_MHA and on_gfx9():
attn_backend = AttentionBackendEnum.ROCM_AITER_FA
elif (
attn_backend_override is None
and on_gfx9()
and attn_backend == AttentionBackendEnum.FLASH_ATTN
):
pass
else:
return AttentionBackendEnum.TORCH_SDPA, None
elif current_platform.is_cuda():
pass
elif current_platform.is_xpu():
assert attn_backend == AttentionBackendEnum.FLASH_ATTN, (
"XPU platform only supports FLASH_ATTN as vision attention backend."
)
pass
else:
return AttentionBackendEnum.TORCH_SDPA, None
if attn_backend in {
AttentionBackendEnum.FLASH_ATTN,
AttentionBackendEnum.ROCM_AITER_FA,
}:
if attn_backend == AttentionBackendEnum.ROCM_AITER_FA:
from aiter import flash_attn_varlen_func
else:
try:
from vllm.attention.utils.fa_utils import flash_attn_varlen_func
except ImportError:
flash_attn_varlen_func = None
else:
flash_attn_varlen_func = None
return attn_backend, flash_attn_varlen_func
def _init_kv_cache_quant(
layer: nn.Module,
quant_config: QuantizationConfig | None,
@ -496,29 +447,15 @@ class MultiHeadAttention(nn.Module):
attn_backend_override = None
if multimodal_config is not None:
attn_backend_override = multimodal_config.mm_encoder_attn_backend
backend = get_vit_attn_backend(
self.attn_backend = get_vit_attn_backend(
head_size=head_size,
dtype=dtype,
attn_backend_override=attn_backend_override,
)
self.attn_backend = (
backend
if backend
in {
AttentionBackendEnum.TORCH_SDPA,
AttentionBackendEnum.PALLAS,
AttentionBackendEnum.ROCM_AITER_FA,
AttentionBackendEnum.FLASH_ATTN,
}
else AttentionBackendEnum.TORCH_SDPA
)
self.attn_backend, self._flash_attn_varlen_func = (
maybe_get_vit_flash_attn_backend(
self.attn_backend,
attn_backend_override=attn_backend_override,
)
self._flash_attn_varlen_func = maybe_get_vit_flash_attn_backend(
self.attn_backend,
)
self.is_flash_attn_backend = self.attn_backend in {

View File

@ -0,0 +1,284 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Callable
import torch
from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.attention.ops.vit_attn_wrappers import (
vit_flash_attn_wrapper,
vit_torch_sdpa_wrapper,
)
from vllm.config import MultiModalConfig
from vllm.logger import init_logger
from vllm.model_executor.custom_op import CustomOp
from vllm.model_executor.models.vision import get_vit_attn_backend
logger = init_logger(__name__)
def maybe_get_vit_flash_attn_backend(
attn_backend: AttentionBackendEnum | None,
) -> Callable | None:
# At this point,
# we already have the attn_backend,
# overriding logic is done in the platform-specific implementation.
# so we don't need to override backend here.
# Just return the attn_backend and flash_attn_varlen_func.
if attn_backend == AttentionBackendEnum.FLASH_ATTN:
from vllm.attention.utils.fa_utils import flash_attn_varlen_func
elif attn_backend == AttentionBackendEnum.ROCM_AITER_FA:
from aiter import flash_attn_varlen_func
else:
flash_attn_varlen_func = None
# if attn_backend is TORCH_SDPA,
# it will reach here and the flash_attn_varlen_func will be None.
return flash_attn_varlen_func
@CustomOp.register("mm_encoder_attn")
class MMEncoderAttention(CustomOp):
"""Multi-headed attention without any cache, used for multimodal encoder."""
def __init__(
self,
num_heads: int,
head_size: int,
scale: float | None = None,
num_kv_heads: int | None = None,
prefix: str = "",
multimodal_config: MultiModalConfig | None = None,
) -> None:
"""
Args:
num_heads: number of attention heads per partition.
head_size: hidden_size per attention head.
scale: scale factor.
num_kv_heads: number of kv heads.
prefix: This has no effect, it is only here to make it easier to
swap between Attention and MultiHeadAttention
multimodal_config: configs for multi-modal.
"""
super().__init__()
self.num_heads = num_heads
self.head_size = head_size
self.scale = scale
self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
self.layer_name = prefix
assert self.num_heads % self.num_kv_heads == 0, (
f"num_heads ({self.num_heads}) is not "
f"divisible by num_kv_heads ({self.num_kv_heads})"
)
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
# During model initialization, the default dtype is set as the model
# weight and activation dtype.
dtype = torch.get_default_dtype()
# Try to get vision attention backend from multimodal_config.
attn_backend_override = None
if multimodal_config is not None:
attn_backend_override = multimodal_config.mm_encoder_attn_backend
# Get device-specific vision attention backend.
self.attn_backend = get_vit_attn_backend(
head_size=head_size,
dtype=dtype,
attn_backend_override=attn_backend_override,
)
self.is_flash_attn_backend = self.attn_backend in {
AttentionBackendEnum.FLASH_ATTN,
AttentionBackendEnum.ROCM_AITER_FA,
}
self.flash_attn_varlen_func = maybe_get_vit_flash_attn_backend(
self.attn_backend,
)
logger.info_once(f"Using {self.attn_backend} for MMEncoderAttention.")
@classmethod
def enabled(cls) -> bool:
return True
def reshape_qkv_to_4d(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
bsz: int,
q_len: int,
kv_len: int,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Reshape query, key, value to 4D tensors:
(batch_size, seq_len, num_heads, head_size)
"""
query = query.view(bsz, q_len, self.num_heads, self.head_size)
key = key.view(bsz, kv_len, self.num_kv_heads, self.head_size)
value = value.view(bsz, kv_len, self.num_kv_heads, self.head_size)
if (num_repeat := self.num_queries_per_kv) > 1:
# Handle MQA and GQA
key = torch.repeat_interleave(key, num_repeat, dim=2)
value = torch.repeat_interleave(value, num_repeat, dim=2)
return query, key, value
def reshape_qkv_to_3d(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
bsz: int,
q_len: int,
kv_len: int,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Reshape query, key, value to 3D tensors:
(batch_size * seq_len, num_heads, head_size)
"""
query = query.view(bsz * q_len, self.num_heads, self.head_size)
key = key.view(bsz * kv_len, self.num_kv_heads, self.head_size)
value = value.view(bsz * kv_len, self.num_kv_heads, self.head_size)
if (num_repeat := self.num_queries_per_kv) > 1:
# Handle MQA and GQA
key = torch.repeat_interleave(key, num_repeat, dim=1)
value = torch.repeat_interleave(value, num_repeat, dim=1)
return query, key, value
def _forward_sdpa(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
cu_seqlens: torch.Tensor | None = None,
) -> torch.Tensor:
# TODO(Isotr0py): Migrate MultiHeadAttention
assert cu_seqlens is not None
bsz, q_len = query.size()[:2]
kv_len = key.size(1)
query, key, value = self.reshape_qkv_to_4d(
query, key, value, bsz, q_len, kv_len
)
output = vit_torch_sdpa_wrapper(
q=query,
k=key,
v=value,
cu_seqlens=cu_seqlens,
)
return output
def _forward_fa(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
cu_seqlens: torch.Tensor | None = None,
max_seqlen: torch.Tensor | None = None, # Only used for Flash Attention
) -> torch.Tensor:
assert self.flash_attn_varlen_func is not None, (
"Flash attention function is not set."
)
# # TODO(Isotr0py): Migrate MultiHeadAttention
assert cu_seqlens is not None and max_seqlen is not None
bsz = query.shape[0]
output = vit_flash_attn_wrapper(
q=query,
k=key,
v=value,
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,
batch_size=bsz,
is_rocm_aiter=(self.attn_backend == AttentionBackendEnum.ROCM_AITER_FA),
)
return output
def forward_native(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
cu_seqlens: torch.Tensor | None = None,
max_seqlen: torch.Tensor | None = None, # Only used for Flash Attention
) -> torch.Tensor:
return self._forward_sdpa(query, key, value, cu_seqlens)
def forward_cuda(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
cu_seqlens: torch.Tensor | None = None,
max_seqlen: torch.Tensor | None = None, # Only used for Flash Attention
) -> torch.Tensor:
if self.is_flash_attn_backend:
return self._forward_fa(query, key, value, cu_seqlens, max_seqlen)
elif self.attn_backend == AttentionBackendEnum.TORCH_SDPA:
return self._forward_sdpa(query, key, value, cu_seqlens)
else:
raise ValueError(
f"Unsupported multi-modal encoder attention backend for CUDA: "
f"{self.attn_backend}."
)
def forward_cpu(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
cu_seqlens: torch.Tensor | None = None,
max_seqlen: torch.Tensor | None = None, # Only used for Flash Attention
) -> torch.Tensor:
return self._forward_sdpa(query, key, value, cu_seqlens)
def forward_xpu(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
cu_seqlens: torch.Tensor | None = None,
max_seqlen: torch.Tensor | None = None, # Only used for Flash Attention
) -> torch.Tensor:
assert self.is_flash_attn_backend, (
"XPU only supports FLASH_ATTN for vision attention."
)
return self._forward_fa(query, key, value, cu_seqlens, max_seqlen)
def forward_tpu(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
cu_seqlens: torch.Tensor | None = None,
max_seqlen: torch.Tensor | None = None, # Only used for Flash Attention
) -> torch.Tensor:
assert self.attn_backend == AttentionBackendEnum.PALLAS, (
f"MMEncoderAttention on TPU only supports PALLAS backend, "
f"but got {self.attn_backend}."
)
if cu_seqlens is None:
query, key, value = (x.transpose(1, 2) for x in (query, key, value))
from torch_xla.experimental.custom_kernel import flash_attention
out = flash_attention(query, key, value, sm_scale=self.scale)
out = out.transpose(1, 2)
return out
logger.warning_once(
"PALLAS backend with cu_seqlens is not supported for ViT yet. ",
"Falling back to SDPA implementation.",
)
return self._forward_sdpa(query, key, value, cu_seqlens)

View File

@ -44,9 +44,7 @@ def flash_attn_maxseqlen_wrapper(
dropout_p=0.0,
causal=False,
)
context_layer = einops.rearrange(
output, "(b s) h d -> s b (h d)", b=batch_size
).contiguous()
context_layer = einops.rearrange(output, "(b s) h d -> b s h d", b=batch_size)
return context_layer
@ -59,8 +57,7 @@ def flash_attn_maxseqlen_wrapper_fake(
batch_size: int,
is_rocm_aiter: bool,
) -> torch.Tensor:
b, s, h, d = q.shape
return torch.empty((s, b, h * d), dtype=q.dtype, device=q.device)
return torch.empty_like(q)
direct_register_custom_op(
@ -106,7 +103,6 @@ def torch_sdpa_wrapper(
output_i = einops.rearrange(output_i, "b h s d -> b s h d ")
outputs.append(output_i)
context_layer = torch.cat(outputs, dim=1)
context_layer = einops.rearrange(context_layer, "b s h d -> s b (h d)").contiguous()
return context_layer
@ -116,8 +112,7 @@ def torch_sdpa_wrapper_fake(
v: torch.Tensor,
cu_seqlens: torch.Tensor,
) -> torch.Tensor:
b, s, h, d = q.shape
return torch.empty((s, b, h * d), dtype=q.dtype, device=q.device)
return torch.empty_like(q)
direct_register_custom_op(

View File

@ -5,15 +5,14 @@ from typing import Annotated, Literal, TypeAlias
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import LayerNorm
from transformers.models.qwen2_vl import Qwen2VLProcessor
from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.attention.layer import (
maybe_get_vit_flash_attn_backend,
from vllm.attention.layers.mm_encoder_attention import (
MMEncoderAttention,
)
from vllm.config import VllmConfig
from vllm.config import MultiModalConfig, VllmConfig
from vllm.config.multimodal import BaseDummyOptions
from vllm.distributed import utils as dist_utils
from vllm.distributed.parallel_state import (
@ -254,11 +253,15 @@ class DotsVisionAttention(nn.Module):
bias: bool = True,
*,
quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
prefix: str = "",
use_data_parallel: bool = False,
attn_backend_override: AttentionBackendEnum | None = None,
) -> None:
super().__init__()
use_data_parallel = (
multimodal_config.mm_encoder_tp_mode == "data"
if multimodal_config
else False
)
self.embed_dim = dim
self.tp_size = (
@ -287,31 +290,13 @@ class DotsVisionAttention(nn.Module):
prefix=f"{prefix}.proj",
disable_tp=use_data_parallel,
)
# Select attention backend
self.attn_backend = get_vit_attn_backend(
self.hidden_size_per_attention_head,
torch.get_default_dtype(),
attn_backend_override=attn_backend_override,
)
self.attn_backend, self.flash_attn_varlen_func = (
maybe_get_vit_flash_attn_backend(
self.attn_backend,
attn_backend_override=attn_backend_override,
)
self.attn = MMEncoderAttention(
num_heads=self.num_attention_heads_per_partition,
head_size=self.hidden_size_per_attention_head,
multimodal_config=multimodal_config,
prefix=f"{prefix}.attn",
)
if self.attn_backend not in {
AttentionBackendEnum.FLASH_ATTN,
AttentionBackendEnum.TORCH_SDPA,
AttentionBackendEnum.ROCM_AITER_FA,
}:
raise RuntimeError(
f"Unsupported vision attention backend: {self.attn_backend}"
)
self.is_flash_attn_backend = self.attn_backend in {
AttentionBackendEnum.FLASH_ATTN,
AttentionBackendEnum.ROCM_AITER_FA,
}
def forward(
self,
@ -319,7 +304,7 @@ class DotsVisionAttention(nn.Module):
cu_seqlens: torch.Tensor,
rotary_pos_emb: torch.Tensor | None = None,
*,
max_seqlen: int | None = None,
max_seqlen: torch.Tensor | None = None,
) -> torch.Tensor:
# [S, C] -> [S, B=1, C]
x = hidden_states.unsqueeze(1)
@ -336,41 +321,13 @@ class DotsVisionAttention(nn.Module):
qk_rotated = apply_rotary_pos_emb_vision(qk_concat, rotary_pos_emb)
q, k = torch.chunk(qk_rotated, 2, dim=0)
if self.is_flash_attn_backend:
q_ = q.reshape(bs * q.shape[1], q.shape[2], q.shape[3])
k_ = k.reshape(bs * k.shape[1], k.shape[2], k.shape[3])
v_ = v.reshape(bs * v.shape[1], v.shape[2], v.shape[3])
output = self.flash_attn_varlen_func(
q_,
k_,
v_,
cu_seqlens_q=cu_seqlens,
cu_seqlens_k=cu_seqlens,
max_seqlen_q=max_seqlen,
max_seqlen_k=max_seqlen,
dropout_p=0.0,
causal=False,
)
context_layer = output.view(
bs,
-1,
self.num_attention_heads_per_partition,
self.hidden_size_per_attention_head,
)
elif self.attn_backend == AttentionBackendEnum.TORCH_SDPA:
outputs = []
for i in range(1, len(cu_seqlens)):
s = int(cu_seqlens[i - 1])
e = int(cu_seqlens[i])
q_i = q[:, s:e].permute(0, 2, 1, 3)
k_i = k[:, s:e].permute(0, 2, 1, 3)
v_i = v[:, s:e].permute(0, 2, 1, 3)
out_i = F.scaled_dot_product_attention(q_i, k_i, v_i, dropout_p=0.0)
out_i = out_i.permute(0, 2, 1, 3)
outputs.append(out_i)
context_layer = torch.cat(outputs, dim=1) if outputs else q[:, :0]
else:
raise RuntimeError("Unsupported attention backend")
context_layer = self.attn(
query=q,
key=k,
value=v,
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,
)
# [B,S,H,D] -> [S,B,H*D] -> [S, C]
context_layer = context_layer.permute(1, 0, 2, 3).contiguous()
@ -385,14 +342,19 @@ class DotsSwiGLUFFN(nn.Module):
config,
*,
quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
prefix: str = "",
use_data_parallel: bool = False,
):
super().__init__()
hidden_features = config.intermediate_size
in_features = config.embed_dim
bias = config.use_bias
use_data_parallel = (
multimodal_config.mm_encoder_tp_mode == "data"
if multimodal_config
else False
)
# Referenced aimv2.py AIMv2SwiGLUFFN
self.fc13 = MergedColumnParallelLinear(
in_features,
@ -498,9 +460,8 @@ class DotsVisionBlock(nn.Module):
config,
*,
quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
prefix: str = "",
use_data_parallel: bool = False,
attn_backend_override: AttentionBackendEnum | None = None,
):
super().__init__()
@ -510,16 +471,15 @@ class DotsVisionBlock(nn.Module):
num_heads=config.num_attention_heads,
bias=config.use_bias,
quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.attn",
use_data_parallel=use_data_parallel,
attn_backend_override=attn_backend_override,
)
self.norm1 = RMSNorm(config.embed_dim, eps=config.rms_norm_eps)
self.mlp = DotsSwiGLUFFN(
config,
quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.mlp",
use_data_parallel=use_data_parallel,
)
self.norm2 = RMSNorm(config.embed_dim, eps=config.rms_norm_eps)
@ -546,12 +506,11 @@ class DotsVisionTransformer(nn.Module):
self,
config: DotsVisionConfig,
quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
*,
num_hidden_layers_override: int | None = None,
require_post_norm: bool | None = None,
prefix: str = "",
use_data_parallel: bool = False,
attn_backend_override: AttentionBackendEnum | None = None,
) -> None:
super().__init__()
self.config = config
@ -561,6 +520,11 @@ class DotsVisionTransformer(nn.Module):
head_dim = config.embed_dim // config.num_attention_heads
self.rotary_pos_emb = VisionRotaryEmbedding(head_dim // 2)
attn_backend_override = (
multimodal_config.mm_encoder_attn_backend
if multimodal_config is not None
else None
)
self.attn_backend = get_vit_attn_backend(
head_size=head_dim,
dtype=torch.get_default_dtype(),
@ -578,9 +542,8 @@ class DotsVisionTransformer(nn.Module):
DotsVisionBlock(
config,
quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.blocks.{i}",
use_data_parallel=use_data_parallel,
attn_backend_override=attn_backend_override,
)
for i in range(num_layers)
]
@ -592,6 +555,11 @@ class DotsVisionTransformer(nn.Module):
else:
self.post_trunk_norm = None
use_data_parallel = (
multimodal_config.mm_encoder_tp_mode == "data"
if multimodal_config
else False
)
self.merger = PatchMerger(
dim=config.hidden_size,
context_dim=config.embed_dim,
@ -647,7 +615,7 @@ class DotsVisionTransformer(nn.Module):
self.attn_backend == AttentionBackendEnum.FLASH_ATTN
or self.attn_backend == AttentionBackendEnum.ROCM_AITER_FA
):
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
return max_seqlen
def forward(
@ -733,17 +701,12 @@ class DotsOCRForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA
self.config.vision_config = vision_config
else:
vision_config = self.config.vision_config
attn_backend_override = (
multimodal_config.mm_encoder_attn_backend
if multimodal_config is not None
else None
)
self.vision_tower = DotsVisionTransformer(
vision_config,
quant_config=self.quant_config,
multimodal_config=multimodal_config,
prefix=maybe_prefix(prefix, "vision_tower"),
use_data_parallel=self.use_data_parallel,
attn_backend_override=attn_backend_override,
)
self.language_model: Qwen2ForCausalLM = init_vllm_registered_model(
vllm_config=vllm_config,

View File

@ -37,10 +37,10 @@ from einops import rearrange, repeat
from transformers import BatchFeature
from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.attention.layer import (
maybe_get_vit_flash_attn_backend,
from vllm.attention.layers.mm_encoder_attention import (
MMEncoderAttention,
)
from vllm.config import VllmConfig
from vllm.config import MultiModalConfig, VllmConfig
from vllm.config.multimodal import BaseDummyOptions
from vllm.distributed import parallel_state
from vllm.distributed import utils as dist_utils
@ -163,8 +163,8 @@ class Ernie4_5_VisionAttention(nn.Module):
num_heads: int,
projection_size: int,
quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
prefix: str = "",
attn_backend_override: AttentionBackendEnum | None = None,
) -> None:
super().__init__()
# Per attention head and per partition values.
@ -193,33 +193,13 @@ class Ernie4_5_VisionAttention(nn.Module):
prefix=f"{prefix}.proj",
)
# Detect attention implementation.
self.attn_backend = get_vit_attn_backend(
self.attn = MMEncoderAttention(
num_heads=self.num_attention_heads_per_partition,
head_size=self.hidden_size_per_attention_head,
dtype=torch.get_default_dtype(),
attn_backend_override=attn_backend_override,
multimodal_config=multimodal_config,
prefix=f"{prefix}.attn",
)
self.attn_backend, self.flash_attn_varlen_func = (
maybe_get_vit_flash_attn_backend(
self.attn_backend,
attn_backend_override=attn_backend_override,
)
)
if self.attn_backend not in {
AttentionBackendEnum.FLASH_ATTN,
AttentionBackendEnum.TORCH_SDPA,
AttentionBackendEnum.ROCM_AITER_FA,
}:
raise RuntimeError(
f"Ernie45-VL does not support {self.attn_backend} backend now."
)
self.is_flash_attn_backend = self.attn_backend in {
AttentionBackendEnum.FLASH_ATTN,
AttentionBackendEnum.ROCM_AITER_FA,
}
def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]:
# [s, b, 3 * head * head_dim]
seq_len, bs, _ = qkv.shape
@ -253,14 +233,13 @@ class Ernie4_5_VisionAttention(nn.Module):
x: torch.Tensor,
cu_seqlens: torch.Tensor,
rotary_pos_emb: torch.Tensor,
max_seqlen: int | None = None, # Only used for Flash Attention
max_seqlen: torch.Tensor | None = None, # Only used for Flash Attention
) -> torch.Tensor:
# [s, b, c] --> [s, b, head * 3 * head_dim]
x, _ = self.qkv(x)
# [s, b, 3 * head * head_dim] -> 3 * [s, b, head, head_dim]
q, k, v = self.split_qkv(x)
batch_size = q.shape[1]
q, k, v = (rearrange(x, "s b ... -> b s ...").contiguous() for x in (q, k, v))
if rotary_pos_emb is not None:
@ -268,43 +247,14 @@ class Ernie4_5_VisionAttention(nn.Module):
qk_rotated = apply_rotary_pos_emb_vision(qk_concat, rotary_pos_emb)
q, k = torch.chunk(qk_rotated, 2, dim=0)
if self.is_flash_attn_backend:
q, k, v = (rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v])
output = self.flash_attn_varlen_func(
q,
k,
v,
cu_seqlens_q=cu_seqlens,
cu_seqlens_k=cu_seqlens,
max_seqlen_q=max_seqlen,
max_seqlen_k=max_seqlen,
dropout_p=0.0,
causal=False,
)
context_layer = rearrange(
output, "(b s) h d -> s b (h d)", b=batch_size
).contiguous()
elif self.attn_backend == AttentionBackendEnum.TORCH_SDPA:
# Execute attention entry by entry for speed & less VRAM.
outputs = []
lens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
q_chunks = torch.split(q, lens, dim=1)
k_chunks = torch.split(k, lens, dim=1)
v_chunks = torch.split(v, lens, dim=1)
for q_i, k_i, v_i in zip(q_chunks, k_chunks, v_chunks):
q_i, k_i, v_i = (
rearrange(x, "b s h d -> b h s d") for x in [q_i, k_i, v_i]
)
output_i = F.scaled_dot_product_attention(q_i, k_i, v_i, dropout_p=0.0)
output_i = rearrange(output_i, "b h s d -> b s h d ")
outputs.append(output_i)
context_layer = torch.cat(outputs, dim=1)
context_layer = rearrange(
context_layer, "b s h d -> s b (h d)"
).contiguous()
output = self.attn(
query=q,
key=k,
value=v,
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,
)
context_layer = rearrange(output, "b s h d -> s b (h d)").contiguous()
output, _ = self.proj(context_layer)
return output
@ -350,8 +300,8 @@ class Ernie4_5_VisionBlock(nn.Module):
act_layer: type[nn.Module] = QuickGELU,
norm_layer: Callable[[int], nn.Module] | None = None,
quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
prefix: str = "",
attn_backend_override: AttentionBackendEnum | None = None,
) -> None:
super().__init__()
@ -366,8 +316,8 @@ class Ernie4_5_VisionBlock(nn.Module):
num_heads=num_heads,
projection_size=dim,
quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.attn",
attn_backend_override=attn_backend_override,
)
self.mlp = Ernie4_5_VisionMLP(
@ -383,7 +333,7 @@ class Ernie4_5_VisionBlock(nn.Module):
hidden_states: torch.Tensor,
cu_seqlens: torch.Tensor,
rotary_pos_emb: torch.Tensor,
max_seqlen: int | None = None, # Only used for Flash Attention
max_seqlen: torch.Tensor | None = None, # Only used for Flash Attention
) -> torch.Tensor:
hidden_states = hidden_states + self.attn(
self.norm1(hidden_states),
@ -441,8 +391,8 @@ class Ernie4_5_VisionTransformer(nn.Module):
vision_config,
norm_eps: float = 1e-6,
quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
prefix: str = "",
attn_backend_override: AttentionBackendEnum | None = None,
) -> None:
super().__init__()
patch_size = vision_config.patch_size
@ -477,8 +427,8 @@ class Ernie4_5_VisionTransformer(nn.Module):
mlp_ratio=mlp_ratio,
norm_layer=norm_layer,
quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.blocks.{layer_idx}",
attn_backend_override=attn_backend_override,
)
for layer_idx in range(depth)
]
@ -489,6 +439,9 @@ class Ernie4_5_VisionTransformer(nn.Module):
)
self.ln = nn.LayerNorm(hidden_size, eps=1e-6)
attn_backend_override = (
multimodal_config.mm_encoder_attn_backend if multimodal_config else None
)
self.attn_backend = get_vit_attn_backend(
head_size=head_dim,
dtype=torch.get_default_dtype(),
@ -535,13 +488,13 @@ class Ernie4_5_VisionTransformer(nn.Module):
rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
return rotary_pos_emb
def compute_attn_mask_seqlen(self, cu_seqlens: torch.Tensor) -> int | None:
def compute_attn_mask_seqlen(self, cu_seqlens: torch.Tensor) -> torch.Tensor | None:
max_seqlen = None
if (
self.attn_backend == AttentionBackendEnum.FLASH_ATTN
or self.attn_backend == AttentionBackendEnum.ROCM_AITER_FA
):
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
return max_seqlen
def forward(
@ -1304,17 +1257,12 @@ class Ernie4_5_VLMoeForConditionalGeneration(
self.config = config
self.multimodal_config = multimodal_config
attn_backend_override = (
multimodal_config.mm_encoder_attn_backend
if multimodal_config is not None
else None
)
self.vision_model = Ernie4_5_VisionTransformer(
config.vision_config,
norm_eps=getattr(config, "rms_norm_eps", 1e-6),
quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=maybe_prefix(prefix, "vision_model"),
attn_backend_override=attn_backend_override,
)
self.language_model = Ernie4_5_VLMoeForCausalLM(

View File

@ -47,8 +47,10 @@ from transformers.models.glm4v.video_processing_glm4v import Glm4vVideoProcessor
from transformers.video_utils import VideoMetadata
from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.attention.layer import maybe_get_vit_flash_attn_backend
from vllm.config import VllmConfig
from vllm.attention.layers.mm_encoder_attention import (
MMEncoderAttention,
)
from vllm.config import MultiModalConfig, VllmConfig
from vllm.config.multimodal import BaseDummyOptions, VideoDummyOptions
from vllm.distributed import get_tensor_model_parallel_world_size, parallel_state
from vllm.distributed import utils as dist_utils
@ -191,10 +193,15 @@ class Glm4vVisionMLP(nn.Module):
hidden_features: int,
bias: bool = False,
quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
prefix: str = "",
use_data_parallel: bool = False,
):
super().__init__()
use_data_parallel = (
multimodal_config.mm_encoder_tp_mode == "data"
if multimodal_config
else False
)
self.gate_up_proj = MergedColumnParallelLinear(
input_size=in_features,
output_sizes=[hidden_features] * 2,
@ -248,12 +255,16 @@ class Glm4vVisionAttention(nn.Module):
num_heads: int,
projection_size: int,
quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
prefix: str = "",
use_data_parallel: bool = False,
attn_backend_override: AttentionBackendEnum | None = None,
) -> None:
super().__init__()
# Per attention head and per partition values.
use_data_parallel = (
multimodal_config.mm_encoder_tp_mode == "data"
if multimodal_config
else False
)
self.tp_size = (
1 if use_data_parallel else get_tensor_model_parallel_world_size()
)
@ -287,34 +298,12 @@ class Glm4vVisionAttention(nn.Module):
disable_tp=use_data_parallel,
)
# Detect attention implementation.
self.attn_backend = get_vit_attn_backend(
self.attn = MMEncoderAttention(
num_heads=self.num_attention_heads_per_partition,
head_size=self.hidden_size_per_attention_head,
dtype=torch.get_default_dtype(),
attn_backend_override=attn_backend_override,
multimodal_config=multimodal_config,
)
self.attn_backend, self.flash_attn_varlen_func = (
maybe_get_vit_flash_attn_backend(
self.attn_backend,
attn_backend_override=attn_backend_override,
)
)
if self.attn_backend not in {
AttentionBackendEnum.FLASH_ATTN,
AttentionBackendEnum.TORCH_SDPA,
AttentionBackendEnum.ROCM_AITER_FA,
}:
raise RuntimeError(
f"GLM-4V does not support {self.attn_backend} backend now."
)
self.is_flash_attn_backend = self.attn_backend in {
AttentionBackendEnum.FLASH_ATTN,
AttentionBackendEnum.ROCM_AITER_FA,
}
def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]:
# [s, b, 3 * head * head_dim]
seq_len, bs, _ = qkv.shape
@ -338,14 +327,13 @@ class Glm4vVisionAttention(nn.Module):
cu_seqlens: torch.Tensor,
rotary_pos_emb_cos: torch.Tensor,
rotary_pos_emb_sin: torch.Tensor,
max_seqlen: int | None = None, # Only used for Flash Attention
max_seqlen: torch.Tensor | None = None, # Only used for Flash Attention
) -> torch.Tensor:
# [s, b, c] --> [s, b, head * 3 * head_dim]
x, _ = self.qkv(x)
# [s, b, 3 * head * head_dim] -> 3 * [s, b, head, head_dim]
q, k, v = self.split_qkv(x)
batch_size = q.shape[1]
q, k, v = (rearrange(x, "s b ... -> b s ...").contiguous() for x in (q, k, v))
if rotary_pos_emb_cos is not None and rotary_pos_emb_sin is not None:
@ -356,43 +344,14 @@ class Glm4vVisionAttention(nn.Module):
)
q, k = torch.chunk(qk_rotated, 2, dim=0)
if self.is_flash_attn_backend:
q, k, v = (rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v])
output = self.flash_attn_varlen_func(
q,
k,
v,
cu_seqlens_q=cu_seqlens,
cu_seqlens_k=cu_seqlens,
max_seqlen_q=max_seqlen,
max_seqlen_k=max_seqlen,
dropout_p=0.0,
causal=False,
)
context_layer = rearrange(
output, "(b s) h d -> s b (h d)", b=batch_size
).contiguous()
elif self.attn_backend == AttentionBackendEnum.TORCH_SDPA:
# Execute attention entry by entry for speed & less VRAM.
outputs = []
lens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
q_chunks = torch.split(q, lens, dim=1)
k_chunks = torch.split(k, lens, dim=1)
v_chunks = torch.split(v, lens, dim=1)
for q_i, k_i, v_i in zip(q_chunks, k_chunks, v_chunks):
q_i, k_i, v_i = (
rearrange(x, "b s h d -> b h s d") for x in [q_i, k_i, v_i]
)
output_i = F.scaled_dot_product_attention(q_i, k_i, v_i, dropout_p=0.0)
output_i = rearrange(output_i, "b h s d -> b s h d ")
outputs.append(output_i)
context_layer = torch.cat(outputs, dim=1)
context_layer = rearrange(
context_layer, "b s h d -> s b (h d)"
).contiguous()
context_layer = self.attn(
query=q,
key=k,
value=v,
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,
)
context_layer = rearrange(context_layer, "b s h d -> s b (h d)").contiguous()
output, _ = self.proj(context_layer)
return output
@ -406,9 +365,8 @@ class Glm4vVisionBlock(nn.Module):
mlp_hidden_dim: int,
norm_layer: Callable[[int], nn.Module] | None = None,
quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
prefix: str = "",
use_data_parallel: bool = False,
attn_backend_override: AttentionBackendEnum | None = None,
) -> None:
super().__init__()
if norm_layer is None:
@ -420,17 +378,16 @@ class Glm4vVisionBlock(nn.Module):
num_heads=num_heads,
projection_size=dim,
quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.attn",
use_data_parallel=use_data_parallel,
attn_backend_override=attn_backend_override,
)
self.mlp = Glm4vVisionMLP(
dim,
mlp_hidden_dim,
bias=False,
quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.mlp",
use_data_parallel=use_data_parallel,
)
def forward(
@ -489,11 +446,16 @@ class Glm4vPatchMerger(nn.Module):
d_model: int,
context_dim: int,
quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
bias: bool = False,
prefix: str = "",
use_data_parallel: bool = False,
) -> None:
super().__init__()
use_data_parallel = (
multimodal_config.mm_encoder_tp_mode == "data"
if multimodal_config
else False
)
self.hidden_size = d_model
self.proj = ColumnParallelLinear(
self.hidden_size,
@ -649,19 +611,19 @@ class Glm4vVisionTransformer(nn.Module):
vision_config: Glm4vVisionConfig,
norm_eps: float = 1e-6,
quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
prefix: str = "",
use_data_parallel: bool = False,
attn_backend_override: AttentionBackendEnum | None = None,
) -> None:
super().__init__()
assert multimodal_config is not None, "multimodal_config must be provided"
patch_size = vision_config.patch_size
temporal_patch_size = vision_config.temporal_patch_size
in_channels = vision_config.in_channels
depth = vision_config.depth
self.hidden_size = vision_config.hidden_size
self.num_heads = vision_config.num_heads
self.use_data_parallel = use_data_parallel
self.patch_size = vision_config.patch_size
self.spatial_merge_size = vision_config.spatial_merge_size
@ -690,9 +652,8 @@ class Glm4vVisionTransformer(nn.Module):
mlp_hidden_dim=vision_config.out_hidden_size,
norm_layer=norm_layer,
quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.blocks.{layer_idx}",
use_data_parallel=self.use_data_parallel,
attn_backend_override=attn_backend_override,
)
for layer_idx in range(depth)
]
@ -701,9 +662,9 @@ class Glm4vVisionTransformer(nn.Module):
d_model=vision_config.out_hidden_size,
context_dim=vision_config.intermediate_size,
quant_config=quant_config,
multimodal_config=multimodal_config,
bias=False,
prefix=f"{prefix}.merger",
use_data_parallel=self.use_data_parallel,
)
self.embeddings = Glm4vVisionEmbeddings(vision_config)
@ -723,7 +684,7 @@ class Glm4vVisionTransformer(nn.Module):
self.attn_backend = get_vit_attn_backend(
head_size=head_dim,
dtype=torch.get_default_dtype(),
attn_backend_override=attn_backend_override,
attn_backend_override=multimodal_config.mm_encoder_attn_backend,
)
@property
@ -775,13 +736,13 @@ class Glm4vVisionTransformer(nn.Module):
def compute_attn_mask_seqlen(
self,
cu_seqlens: torch.Tensor,
) -> int | None:
) -> torch.Tensor | None:
max_seqlen = None
if (
self.attn_backend == AttentionBackendEnum.FLASH_ATTN
or self.attn_backend == AttentionBackendEnum.ROCM_AITER_FA
):
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
return max_seqlen
def forward(
@ -1465,18 +1426,12 @@ class Glm4vForConditionalGeneration(
self.multimodal_config = multimodal_config
self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data"
attn_backend_override = (
multimodal_config.mm_encoder_attn_backend
if multimodal_config is not None
else None
)
self.visual = Glm4vVisionTransformer(
config.vision_config,
norm_eps=getattr(config, "rms_norm_eps", 1e-5),
quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=maybe_prefix(prefix, "visual"),
use_data_parallel=self.use_data_parallel,
attn_backend_override=attn_backend_override,
)
if config.model_type == "glm4v":

View File

@ -9,7 +9,6 @@ from typing import Annotated, Any, Literal, TypeAlias, TypeVar
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from transformers import PretrainedConfig
from transformers.activations import GELUActivation
@ -17,11 +16,10 @@ from transformers.feature_extraction_utils import BatchFeature
from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
from transformers.utils import torch_int
from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.attention.layer import (
maybe_get_vit_flash_attn_backend,
from vllm.attention.layers.mm_encoder_attention import (
MMEncoderAttention,
)
from vllm.config import VllmConfig
from vllm.config import MultiModalConfig, VllmConfig
from vllm.config.multimodal import BaseDummyOptions
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.logger import init_logger
@ -80,7 +78,6 @@ from .utils import (
is_pp_missing_parameter,
maybe_prefix,
)
from .vision import get_vit_attn_backend
logger = init_logger(__name__)
@ -369,8 +366,8 @@ class KeyeSiglipAttention(nn.Module):
self,
config: PretrainedConfig,
quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
prefix: str = "",
attn_backend_override: AttentionBackendEnum | None = None,
):
super().__init__()
self.config = config
@ -408,34 +405,14 @@ class KeyeSiglipAttention(nn.Module):
prefix=f"{prefix}.out_proj",
)
# Detect attention implementation.
self.attn_backend = get_vit_attn_backend(
self.attn = MMEncoderAttention(
num_heads=self.num_heads,
head_size=self.head_dim,
dtype=torch.get_default_dtype(),
attn_backend_override=attn_backend_override,
num_kv_heads=self.num_kv_heads,
prefix=f"{prefix}.attn",
multimodal_config=multimodal_config,
)
self.attn_backend, self.flash_attn_varlen_func = (
maybe_get_vit_flash_attn_backend(
self.attn_backend,
attn_backend_override=attn_backend_override,
)
)
if self.attn_backend not in {
AttentionBackendEnum.FLASH_ATTN,
AttentionBackendEnum.TORCH_SDPA,
AttentionBackendEnum.ROCM_AITER_FA,
}:
raise RuntimeError(
f"Keye-VL does not support {self.attn_backend} backend now."
)
self.is_flash_attn_backend = self.attn_backend in {
AttentionBackendEnum.FLASH_ATTN,
AttentionBackendEnum.ROCM_AITER_FA,
}
def forward(
self,
hidden_states: torch.Tensor,
@ -450,8 +427,7 @@ class KeyeSiglipAttention(nn.Module):
dim=-1,
)
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
batch_size = q.shape[0]
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
if rope_emb is None:
q = q.view(*q.shape[:-1], self.num_heads, self.head_dim)
@ -482,38 +458,14 @@ class KeyeSiglipAttention(nn.Module):
self.head_dim,
)
if self.is_flash_attn_backend:
q, k, v = (rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v])
output = self.flash_attn_varlen_func(
q,
k,
v,
cu_seqlens_q=cu_seqlens,
cu_seqlens_k=cu_seqlens,
max_seqlen_q=max_seqlen,
max_seqlen_k=max_seqlen,
causal=False,
softmax_scale=self.scale,
)
context_layer = rearrange(output, "(b s) ... -> b s ...", b=batch_size)
elif self.attn_backend == AttentionBackendEnum.TORCH_SDPA:
outputs = []
for i in range(1, len(cu_seqlens)):
start_idx = cu_seqlens[i - 1]
end_idx = cu_seqlens[i]
q_i = q[:, start_idx:end_idx]
k_i = k[:, start_idx:end_idx]
v_i = v[:, start_idx:end_idx]
q_i, k_i, v_i = (
rearrange(x, "b s h d -> b h s d") for x in (q_i, k_i, v_i)
)
output_i = F.scaled_dot_product_attention(q_i, k_i, v_i, dropout_p=0.0)
output_i = rearrange(output_i, "b h s d -> b s h d ")
outputs.append(output_i)
context_layer = torch.cat(outputs, dim=1) if outputs else q[:, :0]
context_layer = rearrange(context_layer, "b s h d -> b s (h d)").contiguous()
context_layer = self.attn(
query=q,
key=k,
value=v,
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,
)
context_layer = rearrange(context_layer, "b s h d -> b s (h d)")
output, _ = self.out_proj(context_layer)
return output
@ -547,8 +499,8 @@ class KeyeSiglipEncoderLayer(nn.Module):
self,
config: PretrainedConfig,
quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
prefix: str = "",
attn_backend_override: AttentionBackendEnum | None = None,
):
super().__init__()
self.embed_dim = config.hidden_size
@ -556,8 +508,8 @@ class KeyeSiglipEncoderLayer(nn.Module):
self.self_attn = KeyeSiglipAttention(
config,
quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.self_attn",
attn_backend_override=attn_backend_override,
)
self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
self.mlp = SiglipMLP(
@ -601,8 +553,8 @@ class KeyeSiglipEncoder(nn.Module):
self,
config: PretrainedConfig,
quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
prefix: str = "",
attn_backend_override: AttentionBackendEnum | None = None,
):
super().__init__()
self.config = config
@ -614,8 +566,8 @@ class KeyeSiglipEncoder(nn.Module):
KeyeSiglipEncoderLayer(
config,
quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.layers.{layer_idx}",
attn_backend_override=attn_backend_override,
)
for layer_idx in range(config.num_hidden_layers)
]
@ -696,8 +648,8 @@ class KeyeSiglipVisionTransformer(nn.Module):
self,
config: PretrainedConfig,
quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
prefix: str = "",
attn_backend_override: AttentionBackendEnum | None = None,
):
super().__init__()
self.config = config
@ -707,8 +659,8 @@ class KeyeSiglipVisionTransformer(nn.Module):
self.encoder = KeyeSiglipEncoder(
config,
quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.encoder",
attn_backend_override=attn_backend_override,
)
self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
@ -779,16 +731,16 @@ class KeyeSiglipVisionModel(nn.Module):
self,
config: PretrainedConfig,
quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
prefix: str = "",
attn_backend_override: AttentionBackendEnum | None = None,
):
super().__init__()
self.vision_model = KeyeSiglipVisionTransformer(
config,
quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.vision_model",
attn_backend_override=attn_backend_override,
)
self.quant_config = quant_config
@ -1329,16 +1281,11 @@ class BaseKeyeModule(nn.Module):
self.config = config
self.multimodal_config = multimodal_config
attn_backend_override = (
multimodal_config.mm_encoder_attn_backend
if multimodal_config is not None
else None
)
self.visual = KeyeSiglipVisionModel(
config.vision_config,
quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=maybe_prefix(prefix, "visual"),
attn_backend_override=attn_backend_override,
)
self.mlp_AR = self._build_projector(

View File

@ -240,18 +240,12 @@ class OpenCUAForConditionalGeneration(Qwen2_5_VLForConditionalGeneration):
)
if multimodal_config.get_limit_per_prompt("image"):
attn_backend_override = (
multimodal_config.mm_encoder_attn_backend
if multimodal_config is not None
else None
)
self.visual = OpenCUAVisionTransformer(
vision_config=config.vision_config,
norm_eps=getattr(config, "rms_norm_eps", 1e-6),
quant_config=self.quant_config,
multimodal_config=self.multimodal_config,
prefix=maybe_prefix(prefix, "visual"),
use_data_parallel=self.use_data_parallel,
attn_backend_override=attn_backend_override,
)
else:
self.visual = None

View File

@ -10,8 +10,7 @@ import torch
import torch.nn as nn
from transformers import BaseImageProcessor, BatchFeature, PretrainedConfig
from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.config import VllmConfig
from vllm.config import MultiModalConfig, VllmConfig
from vllm.config.multimodal import BaseDummyOptions
from vllm.model_executor.layers.linear import ReplicatedLinear
from vllm.model_executor.layers.quantization import QuantizationConfig
@ -104,18 +103,16 @@ class VisualTokenizer(torch.nn.Module):
config: PretrainedConfig,
visual_vocab_size: int,
quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
prefix: str = "",
use_data_parallel: bool = False,
attn_backend_override: AttentionBackendEnum | None = None,
):
super().__init__()
self.config = config
self.vit = self._init_backbone(
config=config,
quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.vit",
use_data_parallel=use_data_parallel,
attn_backend_override=attn_backend_override,
)
# reserved tokens for INDICATOR_IDS
head_dim = visual_vocab_size - len(INDICATOR_IDS)
@ -133,18 +130,16 @@ class VisualTokenizer(torch.nn.Module):
self,
config: PretrainedConfig,
quant_config: QuantizationConfig | None = None,
multimodal_config: QuantizationConfig | None = None,
prefix: str = "",
use_data_parallel: bool = False,
attn_backend_override: AttentionBackendEnum | None = None,
):
model_type = config.model_type
if model_type == "siglip2_navit":
return Siglip2NavitModel(
config=config,
quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=prefix,
use_data_parallel=use_data_parallel,
attn_backend_override=attn_backend_override,
)
raise ValueError(f"Unsupported visual tokenizer model_type: {model_type}")
@ -468,17 +463,12 @@ class Ovis2_5(nn.Module, SupportsMultiModal, SupportsPP):
prefix=maybe_prefix(prefix, "llm"),
)
attn_backend_override = (
multimodal_config.mm_encoder_attn_backend
if multimodal_config is not None
else None
)
self.visual_tokenizer = VisualTokenizer(
config=config.vit_config,
visual_vocab_size=config.visual_vocab_size,
multimodal_config=multimodal_config,
quant_config=quant_config,
prefix=f"{prefix}.visual_tokenizer",
attn_backend_override=attn_backend_override,
)
self.vte = VisualEmbedding(config.visual_vocab_size, config.hidden_size)

View File

@ -22,7 +22,6 @@ from typing import Annotated, Literal
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange, repeat
from transformers import BatchFeature, PretrainedConfig
from transformers.activations import GELUActivation
@ -32,13 +31,10 @@ from transformers.modeling_outputs import (
from transformers.utils import torch_int
from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.attention.layer import (
maybe_get_vit_flash_attn_backend,
from vllm.attention.layers.mm_encoder_attention import (
MMEncoderAttention,
)
from vllm.attention.ops.vit_attn_wrappers import (
vit_flash_attn_wrapper,
)
from vllm.config import VllmConfig
from vllm.config import MultiModalConfig, VllmConfig
from vllm.config.multimodal import BaseDummyOptions
from vllm.distributed import parallel_state
from vllm.distributed import utils as dist_utils
@ -578,9 +574,8 @@ class SiglipAttention(nn.Module):
num_heads: int,
projection_size: int,
quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
prefix: str = "",
attn_backend: AttentionBackendEnum = AttentionBackendEnum.TORCH_SDPA,
attn_backend_override: AttentionBackendEnum | None = None,
) -> None:
super().__init__()
@ -608,18 +603,12 @@ class SiglipAttention(nn.Module):
quant_config=quant_config,
prefix=f"{prefix}.out_proj",
)
self.attn_backend = attn_backend
self.attn_backend, self.flash_attn_varlen_func = (
maybe_get_vit_flash_attn_backend(
self.attn_backend,
attn_backend_override=attn_backend_override,
)
self.attn = MMEncoderAttention(
num_heads=self.num_attention_heads_per_partition,
head_size=self.hidden_size_per_attention_head,
multimodal_config=multimodal_config,
prefix=f"{prefix}.attn",
)
self.is_flash_attn_backend = self.attn_backend in {
AttentionBackendEnum.FLASH_ATTN,
AttentionBackendEnum.ROCM_AITER_FA,
}
def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]:
seq_len, bs, _ = qkv.shape
@ -665,44 +654,16 @@ class SiglipAttention(nn.Module):
qk_rotated = apply_rotary_pos_emb_vision(qk_concat, rotary_pos_emb)
q, k = torch.chunk(qk_rotated, 2, dim=0)
if self.is_flash_attn_backend:
if max_seqlen is None:
raise ValueError("Flash attention backend requires max_seqlen.")
context_layer = vit_flash_attn_wrapper(
q,
k,
v,
cu_seqlens,
max_seqlen,
batch_size,
self.attn_backend == AttentionBackendEnum.ROCM_AITER_FA,
)
elif self.attn_backend == AttentionBackendEnum.TORCH_SDPA:
outputs = []
for i in range(1, len(cu_seqlens)):
start_idx = cu_seqlens[i - 1]
end_idx = cu_seqlens[i]
q_i = q[:, start_idx:end_idx]
k_i = k[:, start_idx:end_idx]
v_i = v[:, start_idx:end_idx]
q_i, k_i, v_i = (
rearrange(tensor, "b s h d -> b h s d")
for tensor in (q_i, k_i, v_i)
)
output_i = F.scaled_dot_product_attention(q_i, k_i, v_i, dropout_p=0.0)
output_i = rearrange(output_i, "b h s d -> b s h d")
outputs.append(output_i)
context_layer = torch.cat(outputs, dim=1)
context_layer = rearrange(
context_layer, "b s h d -> s b (h d)"
).contiguous()
else:
raise RuntimeError(
f"PaddleOCR-VL does not support {self.attn_backend} backend now."
)
context_layer = self.attn(
query=q,
key=k,
value=v,
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,
)
context_layer = rearrange(context_layer, "b s h d -> b s (h d)")
output, _ = self.out_proj(context_layer)
output = rearrange(output, "s b d -> b s d")
return output
@ -774,10 +735,8 @@ class SiglipEncoderLayer(nn.Module):
self,
config: PretrainedConfig,
quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
prefix: str = "",
*,
attn_backend: AttentionBackendEnum = AttentionBackendEnum.TORCH_SDPA,
attn_backend_override: AttentionBackendEnum | None = None,
):
super().__init__()
self.embed_dim = config.hidden_size
@ -787,9 +746,8 @@ class SiglipEncoderLayer(nn.Module):
num_heads=config.num_attention_heads,
projection_size=config.hidden_size,
quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.self_attn",
attn_backend=attn_backend,
attn_backend_override=attn_backend_override,
)
self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
self.mlp = SiglipMLP(
@ -832,14 +790,18 @@ class SiglipEncoder(nn.Module):
self,
config: PretrainedConfig,
quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
prefix: str = "",
attn_backend_override: AttentionBackendEnum | None = None,
):
super().__init__()
self.config = config
embed_dim = config.hidden_size
num_heads = config.num_attention_heads
head_dim = embed_dim // num_heads
attn_backend_override = (
multimodal_config.mm_encoder_attn_backend if multimodal_config else None
)
self.attn_backend = get_vit_attn_backend(
head_size=head_dim,
dtype=torch.get_default_dtype(),
@ -858,9 +820,8 @@ class SiglipEncoder(nn.Module):
SiglipEncoderLayer(
config,
quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.layers.{layer_idx}",
attn_backend=self.attn_backend,
attn_backend_override=attn_backend_override,
)
for layer_idx in range(config.num_hidden_layers)
]
@ -941,8 +902,8 @@ class SiglipVisionTransformer(nn.Module):
self,
config: PretrainedConfig,
quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
prefix: str = "",
attn_backend_override: AttentionBackendEnum | None = None,
):
super().__init__()
self.config = config
@ -952,8 +913,8 @@ class SiglipVisionTransformer(nn.Module):
self.encoder = SiglipEncoder(
config,
quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.encoder",
attn_backend_override=attn_backend_override,
)
self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
@ -991,16 +952,16 @@ class SiglipVisionModel(nn.Module):
self,
config,
quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
prefix: str = "",
attn_backend_override: AttentionBackendEnum | None = None,
):
super().__init__()
self.vision_model = SiglipVisionTransformer(
config,
quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.vision_model",
attn_backend_override=attn_backend_override,
)
self.quant_config = quant_config
@ -1119,17 +1080,11 @@ class PaddleOCRVLForConditionalGeneration(nn.Module, SupportsMultiModal, Support
self.config = config
self.multimodal_config = multimodal_config
attn_backend_override = (
multimodal_config.mm_encoder_attn_backend
if multimodal_config is not None
else None
)
self.visual = SiglipVisionModel(
config=config.vision_config,
quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=maybe_prefix(prefix, "visual"),
attn_backend_override=attn_backend_override,
)
self.mlp_AR = Projector(config, config.vision_config)

View File

@ -845,6 +845,7 @@ class Qwen2_5OmniThinkerForConditionalGeneration(
norm_eps=getattr(thinker_config.text_config, "rms_norm_eps", 1e-6),
quant_config=quant_config,
prefix=maybe_prefix(prefix, "visual"),
multimodal_config=multimodal_config,
)
else:
self.visual = None

View File

@ -42,13 +42,9 @@ from transformers.models.qwen2_5_vl.configuration_qwen2_5_vl import (
)
from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.attention.layer import maybe_get_vit_flash_attn_backend
from vllm.attention.ops.vit_attn_wrappers import (
vit_flash_attn_wrapper,
vit_torch_sdpa_wrapper,
)
from vllm.attention.layers.mm_encoder_attention import MMEncoderAttention
from vllm.compilation.decorators import support_torch_compile
from vllm.config import VllmConfig
from vllm.config import MultiModalConfig, VllmConfig
from vllm.distributed import parallel_state
from vllm.distributed import utils as dist_utils
from vllm.forward_context import set_forward_context
@ -267,10 +263,15 @@ class Qwen2_5_VisionMLP(nn.Module):
bias: bool = False,
act_fn: Callable[[torch.Tensor], torch.Tensor] = F.silu,
quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
prefix: str = "",
use_data_parallel: bool = False,
):
super().__init__()
use_data_parallel = (
multimodal_config.mm_encoder_tp_mode == "data"
if multimodal_config
else False
)
self.gate_up_proj = MergedColumnParallelLinear(
input_size=in_features,
output_sizes=[hidden_features] * 2, # [gate_proj, up_proj]
@ -304,13 +305,16 @@ class Qwen2_5_VisionAttention(nn.Module):
num_heads: int,
projection_size: int,
quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
prefix: str = "",
use_data_parallel: bool = False,
attn_backend: AttentionBackendEnum = AttentionBackendEnum.TORCH_SDPA,
attn_backend_override: AttentionBackendEnum | None = None,
) -> None:
super().__init__()
# Per attention head and per partition values.
use_data_parallel = (
multimodal_config.mm_encoder_tp_mode == "data"
if multimodal_config
else False
)
self.tp_size = (
1
if use_data_parallel
@ -342,18 +346,12 @@ class Qwen2_5_VisionAttention(nn.Module):
prefix=f"{prefix}.proj",
disable_tp=use_data_parallel,
)
self.attn_backend = attn_backend
self.attn_backend, self.flash_attn_varlen_func = (
maybe_get_vit_flash_attn_backend(
self.attn_backend,
attn_backend_override=attn_backend_override,
)
)
self.is_flash_attn_backend = self.attn_backend in {
AttentionBackendEnum.FLASH_ATTN,
AttentionBackendEnum.ROCM_AITER_FA,
}
self.attn = MMEncoderAttention(
num_heads=self.num_attention_heads_per_partition,
head_size=self.hidden_size_per_attention_head,
multimodal_config=multimodal_config,
)
def forward(
self,
@ -394,32 +392,17 @@ class Qwen2_5_VisionAttention(nn.Module):
else:
q, k, v = qkv.unbind(dim=2)
if self.is_flash_attn_backend:
context_layer = vit_flash_attn_wrapper(
q,
k,
v,
cu_seqlens,
max_seqlen,
batch_size,
self.attn_backend == AttentionBackendEnum.ROCM_AITER_FA,
)
elif self.attn_backend == AttentionBackendEnum.TORCH_SDPA:
# Execute attention entry by entry for speed & less VRAM.
from vllm.platforms import current_platform
context_layer = self.attn(
query=q,
key=k,
value=v,
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,
)
# Never remove the next contiguous logic
# Without it, hallucinations occur with the backend
if current_platform.is_rocm():
q = q.contiguous()
k = k.contiguous()
v = v.contiguous()
context_layer = vit_torch_sdpa_wrapper(
q,
k,
v,
cu_seqlens,
)
context_layer = einops.rearrange(
context_layer, "b s h d -> s b (h d)", b=batch_size
).contiguous()
output, _ = self.proj(context_layer)
return output
@ -443,10 +426,8 @@ class Qwen2_5_VisionBlock(nn.Module):
act_fn: Callable[[torch.Tensor], torch.Tensor] = F.silu,
norm_layer: Callable[[int], nn.Module] | None = None,
quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
prefix: str = "",
use_data_parallel: bool = False,
attn_backend: AttentionBackendEnum = AttentionBackendEnum.TORCH_SDPA,
attn_backend_override: AttentionBackendEnum | None = None,
) -> None:
super().__init__()
if norm_layer is None:
@ -458,10 +439,8 @@ class Qwen2_5_VisionBlock(nn.Module):
num_heads=num_heads,
projection_size=dim,
quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.attn",
use_data_parallel=use_data_parallel,
attn_backend=attn_backend,
attn_backend_override=attn_backend_override,
)
self.mlp = Qwen2_5_VisionMLP(
dim,
@ -469,8 +448,8 @@ class Qwen2_5_VisionBlock(nn.Module):
act_fn=act_fn,
bias=True,
quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.mlp",
use_data_parallel=use_data_parallel,
)
def forward(
@ -542,10 +521,15 @@ class Qwen2_5_VisionPatchMerger(nn.Module):
norm_layer: Callable[[int], nn.Module] | None = None,
spatial_merge_size: int = 2,
quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
prefix: str = "",
use_data_parallel: bool = False,
) -> None:
super().__init__()
use_data_parallel = (
multimodal_config.mm_encoder_tp_mode == "data"
if multimodal_config
else False
)
self.hidden_size = context_dim * (spatial_merge_size**2)
if norm_layer is None:
norm_layer = partial(nn.LayerNorm, eps=1e-6)
@ -586,9 +570,8 @@ class Qwen2_5_VisionTransformer(nn.Module):
vision_config: Qwen2_5_VLVisionConfig,
norm_eps: float = 1e-6,
quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
prefix: str = "",
use_data_parallel: bool = False,
attn_backend_override: AttentionBackendEnum | None = None,
) -> None:
super().__init__()
@ -598,7 +581,6 @@ class Qwen2_5_VisionTransformer(nn.Module):
depth = vision_config.depth
self.hidden_size = vision_config.hidden_size
self.num_heads = vision_config.num_heads
self.use_data_parallel = use_data_parallel
self.out_hidden_size = vision_config.out_hidden_size
# args for get_window_index_thw
@ -629,19 +611,17 @@ class Qwen2_5_VisionTransformer(nn.Module):
rope_parameters={"partial_rotary_factor": 0.5},
)
attn_backend_override = (
multimodal_config.mm_encoder_attn_backend
if multimodal_config is not None
else None
)
self.attn_backend = get_vit_attn_backend(
head_size=head_dim,
dtype=torch.get_default_dtype(),
attn_backend_override=attn_backend_override,
)
self.attn_backend, self.flash_attn_varlen_func = (
maybe_get_vit_flash_attn_backend(
self.attn_backend,
attn_backend_override=attn_backend_override,
)
)
if self.attn_backend not in {
AttentionBackendEnum.FLASH_ATTN,
AttentionBackendEnum.TORCH_SDPA,
@ -661,10 +641,8 @@ class Qwen2_5_VisionTransformer(nn.Module):
act_fn=get_act_and_mul_fn(vision_config.hidden_act),
norm_layer=norm_layer,
quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.blocks.{layer_idx}",
use_data_parallel=use_data_parallel,
attn_backend=self.attn_backend,
attn_backend_override=attn_backend_override,
)
for layer_idx in range(depth)
]
@ -677,8 +655,8 @@ class Qwen2_5_VisionTransformer(nn.Module):
norm_layer=norm_layer,
spatial_merge_size=self.spatial_merge_size,
quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.merger",
use_data_parallel=use_data_parallel,
)
@property
@ -1200,18 +1178,12 @@ class Qwen2_5_VLForConditionalGeneration(
if multimodal_config.get_limit_per_prompt(
"image"
) or multimodal_config.get_limit_per_prompt("video"):
attn_backend_override = (
multimodal_config.mm_encoder_attn_backend
if multimodal_config is not None
else None
)
self.visual = Qwen2_5_VisionTransformer(
vision_config=config.vision_config,
norm_eps=getattr(config, "rms_norm_eps", 1e-6),
quant_config=self.quant_config,
prefix=maybe_prefix(prefix, "visual"),
use_data_parallel=self.use_data_parallel,
attn_backend_override=attn_backend_override,
multimodal_config=multimodal_config,
)
else:
self.visual = None

View File

@ -33,7 +33,6 @@ from typing import Annotated, Any, Literal, TypeAlias
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from transformers import BatchFeature
from transformers.models.qwen2_vl import Qwen2VLImageProcessor, Qwen2VLProcessor
@ -45,10 +44,8 @@ from transformers.models.qwen2_vl.image_processing_qwen2_vl import smart_resize
from transformers.models.qwen2_vl.video_processing_qwen2_vl import Qwen2VLVideoProcessor
from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.attention.layer import (
maybe_get_vit_flash_attn_backend,
)
from vllm.config import VllmConfig
from vllm.attention.layers.mm_encoder_attention import MMEncoderAttention
from vllm.config import MultiModalConfig, VllmConfig
from vllm.config.multimodal import BaseDummyOptions
from vllm.distributed import parallel_state, tensor_model_parallel_all_gather
from vllm.distributed import utils as dist_utils
@ -251,10 +248,15 @@ class Qwen2VisionMLP(nn.Module):
hidden_features: int,
act_layer: type[nn.Module] = QuickGELU,
quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
prefix: str = "",
use_data_parallel: bool = False,
):
super().__init__()
use_data_parallel = (
multimodal_config.mm_encoder_tp_mode == "data"
if multimodal_config
else False
)
self.fc1 = ColumnParallelLinear(
in_features,
hidden_features,
@ -295,12 +297,16 @@ class Qwen2VisionAttention(nn.Module):
num_heads: int,
projection_size: int,
quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
prefix: str = "",
use_data_parallel: bool = False,
attn_backend_override: AttentionBackendEnum | None = None,
) -> None:
super().__init__()
# Per attention head and per partition values.
use_data_parallel = (
multimodal_config.mm_encoder_tp_mode == "data"
if multimodal_config
else False
)
self.tp_size = (
1
if use_data_parallel
@ -329,34 +335,12 @@ class Qwen2VisionAttention(nn.Module):
disable_tp=use_data_parallel,
)
# Detect attention implementation.
self.attn_backend = get_vit_attn_backend(
self.attn = MMEncoderAttention(
num_heads=self.num_attention_heads_per_partition,
head_size=self.hidden_size_per_attention_head,
dtype=torch.get_default_dtype(),
attn_backend_override=attn_backend_override,
multimodal_config=multimodal_config,
)
self.attn_backend, self.flash_attn_varlen_func = (
maybe_get_vit_flash_attn_backend(
self.attn_backend,
attn_backend_override=attn_backend_override,
)
)
if self.attn_backend not in {
AttentionBackendEnum.FLASH_ATTN,
AttentionBackendEnum.TORCH_SDPA,
AttentionBackendEnum.ROCM_AITER_FA,
}:
raise RuntimeError(
f"Qwen2-VL does not support {self.attn_backend} backend now."
)
self.is_flash_attn_backend = self.attn_backend in {
AttentionBackendEnum.FLASH_ATTN,
AttentionBackendEnum.ROCM_AITER_FA,
}
def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]:
# [s, b, 3 * head * head_dim]
seq_len, bs, _ = qkv.shape
@ -398,7 +382,6 @@ class Qwen2VisionAttention(nn.Module):
# [s, b, 3 * head * head_dim] -> 3 * [s, b, head, head_dim]
q, k, v = self.split_qkv(x)
batch_size = q.shape[1]
q, k, v = (rearrange(x, "s b ... -> b s ...") for x in (q, k, v))
@ -409,49 +392,15 @@ class Qwen2VisionAttention(nn.Module):
)
q, k = torch.chunk(qk_rotated, 2, dim=0)
if self.is_flash_attn_backend:
q, k, v = (rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v])
context_layer = self.attn(
query=q,
key=k,
value=v,
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,
)
output = self.flash_attn_varlen_func(
q,
k,
v,
cu_seqlens_q=cu_seqlens,
cu_seqlens_k=cu_seqlens,
max_seqlen_q=max_seqlen,
max_seqlen_k=max_seqlen,
dropout_p=0.0,
causal=False,
)
context_layer = rearrange(
output, "(b s) h d -> s b (h d)", b=batch_size
).contiguous()
elif self.attn_backend == AttentionBackendEnum.TORCH_SDPA:
# Execute attention entry by entry for speed & less VRAM.
from vllm.platforms import current_platform
if current_platform.is_rocm():
q = q.contiguous()
k = k.contiguous()
v = v.contiguous()
outputs = []
lens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
q_chunks = torch.split(q, lens, dim=1)
k_chunks = torch.split(k, lens, dim=1)
v_chunks = torch.split(v, lens, dim=1)
for q_i, k_i, v_i in zip(q_chunks, k_chunks, v_chunks):
q_i, k_i, v_i = (
rearrange(x, "b s h d -> b h s d") for x in [q_i, k_i, v_i]
)
output_i = F.scaled_dot_product_attention(q_i, k_i, v_i, dropout_p=0.0)
output_i = rearrange(output_i, "b h s d -> b s h d ")
outputs.append(output_i)
context_layer = torch.cat(outputs, dim=1)
context_layer = rearrange(
context_layer, "b s h d -> s b (h d)"
).contiguous()
context_layer = rearrange(context_layer, "b s h d -> s b (h d)").contiguous()
output, _ = self.proj(context_layer)
return output
@ -466,9 +415,8 @@ class Qwen2VisionBlock(nn.Module):
act_layer: type[nn.Module] = QuickGELU,
norm_layer: Callable[[int], nn.Module] | None = None,
quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
prefix: str = "",
use_data_parallel: bool = False,
attn_backend_override: AttentionBackendEnum | None = None,
) -> None:
super().__init__()
if norm_layer is None:
@ -482,17 +430,16 @@ class Qwen2VisionBlock(nn.Module):
num_heads=num_heads,
projection_size=dim,
quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.attn",
use_data_parallel=use_data_parallel,
attn_backend_override=attn_backend_override,
)
self.mlp = Qwen2VisionMLP(
dim,
mlp_hidden_dim,
act_layer=act_layer,
quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.mlp",
use_data_parallel=use_data_parallel,
)
def forward(
@ -552,10 +499,15 @@ class Qwen2VisionPatchMerger(nn.Module):
norm_layer: Callable[[int], nn.Module] | None = None,
spatial_merge_size: int = 2,
quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
prefix: str = "",
use_data_parallel: bool = False,
) -> None:
super().__init__()
use_data_parallel = (
multimodal_config.mm_encoder_tp_mode == "data"
if multimodal_config
else False
)
self.hidden_size = context_dim * (spatial_merge_size**2)
if norm_layer is None:
norm_layer = partial(nn.LayerNorm, eps=1e-6)
@ -599,9 +551,8 @@ class Qwen2VisionTransformer(nn.Module):
vision_config: Qwen2VLVisionConfig,
norm_eps: float = 1e-6,
quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
prefix: str = "",
use_data_parallel: bool = False,
attn_backend_override: AttentionBackendEnum | None = None,
) -> None:
super().__init__()
@ -615,7 +566,11 @@ class Qwen2VisionTransformer(nn.Module):
num_heads = vision_config.num_heads
mlp_ratio = vision_config.mlp_ratio
self.use_data_parallel = use_data_parallel
self.use_data_parallel = (
multimodal_config.mm_encoder_tp_mode == "data"
if multimodal_config
else False
)
self.out_hidden_size = vision_config.hidden_size
self.spatial_merge_size = spatial_merge_size
@ -647,8 +602,7 @@ class Qwen2VisionTransformer(nn.Module):
norm_layer=norm_layer,
quant_config=quant_config,
prefix=f"{prefix}.blocks.{layer_idx}",
use_data_parallel=use_data_parallel,
attn_backend_override=attn_backend_override,
multimodal_config=multimodal_config,
)
for layer_idx in range(depth)
]
@ -659,7 +613,10 @@ class Qwen2VisionTransformer(nn.Module):
norm_layer=norm_layer,
quant_config=quant_config,
prefix=f"{prefix}.merger",
use_data_parallel=use_data_parallel,
multimodal_config=multimodal_config,
)
attn_backend_override = (
multimodal_config.mm_encoder_attn_backend if multimodal_config else None
)
self.attn_backend = get_vit_attn_backend(
head_size=head_dim,
@ -720,7 +677,7 @@ class Qwen2VisionTransformer(nn.Module):
AttentionBackendEnum.FLASH_ATTN,
AttentionBackendEnum.ROCM_AITER_FA,
}:
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
return max_seqlen
def forward(
@ -1324,18 +1281,12 @@ class Qwen2VLForConditionalGeneration(
if multimodal_config.get_limit_per_prompt(
"image"
) or multimodal_config.get_limit_per_prompt("video"):
attn_backend_override = (
multimodal_config.mm_encoder_attn_backend
if multimodal_config is not None
else None
)
self.visual = Qwen2VisionTransformer(
config.vision_config,
norm_eps=getattr(config, "rms_norm_eps", 1e-6),
quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=maybe_prefix(prefix, "visual"),
use_data_parallel=self.use_data_parallel,
attn_backend_override=attn_backend_override,
)
else:
self.visual = None

View File

@ -48,7 +48,7 @@ from transformers.models.whisper import WhisperFeatureExtractor
from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.compilation.decorators import support_torch_compile
from vllm.config import VllmConfig
from vllm.config import MultiModalConfig, VllmConfig
from vllm.distributed import get_pp_group
from vllm.logger import init_logger
from vllm.model_executor.layers.activation import _ACTIVATION_REGISTRY
@ -192,6 +192,7 @@ class Qwen3_VisionBlock(nn.Module):
mlp_hidden_dim: int,
act_fn: Callable[[torch.Tensor], torch.Tensor] = F.silu,
norm_layer: Callable[[int], nn.Module] | None = None,
multimodal_config: MultiModalConfig | None = None,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
) -> None:
@ -205,6 +206,7 @@ class Qwen3_VisionBlock(nn.Module):
num_heads=num_heads,
projection_size=dim,
quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.attn",
)
self.mlp = Qwen3_VisionMLP(
@ -299,8 +301,8 @@ class Qwen3Omni_VisionTransformer(nn.Module):
vision_config,
norm_eps: float = 1e-6,
quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
prefix: str = "",
attn_backend_override: AttentionBackendEnum | None = None,
) -> None:
super().__init__()
self.hidden_size = vision_config.hidden_size
@ -347,6 +349,7 @@ class Qwen3Omni_VisionTransformer(nn.Module):
act_fn=_ACTIVATION_REGISTRY[vision_config.hidden_act],
norm_layer=norm_layer,
quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.blocks.{layer_idx}",
)
for layer_idx in range(vision_config.depth)
@ -376,6 +379,12 @@ class Qwen3Omni_VisionTransformer(nn.Module):
]
)
attn_backend_override = (
multimodal_config.mm_encoder_attn_backend
if multimodal_config is not None
else None
)
self.attn_backend = get_vit_attn_backend(
head_size=head_dim,
dtype=torch.get_default_dtype(),
@ -1188,17 +1197,12 @@ class Qwen3OmniMoeThinkerForConditionalGeneration(
self.audio_tower = Qwen3OmniMoeAudioEncoder(thinker_config.audio_config)
attn_backend_override = (
multimodal_config.mm_encoder_attn_backend
if multimodal_config is not None
else None
)
self.visual = Qwen3Omni_VisionTransformer(
vision_config=thinker_config.vision_config,
norm_eps=getattr(thinker_config.text_config, "rms_norm_eps", 1e-6),
quant_config=quant_config,
prefix=maybe_prefix(prefix, "visual"),
attn_backend_override=attn_backend_override,
multimodal_config=multimodal_config,
)
self.quant_config = quant_config

View File

@ -50,7 +50,7 @@ from transformers.video_utils import VideoMetadata
from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.compilation.decorators import support_torch_compile
from vllm.config import VllmConfig
from vllm.config import MultiModalConfig, VllmConfig
from vllm.config.multimodal import BaseDummyOptions, VideoDummyOptions
from vllm.distributed import get_pp_group
from vllm.logger import init_logger
@ -169,10 +169,15 @@ class Qwen3_VisionMLP(nn.Module):
bias: bool = False,
act_fn: Callable[[torch.Tensor], torch.Tensor] = F.silu,
quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
prefix: str = "",
use_data_parallel: bool = False,
):
super().__init__()
use_data_parallel = (
multimodal_config.mm_encoder_tp_mode == "data"
if multimodal_config
else False
)
self.linear_fc1 = ColumnParallelLinear(
in_features,
hidden_features,
@ -206,10 +211,9 @@ class Qwen3_VisionBlock(nn.Module):
mlp_hidden_dim: int,
act_fn: Callable[[torch.Tensor], torch.Tensor] = F.silu,
norm_layer: Callable[[int], nn.Module] | None = None,
multimodal_config: MultiModalConfig | None = None,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
use_data_parallel: bool = False,
attn_backend: AttentionBackendEnum = AttentionBackendEnum.TORCH_SDPA,
) -> None:
super().__init__()
if norm_layer is None:
@ -221,9 +225,8 @@ class Qwen3_VisionBlock(nn.Module):
num_heads=num_heads,
projection_size=dim,
quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.attn",
use_data_parallel=use_data_parallel,
attn_backend=attn_backend,
)
self.mlp = Qwen3_VisionMLP(
dim,
@ -231,8 +234,8 @@ class Qwen3_VisionBlock(nn.Module):
act_fn=act_fn,
bias=True,
quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.mlp",
use_data_parallel=use_data_parallel,
)
def forward(
@ -264,10 +267,15 @@ class Qwen3_VisionPatchMerger(nn.Module):
spatial_merge_size: int = 2,
use_postshuffle_norm: bool = False,
quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
prefix: str = "",
use_data_parallel: bool = False,
) -> None:
super().__init__()
use_data_parallel = (
multimodal_config.mm_encoder_tp_mode == "data"
if multimodal_config
else False
)
self.hidden_size = context_dim * (spatial_merge_size**2)
self.use_postshuffle_norm = use_postshuffle_norm
@ -313,9 +321,8 @@ class Qwen3_VisionTransformer(nn.Module):
vision_config: Qwen3VLVisionConfig,
norm_eps: float = 1e-6,
quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
prefix: str = "",
use_data_parallel: bool = False,
attn_backend_override: AttentionBackendEnum | None = None,
) -> None:
super().__init__()
self.hidden_size = vision_config.hidden_size
@ -326,7 +333,6 @@ class Qwen3_VisionTransformer(nn.Module):
self.spatial_merge_unit = self.spatial_merge_size**2
self.temporal_patch_size = vision_config.temporal_patch_size
self.deepstack_visual_indexes = vision_config.deepstack_visual_indexes
self.use_data_parallel = use_data_parallel
self.num_grid_per_side = int(self.num_position_embeddings**0.5)
# NOTE: This is used for creating empty tensor for all_gather for
@ -359,8 +365,8 @@ class Qwen3_VisionTransformer(nn.Module):
norm_layer=norm_layer,
spatial_merge_size=self.spatial_merge_size,
quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.merger",
use_data_parallel=use_data_parallel,
)
self.deepstack_merger_list = nn.ModuleList(
@ -372,13 +378,16 @@ class Qwen3_VisionTransformer(nn.Module):
use_postshuffle_norm=True,
norm_layer=norm_layer,
quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.deepstack_merger_list.{layer_idx}",
use_data_parallel=use_data_parallel,
)
for layer_idx in range(len(self.deepstack_visual_indexes))
]
)
attn_backend_override = (
multimodal_config.mm_encoder_attn_backend if multimodal_config else None
)
self.attn_backend = get_vit_attn_backend(
head_size=head_dim,
dtype=torch.get_default_dtype(),
@ -402,9 +411,8 @@ class Qwen3_VisionTransformer(nn.Module):
act_fn=_ACTIVATION_REGISTRY[vision_config.hidden_act],
norm_layer=norm_layer,
quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.blocks.{layer_idx}",
use_data_parallel=use_data_parallel,
attn_backend=self.attn_backend,
)
for layer_idx in range(vision_config.depth)
]
@ -1277,18 +1285,12 @@ class Qwen3VLForConditionalGeneration(
) and not multimodal_config.get_limit_per_prompt("video"):
self.visual = None
else:
attn_backend_override = (
multimodal_config.mm_encoder_attn_backend
if multimodal_config is not None
else None
)
self.visual = Qwen3_VisionTransformer(
config.vision_config,
norm_eps=getattr(config, "rms_norm_eps", 1e-6),
quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=maybe_prefix(prefix, "visual"),
use_data_parallel=self.use_data_parallel,
attn_backend_override=attn_backend_override,
)
self.language_model = Qwen3LLMForCausalLM(

View File

@ -418,7 +418,6 @@ class Qwen3VLMoeForConditionalGeneration(
self.config = config
self.multimodal_config = multimodal_config
self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data"
if not multimodal_config.get_limit_per_prompt(
"image"
@ -429,8 +428,8 @@ class Qwen3VLMoeForConditionalGeneration(
config.vision_config,
norm_eps=getattr(config, "rms_norm_eps", 1e-6),
quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=maybe_prefix(prefix, "visual"),
use_data_parallel=self.use_data_parallel,
)
self.language_model = Qwen3MoeLLMForCausalLM(

View File

@ -13,7 +13,8 @@ from transformers import Siglip2VisionConfig
from transformers.configuration_utils import PretrainedConfig
from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.attention.layer import maybe_get_vit_flash_attn_backend
from vllm.attention.layers.mm_encoder_attention import MMEncoderAttention
from vllm.config import MultiModalConfig
from vllm.distributed import divide, get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.conv import Conv2dLayer
@ -28,8 +29,6 @@ from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.platforms import current_platform
from .vision import get_vit_attn_backend
class VisionRotaryEmbedding(nn.Module):
def __init__(self, dim: int, theta: float = 10000.0) -> None:
@ -190,7 +189,7 @@ def apply_rotary_pos_emb(
) -> tuple[torch.Tensor, torch.Tensor]:
cos = cos.chunk(2, dim=-1)[0].contiguous()
sin = sin.chunk(2, dim=-1)[0].contiguous()
if is_flash_attn_backend and not current_platform.is_xpu():
if is_flash_attn_backend and current_platform.is_cuda():
from vllm.vllm_flash_attn.layers.rotary import apply_rotary_emb
apply_rotary_emb_func = apply_rotary_emb
@ -208,6 +207,7 @@ class Siglip2Attention(nn.Module):
self,
config: Siglip2VisionConfig,
quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
prefix: str = "",
use_data_parallel: bool = False,
attn_backend_override: AttentionBackendEnum | None = None,
@ -227,20 +227,25 @@ class Siglip2Attention(nn.Module):
self.dropout = config.attention_dropout
self.is_causal = False
# TODO(Isotr0py): Enable data parallel after we support
# disabling TP on parallel linear layer
use_data_parallel = (
multimodal_config.mm_encoder_tp_mode == "data"
if multimodal_config
else False
)
self.qkv_proj = QKVParallelLinear(
hidden_size=self.embed_dim,
head_size=self.head_dim,
total_num_heads=self.num_heads,
quant_config=quant_config,
prefix=f"{prefix}.qkv_proj",
disable_tp=use_data_parallel,
)
self.out_proj = RowParallelLinear(
input_size=self.embed_dim,
output_size=self.embed_dim,
quant_config=quant_config,
prefix=f"{prefix}.out_proj",
disable_tp=use_data_parallel,
)
self.tp_size = (
@ -249,31 +254,13 @@ class Siglip2Attention(nn.Module):
self.num_heads_per_partition = divide(self.num_heads, self.tp_size)
self.use_rope = config.use_rope
# Detect attention implementation.
self.attn_backend = get_vit_attn_backend(
self.attn = MMEncoderAttention(
num_heads=self.num_heads_per_partition,
head_size=self.head_dim,
dtype=torch.get_default_dtype(),
attn_backend_override=attn_backend_override,
prefix=f"{prefix}.attn",
multimodal_config=multimodal_config,
)
self.attn_backend, self.flash_attn_varlen_func = (
maybe_get_vit_flash_attn_backend(
self.attn_backend,
attn_backend_override=attn_backend_override,
)
)
if self.attn_backend not in {
AttentionBackendEnum.FLASH_ATTN,
AttentionBackendEnum.TORCH_SDPA,
AttentionBackendEnum.ROCM_AITER_FA,
}:
self.attn_backend = AttentionBackendEnum.TORCH_SDPA
self.is_flash_attn_backend = self.attn_backend in {
AttentionBackendEnum.FLASH_ATTN,
AttentionBackendEnum.ROCM_AITER_FA,
}
def forward(
self,
hidden_states: torch.Tensor,
@ -298,46 +285,23 @@ class Siglip2Attention(nn.Module):
keys.unsqueeze(0),
cos,
sin,
self.is_flash_attn_backend,
self.attn.is_flash_attn_backend,
)
queries = queries.squeeze(0)
keys = keys.squeeze(0)
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
if self.is_flash_attn_backend:
attn_output = self.flash_attn_varlen_func(
queries,
keys,
values,
cu_seqlens_q=cu_seqlens,
cu_seqlens_k=cu_seqlens,
max_seqlen_q=max_seqlen,
max_seqlen_k=max_seqlen,
).reshape(seq_length, -1)
elif self.attn_backend == AttentionBackendEnum.TORCH_SDPA:
# Execute attention entry by entry for speed & less VRAM.
batch_size = cu_seqlens.shape[0] - 1
outputs = []
cu = cu_seqlens.tolist()
for i in range(batch_size):
start_idx = cu[i]
end_idx = cu[i + 1]
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
attn_output = self.attn(
query=queries.unsqueeze(0),
key=keys.unsqueeze(0),
value=values.unsqueeze(0),
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,
)
attn_output = attn_output.reshape(
seq_length, self.num_heads_per_partition * self.head_dim
)
# Each sequence is processed independently.
q_i = queries[start_idx:end_idx].unsqueeze(0)
k_i = keys[start_idx:end_idx].unsqueeze(0)
v_i = values[start_idx:end_idx].unsqueeze(0)
# (1, seq_len, num_heads, head_dim) ->
# (1, num_heads, seq_len, head_dim)
q_i, k_i, v_i = [x.transpose(1, 2) for x in (q_i, k_i, v_i)]
output_i = F.scaled_dot_product_attention(q_i, k_i, v_i, dropout_p=0.0)
# (1, num_heads, seq_len, head_dim) -> (seq_len, embed_dim)
output_i = output_i.transpose(1, 2).reshape(end_idx - start_idx, -1)
outputs.append(output_i)
attn_output = torch.cat(outputs, dim=0)
attn_output, _ = self.out_proj(attn_output)
return attn_output
@ -347,25 +311,30 @@ class Siglip2MLP(nn.Module):
self,
config: Siglip2VisionConfig,
quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
prefix: str = "",
use_data_parallel: bool = False,
):
super().__init__()
self.config = config
use_data_parallel = (
multimodal_config.mm_encoder_tp_mode == "data"
if multimodal_config
else False
)
self.activation_fn = get_act_fn(config.hidden_act)
# TODO(Isotr0py): Enable data parallel after we support
# disabling TP on parallel linear layer
self.fc1 = ColumnParallelLinear(
config.hidden_size,
config.intermediate_size,
quant_config=quant_config,
prefix=f"{prefix}.fc1",
disable_tp=use_data_parallel,
)
self.fc2 = RowParallelLinear(
config.intermediate_size,
config.hidden_size,
quant_config=quant_config,
prefix=f"{prefix}.fc2",
disable_tp=use_data_parallel,
)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
@ -380,9 +349,8 @@ class Siglip2EncoderLayer(nn.Module):
self,
config: Siglip2VisionConfig,
quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
prefix: str = "",
use_data_parallel: bool = False,
attn_backend_override: AttentionBackendEnum | None = None,
):
super().__init__()
self.embed_dim = config.hidden_size
@ -390,16 +358,15 @@ class Siglip2EncoderLayer(nn.Module):
self.self_attn = Siglip2Attention(
config,
quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.self_attn",
use_data_parallel=use_data_parallel,
attn_backend_override=attn_backend_override,
)
self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
self.mlp = Siglip2MLP(
config,
quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.mlp",
use_data_parallel=use_data_parallel,
)
def forward(
@ -444,9 +411,8 @@ class Siglip2Encoder(nn.Module):
self,
config: Siglip2VisionConfig,
quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
prefix: str = "",
use_data_parallel: bool = False,
attn_backend_override: AttentionBackendEnum | None = None,
):
super().__init__()
self.config = config
@ -455,9 +421,8 @@ class Siglip2Encoder(nn.Module):
Siglip2EncoderLayer(
config,
quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.layers.{idx}",
use_data_parallel=use_data_parallel,
attn_backend_override=attn_backend_override,
)
for idx in range(config.num_hidden_layers)
]
@ -630,9 +595,8 @@ class Siglip2VisionTransformer(nn.Module):
self,
config: Siglip2VisionConfig,
quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
prefix: str = "",
use_data_parallel: bool = False,
attn_backend_override: AttentionBackendEnum | None = None,
):
super().__init__()
self.config = config
@ -642,9 +606,8 @@ class Siglip2VisionTransformer(nn.Module):
self.encoder = Siglip2Encoder(
config,
quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.encoder",
use_data_parallel=use_data_parallel,
attn_backend_override=attn_backend_override,
)
self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
@ -671,18 +634,16 @@ class Siglip2NavitModel(torch.nn.Module):
self,
config: Siglip2VisionConfig,
quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
prefix: str = "",
use_data_parallel: bool = False,
attn_backend_override: AttentionBackendEnum | None = None,
):
super().__init__()
self.vision_model = Siglip2VisionTransformer(
config,
quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.vision_model",
use_data_parallel=use_data_parallel,
attn_backend_override=attn_backend_override,
)
def forward(

View File

@ -88,14 +88,17 @@ def get_vit_attn_backend(
"""
Get the available attention backend for Vision Transformer.
"""
if attn_backend_override is not None:
return attn_backend_override
attn_backend = attn_backend_override
selected_backend = get_current_vllm_config().attention_config.backend
if selected_backend is not None:
return selected_backend
if attn_backend is None:
attn_backend = selected_backend
return current_platform.get_vit_attn_backend(head_size, dtype)
return current_platform.get_vit_attn_backend(
head_size,
dtype,
backend=attn_backend,
)
def should_torch_compile_mm_vit(vllm_config: VllmConfig) -> bool:

View File

@ -7,7 +7,7 @@ pynvml. However, it should not initialize cuda context.
import os
from collections.abc import Callable
from functools import cache, wraps
from typing import TYPE_CHECKING, TypeVar
from typing import TYPE_CHECKING, Optional, TypeVar
import torch
from typing_extensions import ParamSpec
@ -255,23 +255,6 @@ class CudaPlatformBase(Platform):
torch.cuda.reset_peak_memory_stats(device)
return torch.cuda.max_memory_allocated(device)
@classmethod
def get_vit_attn_backend(
cls, head_size: int, dtype: torch.dtype
) -> "AttentionBackendEnum":
# Try FlashAttention first
if (cc := cls.get_device_capability()) and cc.major >= 8:
try:
backend_class = AttentionBackendEnum.FLASH_ATTN.get_class()
if backend_class.supports_head_size(
head_size
) and backend_class.supports_dtype(dtype):
return AttentionBackendEnum.FLASH_ATTN
except ImportError:
pass
return AttentionBackendEnum.TORCH_SDPA
@classmethod
def get_valid_backends(
cls,
@ -418,6 +401,41 @@ class CudaPlatformBase(Platform):
return selected_backend.get_path()
@classmethod
def get_supported_vit_attn_backends(cls) -> list["AttentionBackendEnum"]:
return [
AttentionBackendEnum.TORCH_SDPA,
AttentionBackendEnum.FLASH_ATTN,
]
@classmethod
def get_vit_attn_backend(
cls,
head_size: int,
dtype: torch.dtype,
backend: Optional["AttentionBackendEnum"] = None,
) -> "AttentionBackendEnum":
if backend is not None:
assert backend in cls.get_supported_vit_attn_backends(), (
f"Backend {backend} is not supported for vit attention. "
f"Supported backends are: {cls.get_supported_vit_attn_backends()}"
)
logger.info_once(f"Using backend {backend} for vit attention")
return backend
# Try FlashAttention first
if (cc := cls.get_device_capability()) and cc.major >= 8:
try:
backend_class = AttentionBackendEnum.FLASH_ATTN.get_class()
if backend_class.supports_head_size(
head_size
) and backend_class.supports_dtype(dtype):
return AttentionBackendEnum.FLASH_ATTN
except ImportError:
pass
return AttentionBackendEnum.TORCH_SDPA
@classmethod
def get_punica_wrapper(cls) -> str:
return "vllm.lora.punica_wrapper.punica_gpu.PunicaWrapperGPU"

View File

@ -7,7 +7,7 @@ import platform
import random
import sys
from datetime import timedelta
from typing import TYPE_CHECKING, Any, NamedTuple
from typing import TYPE_CHECKING, Any, NamedTuple, Optional
import numpy as np
import torch
@ -222,12 +222,6 @@ class Platform:
with contextlib.suppress(ImportError):
import vllm._moe_C # noqa: F401
@classmethod
def get_vit_attn_backend(
cls, head_size: int, dtype: torch.dtype
) -> "AttentionBackendEnum":
return AttentionBackendEnum.TORCH_SDPA
@classmethod
def get_attn_backend_cls(
cls,
@ -245,6 +239,43 @@ class Platform:
"""Get the attention backend class of a device."""
return ""
@classmethod
def get_supported_vit_attn_backends(cls) -> list["AttentionBackendEnum"]:
return [
AttentionBackendEnum.TORCH_SDPA,
]
@classmethod
def get_vit_attn_backend(
cls,
head_size: int,
dtype: torch.dtype,
backend: Optional["AttentionBackendEnum"] = None,
) -> "AttentionBackendEnum":
"""
Get the vision attention backend class of a device.
NOTE: ViT Attention should be checked and override in the platform-specific
implementation. we should not override this in any other places, like
the model_executor/models/<model_name>.py.
We check if the backend is None or not:
1. If not, check if the backend is supported by the platform.
2. If None, continue to the default selection logic.
"""
if backend is not None:
assert backend in cls.get_supported_vit_attn_backends(), (
f"Backend {backend} is not supported for vit attention"
f"Supported backends are: {cls.get_supported_vit_attn_backends()}"
)
logger.info_once(f"Using backend {backend} for vit attention")
return backend
logger.info_once(
f"Using default backend {AttentionBackendEnum.TORCH_SDPA} for vit attention"
)
return AttentionBackendEnum.TORCH_SDPA
@classmethod
def get_device_capability(
cls,

View File

@ -3,7 +3,7 @@
import os
from functools import cache, lru_cache, wraps
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Optional
import torch
@ -187,24 +187,6 @@ class RocmPlatform(Platform):
if not on_gfx9():
supported_quantization += ["bitsandbytes"]
@classmethod
def get_vit_attn_backend(
cls, head_size: int, dtype: torch.dtype
) -> AttentionBackendEnum:
from importlib.util import find_spec
from vllm._aiter_ops import rocm_aiter_ops
if rocm_aiter_ops.is_mha_enabled():
# Note: AITER FA is only supported for Qwen-VL models.
# TODO: Add support for other VL models in their model class.
return AttentionBackendEnum.ROCM_AITER_FA
if on_gfx9() and find_spec("flash_attn") is not None:
return AttentionBackendEnum.FLASH_ATTN
return AttentionBackendEnum.TORCH_SDPA
@classmethod
def get_attn_backend_cls(
cls,
@ -322,6 +304,43 @@ class RocmPlatform(Platform):
"ROCm. Note that V0 attention backends have been removed."
)
@classmethod
def get_supported_vit_attn_backends(cls) -> list["AttentionBackendEnum"]:
return [
AttentionBackendEnum.FLASH_ATTN,
AttentionBackendEnum.ROCM_AITER_FA,
AttentionBackendEnum.TORCH_SDPA,
]
@classmethod
def get_vit_attn_backend(
cls,
head_size: int,
dtype: torch.dtype,
backend: Optional["AttentionBackendEnum"] = None,
) -> "AttentionBackendEnum":
if backend is not None:
assert backend in cls.get_supported_vit_attn_backends(), (
f"Backend {backend} is not supported for vit attention. "
f"Supported backends are: {cls.get_supported_vit_attn_backends()}"
)
logger.info_once(f"Using backend {backend} for vit attention")
return backend
from importlib.util import find_spec
from vllm._aiter_ops import rocm_aiter_ops
if rocm_aiter_ops.is_mha_enabled():
# Note: AITER FA is only supported for Qwen-VL models.
# TODO: Add support for other VL models in their model class.
return AttentionBackendEnum.ROCM_AITER_FA
if on_gfx9() and find_spec("flash_attn") is not None:
return AttentionBackendEnum.FLASH_ATTN
return AttentionBackendEnum.TORCH_SDPA
@classmethod
def set_device(cls, device: torch.device) -> None:
"""

View File

@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import contextlib
from typing import TYPE_CHECKING, cast
from typing import TYPE_CHECKING, Optional, cast
import torch
from tpu_info import device
@ -75,6 +75,32 @@ class TpuPlatform(Platform):
logger.info("Using Pallas V1 backend.")
return AttentionBackendEnum.PALLAS.get_path()
@classmethod
def get_supported_vit_attn_backends(cls) -> list["AttentionBackendEnum"]:
return [
AttentionBackendEnum.PALLAS,
]
@classmethod
def get_vit_attn_backend(
cls,
head_size: int,
dtype: torch.dtype,
backend: Optional["AttentionBackendEnum"] = None,
) -> "AttentionBackendEnum":
if backend is not None:
assert backend in cls.get_supported_vit_attn_backends(), (
f"Backend {backend} is not supported for vit attention"
f"Supported backends are: {cls.get_supported_vit_attn_backends()}."
)
logger.info_once(f"Using backend {backend} for vit attention.")
return backend
logger.info_once(
f"Using default backend {AttentionBackendEnum.PALLAS} for vit attention."
)
return AttentionBackendEnum.PALLAS
@classmethod
def set_device(cls, device: torch.device) -> None:
"""

View File

@ -3,7 +3,7 @@
import contextlib
import os
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Optional
import torch
@ -77,6 +77,34 @@ class XPUPlatform(Platform):
logger.info("Using Flash Attention backend.")
return AttentionBackendEnum.FLASH_ATTN.get_path()
@classmethod
def get_supported_vit_attn_backends(cls) -> list["AttentionBackendEnum"]:
# XPU only supports FLASH_ATTN for vision attention.
return [
AttentionBackendEnum.FLASH_ATTN,
]
@classmethod
def get_vit_attn_backend(
cls,
head_size: int,
dtype: torch.dtype,
backend: Optional["AttentionBackendEnum"] = None,
) -> "AttentionBackendEnum":
if backend is not None:
assert backend in cls.get_supported_vit_attn_backends(), (
f"Backend {backend} is not supported for vit attention. "
f"Supported backends are: "
f"{cls.get_supported_vit_attn_backends()}."
)
logger.info_once(f"Using backend {backend} for vit attention")
return backend
logger.info_once(
f"Using backend {AttentionBackendEnum.FLASH_ATTN} for vit attention"
)
return AttentionBackendEnum.FLASH_ATTN
@classmethod
def set_device(cls, device: torch.device) -> None:
"""
@ -110,12 +138,6 @@ class XPUPlatform(Platform):
device_props = torch.xpu.get_device_properties(device_id)
return device_props.total_memory
@classmethod
def get_vit_attn_backend(
cls, head_size: int, dtype: torch.dtype
) -> "AttentionBackendEnum":
return AttentionBackendEnum.FLASH_ATTN
@classmethod
def inference_mode(cls):
return torch.no_grad()