mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-27 07:15:15 +08:00
[CustomOp][MM] Extract MMEncoderAttention as CustomOp and replace the backend of QwenVisionAttention with it. (#30125)
Signed-off-by: shen-shanshan <467638484@qq.com> Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn> Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com> Co-authored-by: Isotr0py <mozf@mail2.sysu.edu.cn> Co-authored-by: tjtanaa <tunjian.tan@embeddedllm.com>
This commit is contained in:
parent
84e23d103d
commit
87b4d1557d
@ -0,0 +1,434 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
Consolidated test for ViT attention backend functionality across multiple models.
|
||||
|
||||
This test validates that each multimodal model can successfully generate outputs
|
||||
using different ViT attention backends. Tests are parametrized by model and backend.
|
||||
"""
|
||||
|
||||
from dataclasses import asdict
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
from transformers import AutoProcessor
|
||||
|
||||
from vllm import LLM, EngineArgs, SamplingParams
|
||||
from vllm.attention.backends.registry import AttentionBackendEnum
|
||||
from vllm.multimodal.utils import encode_image_base64
|
||||
from vllm.multimodal.video import sample_frames_from_video
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from ....utils import create_new_process_for_each_test
|
||||
from ...utils import dummy_hf_overrides
|
||||
|
||||
# Dots.OCR prompt from official repository
|
||||
# https://github.com/rednote-hilab/dots.ocr/blob/d72d1d8c5bdd0362eb264f714cdbd1e5daa7cdff/dots_ocr/utils/prompts.py#L3
|
||||
# ruff: noqa: E501
|
||||
DOTS_OCR_PROMPT = """Please output the layout information from the PDF image, including each layout element's bbox, its category, and the corresponding text content within the bbox.
|
||||
|
||||
1. Bbox format: [x1, y1, x2, y2]
|
||||
|
||||
2. Layout Categories: The possible categories are ['Caption', 'Footnote', 'Formula', 'List-item', 'Page-footer', 'Page-header', 'Picture', 'Section-header', 'Table', 'Text', 'Title'].
|
||||
|
||||
3. Text Extraction & Formatting Rules:
|
||||
- Picture: For the 'Picture' category, the text field should be omitted.
|
||||
- Formula: Format its text as LaTeX.
|
||||
- Table: Format its text as HTML.
|
||||
- All Others (Text, Title, etc.): Format their text as Markdown.
|
||||
|
||||
4. Constraints:
|
||||
- The output text must be the original text from the image, with no translation.
|
||||
- All layout elements must be sorted according to human reading order.
|
||||
|
||||
5. Final Output: The entire output must be a single JSON object.
|
||||
"""
|
||||
|
||||
VIDEO_PLACEHOLDER = "<|vision_start|><|video_pad|><|vision_end|>"
|
||||
|
||||
|
||||
# Model configurations
|
||||
MODEL_CONFIGS: dict[str, dict[str, Any]] = {
|
||||
"dots_ocr": {
|
||||
"model_name": "rednote-hilab/dots.ocr",
|
||||
"interface": "llm_chat",
|
||||
"max_model_len": 32768,
|
||||
"max_num_seqs": 1,
|
||||
"limit_mm_per_prompt": {"image": 1},
|
||||
"sampling_params": {
|
||||
"temperature": 0.1,
|
||||
"max_tokens": 16384,
|
||||
"top_p": 0.9,
|
||||
"stop_token_ids": None,
|
||||
},
|
||||
"use_specific_image": "stop_sign",
|
||||
"prompt_builder": "build_dots_ocr_prompt",
|
||||
"output_validator": lambda x: len(x) > 10 and "stop" in x.lower(),
|
||||
},
|
||||
"ernie45_vl": {
|
||||
"model_name": "baidu/ERNIE-4.5-VL-28B-A3B-PT",
|
||||
"interface": "llm_generate",
|
||||
"max_model_len": 16384,
|
||||
"max_num_seqs": 2,
|
||||
"sampling_params": {
|
||||
"temperature": 0.0,
|
||||
"max_tokens": 256,
|
||||
"stop_token_ids": None,
|
||||
},
|
||||
"use_processor": True,
|
||||
"question": "What is the content of each image?",
|
||||
},
|
||||
"glm4_1v": {
|
||||
"model_name": "zai-org/GLM-4.1V-9B-Thinking",
|
||||
"interface": "llm_generate",
|
||||
"max_model_len": 32768,
|
||||
"max_num_seqs": 2,
|
||||
"sampling_params": {
|
||||
"temperature": 0.0,
|
||||
"max_tokens": 256,
|
||||
"stop_token_ids": None,
|
||||
},
|
||||
"use_processor": True,
|
||||
"question": "What is the content of each image?",
|
||||
},
|
||||
"keye_vl": {
|
||||
"model_name": "Kwai-Keye/Keye-VL-8B-Preview",
|
||||
"interface": "llm_generate",
|
||||
"max_model_len": 8192,
|
||||
"max_num_seqs": 5,
|
||||
"sampling_params": {
|
||||
"temperature": 0.0,
|
||||
"max_tokens": 256,
|
||||
"stop_token_ids": None,
|
||||
},
|
||||
"supported_backends": {
|
||||
AttentionBackendEnum.FLASH_ATTN,
|
||||
AttentionBackendEnum.ROCM_AITER_FA,
|
||||
},
|
||||
"use_processor": True,
|
||||
"question": "What is the content of each image?",
|
||||
},
|
||||
"ovis2_5": {
|
||||
"model_name": "AIDC-AI/Ovis2.5-2B",
|
||||
"interface": "llm_generate",
|
||||
"max_model_len": 8192,
|
||||
"max_num_seqs": 2,
|
||||
"sampling_params": {
|
||||
"temperature": 0.0,
|
||||
"max_tokens": 256,
|
||||
"stop_token_ids": None,
|
||||
},
|
||||
"prompt_builder": "build_ovis_prompt",
|
||||
"question": "What is the content of each image?",
|
||||
},
|
||||
"qwen2_5_vl": {
|
||||
"model_name": "Qwen/Qwen2.5-VL-3B-Instruct",
|
||||
"interface": "vllm_runner",
|
||||
"media_type": "video",
|
||||
"max_model_len": 4000,
|
||||
"max_num_seqs": 1,
|
||||
"limit_mm_per_prompt": {"video": 1},
|
||||
"sampling_params": {
|
||||
"max_tokens": 128,
|
||||
},
|
||||
"runner_kwargs": {
|
||||
"runner": "generate",
|
||||
"dtype": "bfloat16",
|
||||
},
|
||||
"video_params": {
|
||||
"num_frames": 16,
|
||||
"pruning_rates": [0.0, 0.75],
|
||||
},
|
||||
},
|
||||
"qwen2_5_omni": {
|
||||
"model_name": "Qwen/Qwen2.5-Omni-3B",
|
||||
"interface": "llm_generate",
|
||||
"max_model_len": 32768,
|
||||
"max_num_seqs": 2,
|
||||
"limit_mm_per_prompt": {"image": 3, "video": 3, "audio": 3},
|
||||
"sampling_params": {
|
||||
"temperature": 0.6,
|
||||
"top_p": 0.95,
|
||||
"top_k": 20,
|
||||
"max_tokens": 16384,
|
||||
},
|
||||
"use_processor": True,
|
||||
"question": "What is the content of each image?",
|
||||
},
|
||||
"qwen3_omni": {
|
||||
"model_name": "Qwen/Qwen3-Omni-30B-A3B-Instruct",
|
||||
"interface": "llm_generate",
|
||||
"max_model_len": 32768,
|
||||
"max_num_seqs": 2,
|
||||
"limit_mm_per_prompt": {"image": 3, "video": 3, "audio": 3},
|
||||
"sampling_params": {
|
||||
"temperature": 0.6,
|
||||
"top_p": 0.95,
|
||||
"top_k": 20,
|
||||
"max_tokens": 16384,
|
||||
},
|
||||
"use_processor": True,
|
||||
"question": "What is the content of each image?",
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
# Prompt builder functions
|
||||
def build_dots_ocr_prompt(images, config):
|
||||
"""Build Dots.OCR specific prompt with OCR instructions."""
|
||||
# Use only stop_sign image for Dots.OCR
|
||||
image = images[0] # Already filtered to stop_sign
|
||||
|
||||
image_url = f"data:image/jpeg;base64,{encode_image_base64(image)}"
|
||||
|
||||
placeholders = [{"type": "image_url", "image_url": {"url": image_url}}]
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
*placeholders,
|
||||
{
|
||||
"type": "text",
|
||||
"text": f"<|img|><|imgpad|><|endofimg|>{DOTS_OCR_PROMPT}",
|
||||
},
|
||||
],
|
||||
},
|
||||
]
|
||||
|
||||
return messages
|
||||
|
||||
|
||||
def build_processor_prompt(images, config):
|
||||
"""Build prompt using AutoProcessor.apply_chat_template()."""
|
||||
processor = AutoProcessor.from_pretrained(
|
||||
config["model_name"], trust_remote_code=True
|
||||
)
|
||||
|
||||
image_urls = [
|
||||
f"data:image/jpeg;base64,{encode_image_base64(img)}" for img in images
|
||||
]
|
||||
placeholders = [{"type": "image", "image": url} for url in image_urls]
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
*placeholders,
|
||||
{"type": "text", "text": config["question"]},
|
||||
],
|
||||
},
|
||||
]
|
||||
|
||||
return processor.apply_chat_template(
|
||||
messages, tokenize=False, add_generation_prompt=True
|
||||
)
|
||||
|
||||
|
||||
def build_ovis_prompt(images, config):
|
||||
"""Build Ovis2.5 specific prompt with custom format."""
|
||||
image_urls = [
|
||||
f"data:image/jpeg;base64,{encode_image_base64(img)}" for img in images
|
||||
]
|
||||
|
||||
placeholders = "\n".join(
|
||||
f"Image-{i}: <image>\n" for i, _ in enumerate(image_urls, start=1)
|
||||
)
|
||||
|
||||
return (
|
||||
f"<|im_start|>user\n\n{placeholders}\n{config['question']}<|im_end|>\n"
|
||||
"<|im_start|>assistant\n"
|
||||
)
|
||||
|
||||
|
||||
def build_qwen2_5_video_prompt():
|
||||
"""Build Qwen2.5-VL video prompt with EVS placeholder."""
|
||||
return (
|
||||
f"<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n"
|
||||
f"<|im_start|>user\n{VIDEO_PLACEHOLDER}"
|
||||
"Describe this video with a short sentence (no more than 20 words)"
|
||||
"<|im_end|><|im_start|>assistant\n"
|
||||
)
|
||||
|
||||
|
||||
# Handler functions
|
||||
def run_llm_generate_test(config, mm_encoder_attn_backend, image_assets):
|
||||
"""Standard LLM.generate() interface handler."""
|
||||
images = [asset.pil_image for asset in image_assets]
|
||||
|
||||
# Build prompt
|
||||
if config.get("use_processor"):
|
||||
prompt = build_processor_prompt(images, config)
|
||||
else:
|
||||
prompt_builder_name = config.get("prompt_builder", "build_ovis_prompt")
|
||||
prompt_builder = globals()[prompt_builder_name]
|
||||
prompt = prompt_builder(images, config)
|
||||
|
||||
# Determine limit_mm_per_prompt
|
||||
limit_mm_per_prompt = config.get("limit_mm_per_prompt", {"image": len(images)})
|
||||
|
||||
# Create engine
|
||||
engine_args = EngineArgs(
|
||||
model=config["model_name"],
|
||||
trust_remote_code=True,
|
||||
max_model_len=config["max_model_len"],
|
||||
max_num_seqs=config["max_num_seqs"],
|
||||
limit_mm_per_prompt=limit_mm_per_prompt,
|
||||
mm_encoder_attn_backend=mm_encoder_attn_backend,
|
||||
hf_overrides=dummy_hf_overrides,
|
||||
load_format="dummy",
|
||||
)
|
||||
|
||||
engine_dict = asdict(engine_args) | {"seed": 42}
|
||||
llm = LLM(**engine_dict)
|
||||
|
||||
# Generate
|
||||
sampling_params = SamplingParams(**config["sampling_params"])
|
||||
outputs = llm.generate(
|
||||
{
|
||||
"prompt": prompt,
|
||||
"multi_modal_data": {"image": images},
|
||||
},
|
||||
sampling_params=sampling_params,
|
||||
)
|
||||
|
||||
# Validate
|
||||
for o in outputs:
|
||||
generated_text = o.outputs[0].text
|
||||
validator = config.get("output_validator", lambda x: len(x) > 10)
|
||||
assert validator(generated_text), (
|
||||
f"Validation failed for {config['model_name']}: {generated_text}"
|
||||
)
|
||||
|
||||
|
||||
def run_llm_chat_test(config, mm_encoder_attn_backend, image_assets):
|
||||
"""LLM.chat() interface handler for Dots.OCR."""
|
||||
# Filter to stop_sign image only
|
||||
stop_sign_image = [
|
||||
asset.pil_image for asset in image_assets if asset.name == "stop_sign"
|
||||
][0]
|
||||
|
||||
# Build messages
|
||||
messages = build_dots_ocr_prompt([stop_sign_image], config)
|
||||
|
||||
# Create engine
|
||||
engine_args = EngineArgs(
|
||||
model=config["model_name"],
|
||||
trust_remote_code=True,
|
||||
max_model_len=config["max_model_len"],
|
||||
max_num_seqs=config["max_num_seqs"],
|
||||
limit_mm_per_prompt=config["limit_mm_per_prompt"],
|
||||
mm_encoder_attn_backend=mm_encoder_attn_backend,
|
||||
hf_overrides=dummy_hf_overrides,
|
||||
load_format="dummy",
|
||||
)
|
||||
|
||||
engine_dict = asdict(engine_args) | {"seed": 42}
|
||||
llm = LLM(**engine_dict)
|
||||
|
||||
# Generate using chat
|
||||
sampling_params = SamplingParams(**config["sampling_params"])
|
||||
outputs = llm.chat(messages=messages, sampling_params=sampling_params)
|
||||
|
||||
# Validate
|
||||
for o in outputs:
|
||||
generated_text = o.outputs[0].text
|
||||
validator = config.get("output_validator", lambda x: len(x) > 10)
|
||||
assert validator(generated_text), (
|
||||
f"Validation failed for {config['model_name']}: {generated_text}"
|
||||
)
|
||||
|
||||
|
||||
def run_video_test(config, mm_encoder_attn_backend, video_assets, vllm_runner):
|
||||
"""Video test with EVS (Efficient Video Sampling) handler."""
|
||||
for pruning_rate in config["video_params"]["pruning_rates"]:
|
||||
num_frames = config["video_params"]["num_frames"]
|
||||
|
||||
# Sample frames from video
|
||||
sampled_vids = [
|
||||
sample_frames_from_video(asset.np_ndarrays, num_frames)
|
||||
for asset in video_assets
|
||||
]
|
||||
|
||||
# Build prompt and prepare video
|
||||
prompt = build_qwen2_5_video_prompt()
|
||||
prompts = [prompt]
|
||||
videos = [sampled_vids[0]]
|
||||
|
||||
# Run with vllm_runner context manager
|
||||
with vllm_runner(
|
||||
config["model_name"],
|
||||
max_model_len=config["max_model_len"],
|
||||
max_num_seqs=config["max_num_seqs"],
|
||||
limit_mm_per_prompt=config["limit_mm_per_prompt"],
|
||||
tensor_parallel_size=1,
|
||||
video_pruning_rate=pruning_rate,
|
||||
mm_encoder_attn_backend=mm_encoder_attn_backend,
|
||||
hf_overrides=dummy_hf_overrides,
|
||||
load_format="dummy",
|
||||
**config["runner_kwargs"],
|
||||
) as vllm_model:
|
||||
outputs = vllm_model.generate_greedy(
|
||||
prompts,
|
||||
config["sampling_params"]["max_tokens"],
|
||||
videos=videos,
|
||||
)
|
||||
|
||||
# Validate output
|
||||
assert len(outputs) == 1, f"Expected 1 output, got {len(outputs)}"
|
||||
output_ids, output_text = outputs[0]
|
||||
assert len(output_ids) > 0, "Generated no output IDs"
|
||||
assert len(output_text) > 0, "Generated empty text"
|
||||
assert isinstance(output_text, str), (
|
||||
f"Output is not string: {type(output_text)}"
|
||||
)
|
||||
|
||||
|
||||
# Main test function
|
||||
@pytest.mark.parametrize("model_key", list(MODEL_CONFIGS.keys()))
|
||||
@pytest.mark.parametrize(
|
||||
"mm_encoder_attn_backend",
|
||||
[None] + current_platform.get_supported_vit_attn_backends(),
|
||||
)
|
||||
@create_new_process_for_each_test()
|
||||
def test_vit_backend_functionality(
|
||||
model_key: str,
|
||||
mm_encoder_attn_backend: AttentionBackendEnum | None,
|
||||
image_assets,
|
||||
video_assets,
|
||||
vllm_runner,
|
||||
request,
|
||||
):
|
||||
"""Test ViT attention backend functionality for multimodal models.
|
||||
|
||||
This test validates that each model can successfully generate outputs
|
||||
using different ViT attention backends. The test:
|
||||
1. Filters unsupported backends per model
|
||||
2. Applies appropriate GPU marks
|
||||
3. Routes to the correct test handler based on interface
|
||||
4. Validates output meets minimum requirements
|
||||
"""
|
||||
config = MODEL_CONFIGS[model_key]
|
||||
|
||||
# Step 1: Backend filtering
|
||||
if (
|
||||
"supported_backends" in config
|
||||
and mm_encoder_attn_backend is not None
|
||||
and mm_encoder_attn_backend not in config["supported_backends"]
|
||||
):
|
||||
pytest.skip(
|
||||
f"{model_key} does not support {mm_encoder_attn_backend} backend now."
|
||||
)
|
||||
|
||||
# Step 2: Apply GPU marks dynamically
|
||||
if "gpu_marks" in config:
|
||||
for mark in config["gpu_marks"]:
|
||||
request.applymarker(mark)
|
||||
|
||||
# Step 3: Route to appropriate handler
|
||||
if config.get("media_type") == "video":
|
||||
run_video_test(config, mm_encoder_attn_backend, video_assets, vllm_runner)
|
||||
elif config["interface"] == "llm_chat":
|
||||
run_llm_chat_test(config, mm_encoder_attn_backend, image_assets)
|
||||
elif config["interface"] == "llm_generate":
|
||||
run_llm_generate_test(config, mm_encoder_attn_backend, image_assets)
|
||||
else:
|
||||
raise ValueError(f"Unknown interface: {config['interface']}")
|
||||
@ -3,7 +3,6 @@
|
||||
"""Attention layer."""
|
||||
|
||||
import functools
|
||||
from collections.abc import Callable
|
||||
from typing import cast
|
||||
|
||||
import torch
|
||||
@ -17,6 +16,7 @@ from vllm.attention.backends.abstract import (
|
||||
MLAAttentionImpl,
|
||||
)
|
||||
from vllm.attention.backends.registry import AttentionBackendEnum
|
||||
from vllm.attention.layers.mm_encoder_attention import maybe_get_vit_flash_attn_backend
|
||||
from vllm.attention.selector import get_attn_backend
|
||||
from vllm.attention.utils.fa_utils import get_flash_attn_version
|
||||
from vllm.attention.utils.kv_sharing_utils import validate_kv_sharing_target
|
||||
@ -49,58 +49,9 @@ from vllm.v1.kv_cache_interface import (
|
||||
SlidingWindowSpec,
|
||||
)
|
||||
|
||||
if current_platform.is_rocm():
|
||||
from vllm.platforms.rocm import on_gfx9
|
||||
else:
|
||||
on_gfx9 = lambda *args, **kwargs: False
|
||||
|
||||
|
||||
FP8_DTYPE = current_platform.fp8_dtype()
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def maybe_get_vit_flash_attn_backend(
|
||||
attn_backend: AttentionBackendEnum,
|
||||
attn_backend_override: AttentionBackendEnum | None = None,
|
||||
) -> tuple[AttentionBackendEnum, Callable | None]:
|
||||
if current_platform.is_rocm():
|
||||
if envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_MHA and on_gfx9():
|
||||
attn_backend = AttentionBackendEnum.ROCM_AITER_FA
|
||||
elif (
|
||||
attn_backend_override is None
|
||||
and on_gfx9()
|
||||
and attn_backend == AttentionBackendEnum.FLASH_ATTN
|
||||
):
|
||||
pass
|
||||
else:
|
||||
return AttentionBackendEnum.TORCH_SDPA, None
|
||||
elif current_platform.is_cuda():
|
||||
pass
|
||||
elif current_platform.is_xpu():
|
||||
assert attn_backend == AttentionBackendEnum.FLASH_ATTN, (
|
||||
"XPU platform only supports FLASH_ATTN as vision attention backend."
|
||||
)
|
||||
pass
|
||||
else:
|
||||
return AttentionBackendEnum.TORCH_SDPA, None
|
||||
|
||||
if attn_backend in {
|
||||
AttentionBackendEnum.FLASH_ATTN,
|
||||
AttentionBackendEnum.ROCM_AITER_FA,
|
||||
}:
|
||||
if attn_backend == AttentionBackendEnum.ROCM_AITER_FA:
|
||||
from aiter import flash_attn_varlen_func
|
||||
else:
|
||||
try:
|
||||
from vllm.attention.utils.fa_utils import flash_attn_varlen_func
|
||||
except ImportError:
|
||||
flash_attn_varlen_func = None
|
||||
else:
|
||||
flash_attn_varlen_func = None
|
||||
|
||||
return attn_backend, flash_attn_varlen_func
|
||||
|
||||
|
||||
def _init_kv_cache_quant(
|
||||
layer: nn.Module,
|
||||
quant_config: QuantizationConfig | None,
|
||||
@ -496,29 +447,15 @@ class MultiHeadAttention(nn.Module):
|
||||
attn_backend_override = None
|
||||
if multimodal_config is not None:
|
||||
attn_backend_override = multimodal_config.mm_encoder_attn_backend
|
||||
backend = get_vit_attn_backend(
|
||||
|
||||
self.attn_backend = get_vit_attn_backend(
|
||||
head_size=head_size,
|
||||
dtype=dtype,
|
||||
attn_backend_override=attn_backend_override,
|
||||
)
|
||||
|
||||
self.attn_backend = (
|
||||
backend
|
||||
if backend
|
||||
in {
|
||||
AttentionBackendEnum.TORCH_SDPA,
|
||||
AttentionBackendEnum.PALLAS,
|
||||
AttentionBackendEnum.ROCM_AITER_FA,
|
||||
AttentionBackendEnum.FLASH_ATTN,
|
||||
}
|
||||
else AttentionBackendEnum.TORCH_SDPA
|
||||
)
|
||||
|
||||
self.attn_backend, self._flash_attn_varlen_func = (
|
||||
maybe_get_vit_flash_attn_backend(
|
||||
self.attn_backend,
|
||||
attn_backend_override=attn_backend_override,
|
||||
)
|
||||
self._flash_attn_varlen_func = maybe_get_vit_flash_attn_backend(
|
||||
self.attn_backend,
|
||||
)
|
||||
|
||||
self.is_flash_attn_backend = self.attn_backend in {
|
||||
|
||||
284
vllm/attention/layers/mm_encoder_attention.py
Normal file
284
vllm/attention/layers/mm_encoder_attention.py
Normal file
@ -0,0 +1,284 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from collections.abc import Callable
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.attention.backends.registry import AttentionBackendEnum
|
||||
from vllm.attention.ops.vit_attn_wrappers import (
|
||||
vit_flash_attn_wrapper,
|
||||
vit_torch_sdpa_wrapper,
|
||||
)
|
||||
from vllm.config import MultiModalConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.custom_op import CustomOp
|
||||
from vllm.model_executor.models.vision import get_vit_attn_backend
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def maybe_get_vit_flash_attn_backend(
|
||||
attn_backend: AttentionBackendEnum | None,
|
||||
) -> Callable | None:
|
||||
# At this point,
|
||||
# we already have the attn_backend,
|
||||
# overriding logic is done in the platform-specific implementation.
|
||||
# so we don't need to override backend here.
|
||||
# Just return the attn_backend and flash_attn_varlen_func.
|
||||
|
||||
if attn_backend == AttentionBackendEnum.FLASH_ATTN:
|
||||
from vllm.attention.utils.fa_utils import flash_attn_varlen_func
|
||||
elif attn_backend == AttentionBackendEnum.ROCM_AITER_FA:
|
||||
from aiter import flash_attn_varlen_func
|
||||
else:
|
||||
flash_attn_varlen_func = None
|
||||
|
||||
# if attn_backend is TORCH_SDPA,
|
||||
# it will reach here and the flash_attn_varlen_func will be None.
|
||||
return flash_attn_varlen_func
|
||||
|
||||
|
||||
@CustomOp.register("mm_encoder_attn")
|
||||
class MMEncoderAttention(CustomOp):
|
||||
"""Multi-headed attention without any cache, used for multimodal encoder."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
scale: float | None = None,
|
||||
num_kv_heads: int | None = None,
|
||||
prefix: str = "",
|
||||
multimodal_config: MultiModalConfig | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Args:
|
||||
num_heads: number of attention heads per partition.
|
||||
head_size: hidden_size per attention head.
|
||||
scale: scale factor.
|
||||
num_kv_heads: number of kv heads.
|
||||
prefix: This has no effect, it is only here to make it easier to
|
||||
swap between Attention and MultiHeadAttention
|
||||
multimodal_config: configs for multi-modal.
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
self.num_heads = num_heads
|
||||
self.head_size = head_size
|
||||
self.scale = scale
|
||||
self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
|
||||
self.layer_name = prefix
|
||||
|
||||
assert self.num_heads % self.num_kv_heads == 0, (
|
||||
f"num_heads ({self.num_heads}) is not "
|
||||
f"divisible by num_kv_heads ({self.num_kv_heads})"
|
||||
)
|
||||
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
||||
|
||||
# During model initialization, the default dtype is set as the model
|
||||
# weight and activation dtype.
|
||||
dtype = torch.get_default_dtype()
|
||||
|
||||
# Try to get vision attention backend from multimodal_config.
|
||||
attn_backend_override = None
|
||||
if multimodal_config is not None:
|
||||
attn_backend_override = multimodal_config.mm_encoder_attn_backend
|
||||
|
||||
# Get device-specific vision attention backend.
|
||||
self.attn_backend = get_vit_attn_backend(
|
||||
head_size=head_size,
|
||||
dtype=dtype,
|
||||
attn_backend_override=attn_backend_override,
|
||||
)
|
||||
|
||||
self.is_flash_attn_backend = self.attn_backend in {
|
||||
AttentionBackendEnum.FLASH_ATTN,
|
||||
AttentionBackendEnum.ROCM_AITER_FA,
|
||||
}
|
||||
|
||||
self.flash_attn_varlen_func = maybe_get_vit_flash_attn_backend(
|
||||
self.attn_backend,
|
||||
)
|
||||
|
||||
logger.info_once(f"Using {self.attn_backend} for MMEncoderAttention.")
|
||||
|
||||
@classmethod
|
||||
def enabled(cls) -> bool:
|
||||
return True
|
||||
|
||||
def reshape_qkv_to_4d(
|
||||
self,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
bsz: int,
|
||||
q_len: int,
|
||||
kv_len: int,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Reshape query, key, value to 4D tensors:
|
||||
(batch_size, seq_len, num_heads, head_size)
|
||||
"""
|
||||
query = query.view(bsz, q_len, self.num_heads, self.head_size)
|
||||
key = key.view(bsz, kv_len, self.num_kv_heads, self.head_size)
|
||||
value = value.view(bsz, kv_len, self.num_kv_heads, self.head_size)
|
||||
|
||||
if (num_repeat := self.num_queries_per_kv) > 1:
|
||||
# Handle MQA and GQA
|
||||
key = torch.repeat_interleave(key, num_repeat, dim=2)
|
||||
value = torch.repeat_interleave(value, num_repeat, dim=2)
|
||||
|
||||
return query, key, value
|
||||
|
||||
def reshape_qkv_to_3d(
|
||||
self,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
bsz: int,
|
||||
q_len: int,
|
||||
kv_len: int,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Reshape query, key, value to 3D tensors:
|
||||
(batch_size * seq_len, num_heads, head_size)
|
||||
"""
|
||||
query = query.view(bsz * q_len, self.num_heads, self.head_size)
|
||||
key = key.view(bsz * kv_len, self.num_kv_heads, self.head_size)
|
||||
value = value.view(bsz * kv_len, self.num_kv_heads, self.head_size)
|
||||
|
||||
if (num_repeat := self.num_queries_per_kv) > 1:
|
||||
# Handle MQA and GQA
|
||||
key = torch.repeat_interleave(key, num_repeat, dim=1)
|
||||
value = torch.repeat_interleave(value, num_repeat, dim=1)
|
||||
|
||||
return query, key, value
|
||||
|
||||
def _forward_sdpa(
|
||||
self,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
cu_seqlens: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
# TODO(Isotr0py): Migrate MultiHeadAttention
|
||||
assert cu_seqlens is not None
|
||||
|
||||
bsz, q_len = query.size()[:2]
|
||||
kv_len = key.size(1)
|
||||
|
||||
query, key, value = self.reshape_qkv_to_4d(
|
||||
query, key, value, bsz, q_len, kv_len
|
||||
)
|
||||
|
||||
output = vit_torch_sdpa_wrapper(
|
||||
q=query,
|
||||
k=key,
|
||||
v=value,
|
||||
cu_seqlens=cu_seqlens,
|
||||
)
|
||||
return output
|
||||
|
||||
def _forward_fa(
|
||||
self,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
cu_seqlens: torch.Tensor | None = None,
|
||||
max_seqlen: torch.Tensor | None = None, # Only used for Flash Attention
|
||||
) -> torch.Tensor:
|
||||
assert self.flash_attn_varlen_func is not None, (
|
||||
"Flash attention function is not set."
|
||||
)
|
||||
# # TODO(Isotr0py): Migrate MultiHeadAttention
|
||||
assert cu_seqlens is not None and max_seqlen is not None
|
||||
|
||||
bsz = query.shape[0]
|
||||
|
||||
output = vit_flash_attn_wrapper(
|
||||
q=query,
|
||||
k=key,
|
||||
v=value,
|
||||
cu_seqlens=cu_seqlens,
|
||||
max_seqlen=max_seqlen,
|
||||
batch_size=bsz,
|
||||
is_rocm_aiter=(self.attn_backend == AttentionBackendEnum.ROCM_AITER_FA),
|
||||
)
|
||||
return output
|
||||
|
||||
def forward_native(
|
||||
self,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
cu_seqlens: torch.Tensor | None = None,
|
||||
max_seqlen: torch.Tensor | None = None, # Only used for Flash Attention
|
||||
) -> torch.Tensor:
|
||||
return self._forward_sdpa(query, key, value, cu_seqlens)
|
||||
|
||||
def forward_cuda(
|
||||
self,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
cu_seqlens: torch.Tensor | None = None,
|
||||
max_seqlen: torch.Tensor | None = None, # Only used for Flash Attention
|
||||
) -> torch.Tensor:
|
||||
if self.is_flash_attn_backend:
|
||||
return self._forward_fa(query, key, value, cu_seqlens, max_seqlen)
|
||||
elif self.attn_backend == AttentionBackendEnum.TORCH_SDPA:
|
||||
return self._forward_sdpa(query, key, value, cu_seqlens)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported multi-modal encoder attention backend for CUDA: "
|
||||
f"{self.attn_backend}."
|
||||
)
|
||||
|
||||
def forward_cpu(
|
||||
self,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
cu_seqlens: torch.Tensor | None = None,
|
||||
max_seqlen: torch.Tensor | None = None, # Only used for Flash Attention
|
||||
) -> torch.Tensor:
|
||||
return self._forward_sdpa(query, key, value, cu_seqlens)
|
||||
|
||||
def forward_xpu(
|
||||
self,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
cu_seqlens: torch.Tensor | None = None,
|
||||
max_seqlen: torch.Tensor | None = None, # Only used for Flash Attention
|
||||
) -> torch.Tensor:
|
||||
assert self.is_flash_attn_backend, (
|
||||
"XPU only supports FLASH_ATTN for vision attention."
|
||||
)
|
||||
return self._forward_fa(query, key, value, cu_seqlens, max_seqlen)
|
||||
|
||||
def forward_tpu(
|
||||
self,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
cu_seqlens: torch.Tensor | None = None,
|
||||
max_seqlen: torch.Tensor | None = None, # Only used for Flash Attention
|
||||
) -> torch.Tensor:
|
||||
assert self.attn_backend == AttentionBackendEnum.PALLAS, (
|
||||
f"MMEncoderAttention on TPU only supports PALLAS backend, "
|
||||
f"but got {self.attn_backend}."
|
||||
)
|
||||
if cu_seqlens is None:
|
||||
query, key, value = (x.transpose(1, 2) for x in (query, key, value))
|
||||
from torch_xla.experimental.custom_kernel import flash_attention
|
||||
|
||||
out = flash_attention(query, key, value, sm_scale=self.scale)
|
||||
out = out.transpose(1, 2)
|
||||
return out
|
||||
logger.warning_once(
|
||||
"PALLAS backend with cu_seqlens is not supported for ViT yet. ",
|
||||
"Falling back to SDPA implementation.",
|
||||
)
|
||||
return self._forward_sdpa(query, key, value, cu_seqlens)
|
||||
@ -44,9 +44,7 @@ def flash_attn_maxseqlen_wrapper(
|
||||
dropout_p=0.0,
|
||||
causal=False,
|
||||
)
|
||||
context_layer = einops.rearrange(
|
||||
output, "(b s) h d -> s b (h d)", b=batch_size
|
||||
).contiguous()
|
||||
context_layer = einops.rearrange(output, "(b s) h d -> b s h d", b=batch_size)
|
||||
return context_layer
|
||||
|
||||
|
||||
@ -59,8 +57,7 @@ def flash_attn_maxseqlen_wrapper_fake(
|
||||
batch_size: int,
|
||||
is_rocm_aiter: bool,
|
||||
) -> torch.Tensor:
|
||||
b, s, h, d = q.shape
|
||||
return torch.empty((s, b, h * d), dtype=q.dtype, device=q.device)
|
||||
return torch.empty_like(q)
|
||||
|
||||
|
||||
direct_register_custom_op(
|
||||
@ -106,7 +103,6 @@ def torch_sdpa_wrapper(
|
||||
output_i = einops.rearrange(output_i, "b h s d -> b s h d ")
|
||||
outputs.append(output_i)
|
||||
context_layer = torch.cat(outputs, dim=1)
|
||||
context_layer = einops.rearrange(context_layer, "b s h d -> s b (h d)").contiguous()
|
||||
return context_layer
|
||||
|
||||
|
||||
@ -116,8 +112,7 @@ def torch_sdpa_wrapper_fake(
|
||||
v: torch.Tensor,
|
||||
cu_seqlens: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
b, s, h, d = q.shape
|
||||
return torch.empty((s, b, h * d), dtype=q.dtype, device=q.device)
|
||||
return torch.empty_like(q)
|
||||
|
||||
|
||||
direct_register_custom_op(
|
||||
|
||||
@ -5,15 +5,14 @@ from typing import Annotated, Literal, TypeAlias
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.nn import LayerNorm
|
||||
from transformers.models.qwen2_vl import Qwen2VLProcessor
|
||||
|
||||
from vllm.attention.backends.registry import AttentionBackendEnum
|
||||
from vllm.attention.layer import (
|
||||
maybe_get_vit_flash_attn_backend,
|
||||
from vllm.attention.layers.mm_encoder_attention import (
|
||||
MMEncoderAttention,
|
||||
)
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config import MultiModalConfig, VllmConfig
|
||||
from vllm.config.multimodal import BaseDummyOptions
|
||||
from vllm.distributed import utils as dist_utils
|
||||
from vllm.distributed.parallel_state import (
|
||||
@ -254,11 +253,15 @@ class DotsVisionAttention(nn.Module):
|
||||
bias: bool = True,
|
||||
*,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
multimodal_config: MultiModalConfig | None = None,
|
||||
prefix: str = "",
|
||||
use_data_parallel: bool = False,
|
||||
attn_backend_override: AttentionBackendEnum | None = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
use_data_parallel = (
|
||||
multimodal_config.mm_encoder_tp_mode == "data"
|
||||
if multimodal_config
|
||||
else False
|
||||
)
|
||||
|
||||
self.embed_dim = dim
|
||||
self.tp_size = (
|
||||
@ -287,31 +290,13 @@ class DotsVisionAttention(nn.Module):
|
||||
prefix=f"{prefix}.proj",
|
||||
disable_tp=use_data_parallel,
|
||||
)
|
||||
# Select attention backend
|
||||
self.attn_backend = get_vit_attn_backend(
|
||||
self.hidden_size_per_attention_head,
|
||||
torch.get_default_dtype(),
|
||||
attn_backend_override=attn_backend_override,
|
||||
)
|
||||
|
||||
self.attn_backend, self.flash_attn_varlen_func = (
|
||||
maybe_get_vit_flash_attn_backend(
|
||||
self.attn_backend,
|
||||
attn_backend_override=attn_backend_override,
|
||||
)
|
||||
self.attn = MMEncoderAttention(
|
||||
num_heads=self.num_attention_heads_per_partition,
|
||||
head_size=self.hidden_size_per_attention_head,
|
||||
multimodal_config=multimodal_config,
|
||||
prefix=f"{prefix}.attn",
|
||||
)
|
||||
if self.attn_backend not in {
|
||||
AttentionBackendEnum.FLASH_ATTN,
|
||||
AttentionBackendEnum.TORCH_SDPA,
|
||||
AttentionBackendEnum.ROCM_AITER_FA,
|
||||
}:
|
||||
raise RuntimeError(
|
||||
f"Unsupported vision attention backend: {self.attn_backend}"
|
||||
)
|
||||
self.is_flash_attn_backend = self.attn_backend in {
|
||||
AttentionBackendEnum.FLASH_ATTN,
|
||||
AttentionBackendEnum.ROCM_AITER_FA,
|
||||
}
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -319,7 +304,7 @@ class DotsVisionAttention(nn.Module):
|
||||
cu_seqlens: torch.Tensor,
|
||||
rotary_pos_emb: torch.Tensor | None = None,
|
||||
*,
|
||||
max_seqlen: int | None = None,
|
||||
max_seqlen: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
# [S, C] -> [S, B=1, C]
|
||||
x = hidden_states.unsqueeze(1)
|
||||
@ -336,41 +321,13 @@ class DotsVisionAttention(nn.Module):
|
||||
qk_rotated = apply_rotary_pos_emb_vision(qk_concat, rotary_pos_emb)
|
||||
q, k = torch.chunk(qk_rotated, 2, dim=0)
|
||||
|
||||
if self.is_flash_attn_backend:
|
||||
q_ = q.reshape(bs * q.shape[1], q.shape[2], q.shape[3])
|
||||
k_ = k.reshape(bs * k.shape[1], k.shape[2], k.shape[3])
|
||||
v_ = v.reshape(bs * v.shape[1], v.shape[2], v.shape[3])
|
||||
output = self.flash_attn_varlen_func(
|
||||
q_,
|
||||
k_,
|
||||
v_,
|
||||
cu_seqlens_q=cu_seqlens,
|
||||
cu_seqlens_k=cu_seqlens,
|
||||
max_seqlen_q=max_seqlen,
|
||||
max_seqlen_k=max_seqlen,
|
||||
dropout_p=0.0,
|
||||
causal=False,
|
||||
)
|
||||
context_layer = output.view(
|
||||
bs,
|
||||
-1,
|
||||
self.num_attention_heads_per_partition,
|
||||
self.hidden_size_per_attention_head,
|
||||
)
|
||||
elif self.attn_backend == AttentionBackendEnum.TORCH_SDPA:
|
||||
outputs = []
|
||||
for i in range(1, len(cu_seqlens)):
|
||||
s = int(cu_seqlens[i - 1])
|
||||
e = int(cu_seqlens[i])
|
||||
q_i = q[:, s:e].permute(0, 2, 1, 3)
|
||||
k_i = k[:, s:e].permute(0, 2, 1, 3)
|
||||
v_i = v[:, s:e].permute(0, 2, 1, 3)
|
||||
out_i = F.scaled_dot_product_attention(q_i, k_i, v_i, dropout_p=0.0)
|
||||
out_i = out_i.permute(0, 2, 1, 3)
|
||||
outputs.append(out_i)
|
||||
context_layer = torch.cat(outputs, dim=1) if outputs else q[:, :0]
|
||||
else:
|
||||
raise RuntimeError("Unsupported attention backend")
|
||||
context_layer = self.attn(
|
||||
query=q,
|
||||
key=k,
|
||||
value=v,
|
||||
cu_seqlens=cu_seqlens,
|
||||
max_seqlen=max_seqlen,
|
||||
)
|
||||
|
||||
# [B,S,H,D] -> [S,B,H*D] -> [S, C]
|
||||
context_layer = context_layer.permute(1, 0, 2, 3).contiguous()
|
||||
@ -385,14 +342,19 @@ class DotsSwiGLUFFN(nn.Module):
|
||||
config,
|
||||
*,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
multimodal_config: MultiModalConfig | None = None,
|
||||
prefix: str = "",
|
||||
use_data_parallel: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
hidden_features = config.intermediate_size
|
||||
in_features = config.embed_dim
|
||||
bias = config.use_bias
|
||||
|
||||
use_data_parallel = (
|
||||
multimodal_config.mm_encoder_tp_mode == "data"
|
||||
if multimodal_config
|
||||
else False
|
||||
)
|
||||
# Referenced aimv2.py AIMv2SwiGLUFFN
|
||||
self.fc13 = MergedColumnParallelLinear(
|
||||
in_features,
|
||||
@ -498,9 +460,8 @@ class DotsVisionBlock(nn.Module):
|
||||
config,
|
||||
*,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
multimodal_config: MultiModalConfig | None = None,
|
||||
prefix: str = "",
|
||||
use_data_parallel: bool = False,
|
||||
attn_backend_override: AttentionBackendEnum | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@ -510,16 +471,15 @@ class DotsVisionBlock(nn.Module):
|
||||
num_heads=config.num_attention_heads,
|
||||
bias=config.use_bias,
|
||||
quant_config=quant_config,
|
||||
multimodal_config=multimodal_config,
|
||||
prefix=f"{prefix}.attn",
|
||||
use_data_parallel=use_data_parallel,
|
||||
attn_backend_override=attn_backend_override,
|
||||
)
|
||||
self.norm1 = RMSNorm(config.embed_dim, eps=config.rms_norm_eps)
|
||||
self.mlp = DotsSwiGLUFFN(
|
||||
config,
|
||||
quant_config=quant_config,
|
||||
multimodal_config=multimodal_config,
|
||||
prefix=f"{prefix}.mlp",
|
||||
use_data_parallel=use_data_parallel,
|
||||
)
|
||||
self.norm2 = RMSNorm(config.embed_dim, eps=config.rms_norm_eps)
|
||||
|
||||
@ -546,12 +506,11 @@ class DotsVisionTransformer(nn.Module):
|
||||
self,
|
||||
config: DotsVisionConfig,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
multimodal_config: MultiModalConfig | None = None,
|
||||
*,
|
||||
num_hidden_layers_override: int | None = None,
|
||||
require_post_norm: bool | None = None,
|
||||
prefix: str = "",
|
||||
use_data_parallel: bool = False,
|
||||
attn_backend_override: AttentionBackendEnum | None = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
@ -561,6 +520,11 @@ class DotsVisionTransformer(nn.Module):
|
||||
|
||||
head_dim = config.embed_dim // config.num_attention_heads
|
||||
self.rotary_pos_emb = VisionRotaryEmbedding(head_dim // 2)
|
||||
attn_backend_override = (
|
||||
multimodal_config.mm_encoder_attn_backend
|
||||
if multimodal_config is not None
|
||||
else None
|
||||
)
|
||||
self.attn_backend = get_vit_attn_backend(
|
||||
head_size=head_dim,
|
||||
dtype=torch.get_default_dtype(),
|
||||
@ -578,9 +542,8 @@ class DotsVisionTransformer(nn.Module):
|
||||
DotsVisionBlock(
|
||||
config,
|
||||
quant_config=quant_config,
|
||||
multimodal_config=multimodal_config,
|
||||
prefix=f"{prefix}.blocks.{i}",
|
||||
use_data_parallel=use_data_parallel,
|
||||
attn_backend_override=attn_backend_override,
|
||||
)
|
||||
for i in range(num_layers)
|
||||
]
|
||||
@ -592,6 +555,11 @@ class DotsVisionTransformer(nn.Module):
|
||||
else:
|
||||
self.post_trunk_norm = None
|
||||
|
||||
use_data_parallel = (
|
||||
multimodal_config.mm_encoder_tp_mode == "data"
|
||||
if multimodal_config
|
||||
else False
|
||||
)
|
||||
self.merger = PatchMerger(
|
||||
dim=config.hidden_size,
|
||||
context_dim=config.embed_dim,
|
||||
@ -647,7 +615,7 @@ class DotsVisionTransformer(nn.Module):
|
||||
self.attn_backend == AttentionBackendEnum.FLASH_ATTN
|
||||
or self.attn_backend == AttentionBackendEnum.ROCM_AITER_FA
|
||||
):
|
||||
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
|
||||
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
|
||||
return max_seqlen
|
||||
|
||||
def forward(
|
||||
@ -733,17 +701,12 @@ class DotsOCRForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA
|
||||
self.config.vision_config = vision_config
|
||||
else:
|
||||
vision_config = self.config.vision_config
|
||||
attn_backend_override = (
|
||||
multimodal_config.mm_encoder_attn_backend
|
||||
if multimodal_config is not None
|
||||
else None
|
||||
)
|
||||
|
||||
self.vision_tower = DotsVisionTransformer(
|
||||
vision_config,
|
||||
quant_config=self.quant_config,
|
||||
multimodal_config=multimodal_config,
|
||||
prefix=maybe_prefix(prefix, "vision_tower"),
|
||||
use_data_parallel=self.use_data_parallel,
|
||||
attn_backend_override=attn_backend_override,
|
||||
)
|
||||
self.language_model: Qwen2ForCausalLM = init_vllm_registered_model(
|
||||
vllm_config=vllm_config,
|
||||
|
||||
@ -37,10 +37,10 @@ from einops import rearrange, repeat
|
||||
from transformers import BatchFeature
|
||||
|
||||
from vllm.attention.backends.registry import AttentionBackendEnum
|
||||
from vllm.attention.layer import (
|
||||
maybe_get_vit_flash_attn_backend,
|
||||
from vllm.attention.layers.mm_encoder_attention import (
|
||||
MMEncoderAttention,
|
||||
)
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config import MultiModalConfig, VllmConfig
|
||||
from vllm.config.multimodal import BaseDummyOptions
|
||||
from vllm.distributed import parallel_state
|
||||
from vllm.distributed import utils as dist_utils
|
||||
@ -163,8 +163,8 @@ class Ernie4_5_VisionAttention(nn.Module):
|
||||
num_heads: int,
|
||||
projection_size: int,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
multimodal_config: MultiModalConfig | None = None,
|
||||
prefix: str = "",
|
||||
attn_backend_override: AttentionBackendEnum | None = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
# Per attention head and per partition values.
|
||||
@ -193,33 +193,13 @@ class Ernie4_5_VisionAttention(nn.Module):
|
||||
prefix=f"{prefix}.proj",
|
||||
)
|
||||
|
||||
# Detect attention implementation.
|
||||
self.attn_backend = get_vit_attn_backend(
|
||||
self.attn = MMEncoderAttention(
|
||||
num_heads=self.num_attention_heads_per_partition,
|
||||
head_size=self.hidden_size_per_attention_head,
|
||||
dtype=torch.get_default_dtype(),
|
||||
attn_backend_override=attn_backend_override,
|
||||
multimodal_config=multimodal_config,
|
||||
prefix=f"{prefix}.attn",
|
||||
)
|
||||
|
||||
self.attn_backend, self.flash_attn_varlen_func = (
|
||||
maybe_get_vit_flash_attn_backend(
|
||||
self.attn_backend,
|
||||
attn_backend_override=attn_backend_override,
|
||||
)
|
||||
)
|
||||
|
||||
if self.attn_backend not in {
|
||||
AttentionBackendEnum.FLASH_ATTN,
|
||||
AttentionBackendEnum.TORCH_SDPA,
|
||||
AttentionBackendEnum.ROCM_AITER_FA,
|
||||
}:
|
||||
raise RuntimeError(
|
||||
f"Ernie45-VL does not support {self.attn_backend} backend now."
|
||||
)
|
||||
self.is_flash_attn_backend = self.attn_backend in {
|
||||
AttentionBackendEnum.FLASH_ATTN,
|
||||
AttentionBackendEnum.ROCM_AITER_FA,
|
||||
}
|
||||
|
||||
def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]:
|
||||
# [s, b, 3 * head * head_dim]
|
||||
seq_len, bs, _ = qkv.shape
|
||||
@ -253,14 +233,13 @@ class Ernie4_5_VisionAttention(nn.Module):
|
||||
x: torch.Tensor,
|
||||
cu_seqlens: torch.Tensor,
|
||||
rotary_pos_emb: torch.Tensor,
|
||||
max_seqlen: int | None = None, # Only used for Flash Attention
|
||||
max_seqlen: torch.Tensor | None = None, # Only used for Flash Attention
|
||||
) -> torch.Tensor:
|
||||
# [s, b, c] --> [s, b, head * 3 * head_dim]
|
||||
x, _ = self.qkv(x)
|
||||
|
||||
# [s, b, 3 * head * head_dim] -> 3 * [s, b, head, head_dim]
|
||||
q, k, v = self.split_qkv(x)
|
||||
batch_size = q.shape[1]
|
||||
|
||||
q, k, v = (rearrange(x, "s b ... -> b s ...").contiguous() for x in (q, k, v))
|
||||
if rotary_pos_emb is not None:
|
||||
@ -268,43 +247,14 @@ class Ernie4_5_VisionAttention(nn.Module):
|
||||
qk_rotated = apply_rotary_pos_emb_vision(qk_concat, rotary_pos_emb)
|
||||
q, k = torch.chunk(qk_rotated, 2, dim=0)
|
||||
|
||||
if self.is_flash_attn_backend:
|
||||
q, k, v = (rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v])
|
||||
|
||||
output = self.flash_attn_varlen_func(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
cu_seqlens_q=cu_seqlens,
|
||||
cu_seqlens_k=cu_seqlens,
|
||||
max_seqlen_q=max_seqlen,
|
||||
max_seqlen_k=max_seqlen,
|
||||
dropout_p=0.0,
|
||||
causal=False,
|
||||
)
|
||||
|
||||
context_layer = rearrange(
|
||||
output, "(b s) h d -> s b (h d)", b=batch_size
|
||||
).contiguous()
|
||||
elif self.attn_backend == AttentionBackendEnum.TORCH_SDPA:
|
||||
# Execute attention entry by entry for speed & less VRAM.
|
||||
outputs = []
|
||||
|
||||
lens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
|
||||
q_chunks = torch.split(q, lens, dim=1)
|
||||
k_chunks = torch.split(k, lens, dim=1)
|
||||
v_chunks = torch.split(v, lens, dim=1)
|
||||
for q_i, k_i, v_i in zip(q_chunks, k_chunks, v_chunks):
|
||||
q_i, k_i, v_i = (
|
||||
rearrange(x, "b s h d -> b h s d") for x in [q_i, k_i, v_i]
|
||||
)
|
||||
output_i = F.scaled_dot_product_attention(q_i, k_i, v_i, dropout_p=0.0)
|
||||
output_i = rearrange(output_i, "b h s d -> b s h d ")
|
||||
outputs.append(output_i)
|
||||
context_layer = torch.cat(outputs, dim=1)
|
||||
context_layer = rearrange(
|
||||
context_layer, "b s h d -> s b (h d)"
|
||||
).contiguous()
|
||||
output = self.attn(
|
||||
query=q,
|
||||
key=k,
|
||||
value=v,
|
||||
cu_seqlens=cu_seqlens,
|
||||
max_seqlen=max_seqlen,
|
||||
)
|
||||
context_layer = rearrange(output, "b s h d -> s b (h d)").contiguous()
|
||||
|
||||
output, _ = self.proj(context_layer)
|
||||
return output
|
||||
@ -350,8 +300,8 @@ class Ernie4_5_VisionBlock(nn.Module):
|
||||
act_layer: type[nn.Module] = QuickGELU,
|
||||
norm_layer: Callable[[int], nn.Module] | None = None,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
multimodal_config: MultiModalConfig | None = None,
|
||||
prefix: str = "",
|
||||
attn_backend_override: AttentionBackendEnum | None = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
@ -366,8 +316,8 @@ class Ernie4_5_VisionBlock(nn.Module):
|
||||
num_heads=num_heads,
|
||||
projection_size=dim,
|
||||
quant_config=quant_config,
|
||||
multimodal_config=multimodal_config,
|
||||
prefix=f"{prefix}.attn",
|
||||
attn_backend_override=attn_backend_override,
|
||||
)
|
||||
|
||||
self.mlp = Ernie4_5_VisionMLP(
|
||||
@ -383,7 +333,7 @@ class Ernie4_5_VisionBlock(nn.Module):
|
||||
hidden_states: torch.Tensor,
|
||||
cu_seqlens: torch.Tensor,
|
||||
rotary_pos_emb: torch.Tensor,
|
||||
max_seqlen: int | None = None, # Only used for Flash Attention
|
||||
max_seqlen: torch.Tensor | None = None, # Only used for Flash Attention
|
||||
) -> torch.Tensor:
|
||||
hidden_states = hidden_states + self.attn(
|
||||
self.norm1(hidden_states),
|
||||
@ -441,8 +391,8 @@ class Ernie4_5_VisionTransformer(nn.Module):
|
||||
vision_config,
|
||||
norm_eps: float = 1e-6,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
multimodal_config: MultiModalConfig | None = None,
|
||||
prefix: str = "",
|
||||
attn_backend_override: AttentionBackendEnum | None = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
patch_size = vision_config.patch_size
|
||||
@ -477,8 +427,8 @@ class Ernie4_5_VisionTransformer(nn.Module):
|
||||
mlp_ratio=mlp_ratio,
|
||||
norm_layer=norm_layer,
|
||||
quant_config=quant_config,
|
||||
multimodal_config=multimodal_config,
|
||||
prefix=f"{prefix}.blocks.{layer_idx}",
|
||||
attn_backend_override=attn_backend_override,
|
||||
)
|
||||
for layer_idx in range(depth)
|
||||
]
|
||||
@ -489,6 +439,9 @@ class Ernie4_5_VisionTransformer(nn.Module):
|
||||
)
|
||||
self.ln = nn.LayerNorm(hidden_size, eps=1e-6)
|
||||
|
||||
attn_backend_override = (
|
||||
multimodal_config.mm_encoder_attn_backend if multimodal_config else None
|
||||
)
|
||||
self.attn_backend = get_vit_attn_backend(
|
||||
head_size=head_dim,
|
||||
dtype=torch.get_default_dtype(),
|
||||
@ -535,13 +488,13 @@ class Ernie4_5_VisionTransformer(nn.Module):
|
||||
rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
|
||||
return rotary_pos_emb
|
||||
|
||||
def compute_attn_mask_seqlen(self, cu_seqlens: torch.Tensor) -> int | None:
|
||||
def compute_attn_mask_seqlen(self, cu_seqlens: torch.Tensor) -> torch.Tensor | None:
|
||||
max_seqlen = None
|
||||
if (
|
||||
self.attn_backend == AttentionBackendEnum.FLASH_ATTN
|
||||
or self.attn_backend == AttentionBackendEnum.ROCM_AITER_FA
|
||||
):
|
||||
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
|
||||
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
|
||||
return max_seqlen
|
||||
|
||||
def forward(
|
||||
@ -1304,17 +1257,12 @@ class Ernie4_5_VLMoeForConditionalGeneration(
|
||||
self.config = config
|
||||
self.multimodal_config = multimodal_config
|
||||
|
||||
attn_backend_override = (
|
||||
multimodal_config.mm_encoder_attn_backend
|
||||
if multimodal_config is not None
|
||||
else None
|
||||
)
|
||||
self.vision_model = Ernie4_5_VisionTransformer(
|
||||
config.vision_config,
|
||||
norm_eps=getattr(config, "rms_norm_eps", 1e-6),
|
||||
quant_config=quant_config,
|
||||
multimodal_config=multimodal_config,
|
||||
prefix=maybe_prefix(prefix, "vision_model"),
|
||||
attn_backend_override=attn_backend_override,
|
||||
)
|
||||
|
||||
self.language_model = Ernie4_5_VLMoeForCausalLM(
|
||||
|
||||
@ -47,8 +47,10 @@ from transformers.models.glm4v.video_processing_glm4v import Glm4vVideoProcessor
|
||||
from transformers.video_utils import VideoMetadata
|
||||
|
||||
from vllm.attention.backends.registry import AttentionBackendEnum
|
||||
from vllm.attention.layer import maybe_get_vit_flash_attn_backend
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.attention.layers.mm_encoder_attention import (
|
||||
MMEncoderAttention,
|
||||
)
|
||||
from vllm.config import MultiModalConfig, VllmConfig
|
||||
from vllm.config.multimodal import BaseDummyOptions, VideoDummyOptions
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size, parallel_state
|
||||
from vllm.distributed import utils as dist_utils
|
||||
@ -191,10 +193,15 @@ class Glm4vVisionMLP(nn.Module):
|
||||
hidden_features: int,
|
||||
bias: bool = False,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
multimodal_config: MultiModalConfig | None = None,
|
||||
prefix: str = "",
|
||||
use_data_parallel: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
use_data_parallel = (
|
||||
multimodal_config.mm_encoder_tp_mode == "data"
|
||||
if multimodal_config
|
||||
else False
|
||||
)
|
||||
self.gate_up_proj = MergedColumnParallelLinear(
|
||||
input_size=in_features,
|
||||
output_sizes=[hidden_features] * 2,
|
||||
@ -248,12 +255,16 @@ class Glm4vVisionAttention(nn.Module):
|
||||
num_heads: int,
|
||||
projection_size: int,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
multimodal_config: MultiModalConfig | None = None,
|
||||
prefix: str = "",
|
||||
use_data_parallel: bool = False,
|
||||
attn_backend_override: AttentionBackendEnum | None = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
# Per attention head and per partition values.
|
||||
use_data_parallel = (
|
||||
multimodal_config.mm_encoder_tp_mode == "data"
|
||||
if multimodal_config
|
||||
else False
|
||||
)
|
||||
self.tp_size = (
|
||||
1 if use_data_parallel else get_tensor_model_parallel_world_size()
|
||||
)
|
||||
@ -287,34 +298,12 @@ class Glm4vVisionAttention(nn.Module):
|
||||
disable_tp=use_data_parallel,
|
||||
)
|
||||
|
||||
# Detect attention implementation.
|
||||
self.attn_backend = get_vit_attn_backend(
|
||||
self.attn = MMEncoderAttention(
|
||||
num_heads=self.num_attention_heads_per_partition,
|
||||
head_size=self.hidden_size_per_attention_head,
|
||||
dtype=torch.get_default_dtype(),
|
||||
attn_backend_override=attn_backend_override,
|
||||
multimodal_config=multimodal_config,
|
||||
)
|
||||
|
||||
self.attn_backend, self.flash_attn_varlen_func = (
|
||||
maybe_get_vit_flash_attn_backend(
|
||||
self.attn_backend,
|
||||
attn_backend_override=attn_backend_override,
|
||||
)
|
||||
)
|
||||
|
||||
if self.attn_backend not in {
|
||||
AttentionBackendEnum.FLASH_ATTN,
|
||||
AttentionBackendEnum.TORCH_SDPA,
|
||||
AttentionBackendEnum.ROCM_AITER_FA,
|
||||
}:
|
||||
raise RuntimeError(
|
||||
f"GLM-4V does not support {self.attn_backend} backend now."
|
||||
)
|
||||
|
||||
self.is_flash_attn_backend = self.attn_backend in {
|
||||
AttentionBackendEnum.FLASH_ATTN,
|
||||
AttentionBackendEnum.ROCM_AITER_FA,
|
||||
}
|
||||
|
||||
def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]:
|
||||
# [s, b, 3 * head * head_dim]
|
||||
seq_len, bs, _ = qkv.shape
|
||||
@ -338,14 +327,13 @@ class Glm4vVisionAttention(nn.Module):
|
||||
cu_seqlens: torch.Tensor,
|
||||
rotary_pos_emb_cos: torch.Tensor,
|
||||
rotary_pos_emb_sin: torch.Tensor,
|
||||
max_seqlen: int | None = None, # Only used for Flash Attention
|
||||
max_seqlen: torch.Tensor | None = None, # Only used for Flash Attention
|
||||
) -> torch.Tensor:
|
||||
# [s, b, c] --> [s, b, head * 3 * head_dim]
|
||||
x, _ = self.qkv(x)
|
||||
|
||||
# [s, b, 3 * head * head_dim] -> 3 * [s, b, head, head_dim]
|
||||
q, k, v = self.split_qkv(x)
|
||||
batch_size = q.shape[1]
|
||||
|
||||
q, k, v = (rearrange(x, "s b ... -> b s ...").contiguous() for x in (q, k, v))
|
||||
if rotary_pos_emb_cos is not None and rotary_pos_emb_sin is not None:
|
||||
@ -356,43 +344,14 @@ class Glm4vVisionAttention(nn.Module):
|
||||
)
|
||||
q, k = torch.chunk(qk_rotated, 2, dim=0)
|
||||
|
||||
if self.is_flash_attn_backend:
|
||||
q, k, v = (rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v])
|
||||
|
||||
output = self.flash_attn_varlen_func(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
cu_seqlens_q=cu_seqlens,
|
||||
cu_seqlens_k=cu_seqlens,
|
||||
max_seqlen_q=max_seqlen,
|
||||
max_seqlen_k=max_seqlen,
|
||||
dropout_p=0.0,
|
||||
causal=False,
|
||||
)
|
||||
|
||||
context_layer = rearrange(
|
||||
output, "(b s) h d -> s b (h d)", b=batch_size
|
||||
).contiguous()
|
||||
elif self.attn_backend == AttentionBackendEnum.TORCH_SDPA:
|
||||
# Execute attention entry by entry for speed & less VRAM.
|
||||
outputs = []
|
||||
|
||||
lens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
|
||||
q_chunks = torch.split(q, lens, dim=1)
|
||||
k_chunks = torch.split(k, lens, dim=1)
|
||||
v_chunks = torch.split(v, lens, dim=1)
|
||||
for q_i, k_i, v_i in zip(q_chunks, k_chunks, v_chunks):
|
||||
q_i, k_i, v_i = (
|
||||
rearrange(x, "b s h d -> b h s d") for x in [q_i, k_i, v_i]
|
||||
)
|
||||
output_i = F.scaled_dot_product_attention(q_i, k_i, v_i, dropout_p=0.0)
|
||||
output_i = rearrange(output_i, "b h s d -> b s h d ")
|
||||
outputs.append(output_i)
|
||||
context_layer = torch.cat(outputs, dim=1)
|
||||
context_layer = rearrange(
|
||||
context_layer, "b s h d -> s b (h d)"
|
||||
).contiguous()
|
||||
context_layer = self.attn(
|
||||
query=q,
|
||||
key=k,
|
||||
value=v,
|
||||
cu_seqlens=cu_seqlens,
|
||||
max_seqlen=max_seqlen,
|
||||
)
|
||||
context_layer = rearrange(context_layer, "b s h d -> s b (h d)").contiguous()
|
||||
|
||||
output, _ = self.proj(context_layer)
|
||||
return output
|
||||
@ -406,9 +365,8 @@ class Glm4vVisionBlock(nn.Module):
|
||||
mlp_hidden_dim: int,
|
||||
norm_layer: Callable[[int], nn.Module] | None = None,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
multimodal_config: MultiModalConfig | None = None,
|
||||
prefix: str = "",
|
||||
use_data_parallel: bool = False,
|
||||
attn_backend_override: AttentionBackendEnum | None = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
if norm_layer is None:
|
||||
@ -420,17 +378,16 @@ class Glm4vVisionBlock(nn.Module):
|
||||
num_heads=num_heads,
|
||||
projection_size=dim,
|
||||
quant_config=quant_config,
|
||||
multimodal_config=multimodal_config,
|
||||
prefix=f"{prefix}.attn",
|
||||
use_data_parallel=use_data_parallel,
|
||||
attn_backend_override=attn_backend_override,
|
||||
)
|
||||
self.mlp = Glm4vVisionMLP(
|
||||
dim,
|
||||
mlp_hidden_dim,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
multimodal_config=multimodal_config,
|
||||
prefix=f"{prefix}.mlp",
|
||||
use_data_parallel=use_data_parallel,
|
||||
)
|
||||
|
||||
def forward(
|
||||
@ -489,11 +446,16 @@ class Glm4vPatchMerger(nn.Module):
|
||||
d_model: int,
|
||||
context_dim: int,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
multimodal_config: MultiModalConfig | None = None,
|
||||
bias: bool = False,
|
||||
prefix: str = "",
|
||||
use_data_parallel: bool = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
use_data_parallel = (
|
||||
multimodal_config.mm_encoder_tp_mode == "data"
|
||||
if multimodal_config
|
||||
else False
|
||||
)
|
||||
self.hidden_size = d_model
|
||||
self.proj = ColumnParallelLinear(
|
||||
self.hidden_size,
|
||||
@ -649,19 +611,19 @@ class Glm4vVisionTransformer(nn.Module):
|
||||
vision_config: Glm4vVisionConfig,
|
||||
norm_eps: float = 1e-6,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
multimodal_config: MultiModalConfig | None = None,
|
||||
prefix: str = "",
|
||||
use_data_parallel: bool = False,
|
||||
attn_backend_override: AttentionBackendEnum | None = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
assert multimodal_config is not None, "multimodal_config must be provided"
|
||||
|
||||
patch_size = vision_config.patch_size
|
||||
temporal_patch_size = vision_config.temporal_patch_size
|
||||
in_channels = vision_config.in_channels
|
||||
depth = vision_config.depth
|
||||
self.hidden_size = vision_config.hidden_size
|
||||
self.num_heads = vision_config.num_heads
|
||||
self.use_data_parallel = use_data_parallel
|
||||
|
||||
self.patch_size = vision_config.patch_size
|
||||
self.spatial_merge_size = vision_config.spatial_merge_size
|
||||
@ -690,9 +652,8 @@ class Glm4vVisionTransformer(nn.Module):
|
||||
mlp_hidden_dim=vision_config.out_hidden_size,
|
||||
norm_layer=norm_layer,
|
||||
quant_config=quant_config,
|
||||
multimodal_config=multimodal_config,
|
||||
prefix=f"{prefix}.blocks.{layer_idx}",
|
||||
use_data_parallel=self.use_data_parallel,
|
||||
attn_backend_override=attn_backend_override,
|
||||
)
|
||||
for layer_idx in range(depth)
|
||||
]
|
||||
@ -701,9 +662,9 @@ class Glm4vVisionTransformer(nn.Module):
|
||||
d_model=vision_config.out_hidden_size,
|
||||
context_dim=vision_config.intermediate_size,
|
||||
quant_config=quant_config,
|
||||
multimodal_config=multimodal_config,
|
||||
bias=False,
|
||||
prefix=f"{prefix}.merger",
|
||||
use_data_parallel=self.use_data_parallel,
|
||||
)
|
||||
self.embeddings = Glm4vVisionEmbeddings(vision_config)
|
||||
|
||||
@ -723,7 +684,7 @@ class Glm4vVisionTransformer(nn.Module):
|
||||
self.attn_backend = get_vit_attn_backend(
|
||||
head_size=head_dim,
|
||||
dtype=torch.get_default_dtype(),
|
||||
attn_backend_override=attn_backend_override,
|
||||
attn_backend_override=multimodal_config.mm_encoder_attn_backend,
|
||||
)
|
||||
|
||||
@property
|
||||
@ -775,13 +736,13 @@ class Glm4vVisionTransformer(nn.Module):
|
||||
def compute_attn_mask_seqlen(
|
||||
self,
|
||||
cu_seqlens: torch.Tensor,
|
||||
) -> int | None:
|
||||
) -> torch.Tensor | None:
|
||||
max_seqlen = None
|
||||
if (
|
||||
self.attn_backend == AttentionBackendEnum.FLASH_ATTN
|
||||
or self.attn_backend == AttentionBackendEnum.ROCM_AITER_FA
|
||||
):
|
||||
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
|
||||
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
|
||||
return max_seqlen
|
||||
|
||||
def forward(
|
||||
@ -1465,18 +1426,12 @@ class Glm4vForConditionalGeneration(
|
||||
self.multimodal_config = multimodal_config
|
||||
self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data"
|
||||
|
||||
attn_backend_override = (
|
||||
multimodal_config.mm_encoder_attn_backend
|
||||
if multimodal_config is not None
|
||||
else None
|
||||
)
|
||||
self.visual = Glm4vVisionTransformer(
|
||||
config.vision_config,
|
||||
norm_eps=getattr(config, "rms_norm_eps", 1e-5),
|
||||
quant_config=quant_config,
|
||||
multimodal_config=multimodal_config,
|
||||
prefix=maybe_prefix(prefix, "visual"),
|
||||
use_data_parallel=self.use_data_parallel,
|
||||
attn_backend_override=attn_backend_override,
|
||||
)
|
||||
|
||||
if config.model_type == "glm4v":
|
||||
|
||||
@ -9,7 +9,6 @@ from typing import Annotated, Any, Literal, TypeAlias, TypeVar
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange
|
||||
from transformers import PretrainedConfig
|
||||
from transformers.activations import GELUActivation
|
||||
@ -17,11 +16,10 @@ from transformers.feature_extraction_utils import BatchFeature
|
||||
from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
|
||||
from transformers.utils import torch_int
|
||||
|
||||
from vllm.attention.backends.registry import AttentionBackendEnum
|
||||
from vllm.attention.layer import (
|
||||
maybe_get_vit_flash_attn_backend,
|
||||
from vllm.attention.layers.mm_encoder_attention import (
|
||||
MMEncoderAttention,
|
||||
)
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config import MultiModalConfig, VllmConfig
|
||||
from vllm.config.multimodal import BaseDummyOptions
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.logger import init_logger
|
||||
@ -80,7 +78,6 @@ from .utils import (
|
||||
is_pp_missing_parameter,
|
||||
maybe_prefix,
|
||||
)
|
||||
from .vision import get_vit_attn_backend
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -369,8 +366,8 @@ class KeyeSiglipAttention(nn.Module):
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
multimodal_config: MultiModalConfig | None = None,
|
||||
prefix: str = "",
|
||||
attn_backend_override: AttentionBackendEnum | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
@ -408,34 +405,14 @@ class KeyeSiglipAttention(nn.Module):
|
||||
prefix=f"{prefix}.out_proj",
|
||||
)
|
||||
|
||||
# Detect attention implementation.
|
||||
self.attn_backend = get_vit_attn_backend(
|
||||
self.attn = MMEncoderAttention(
|
||||
num_heads=self.num_heads,
|
||||
head_size=self.head_dim,
|
||||
dtype=torch.get_default_dtype(),
|
||||
attn_backend_override=attn_backend_override,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
prefix=f"{prefix}.attn",
|
||||
multimodal_config=multimodal_config,
|
||||
)
|
||||
|
||||
self.attn_backend, self.flash_attn_varlen_func = (
|
||||
maybe_get_vit_flash_attn_backend(
|
||||
self.attn_backend,
|
||||
attn_backend_override=attn_backend_override,
|
||||
)
|
||||
)
|
||||
|
||||
if self.attn_backend not in {
|
||||
AttentionBackendEnum.FLASH_ATTN,
|
||||
AttentionBackendEnum.TORCH_SDPA,
|
||||
AttentionBackendEnum.ROCM_AITER_FA,
|
||||
}:
|
||||
raise RuntimeError(
|
||||
f"Keye-VL does not support {self.attn_backend} backend now."
|
||||
)
|
||||
|
||||
self.is_flash_attn_backend = self.attn_backend in {
|
||||
AttentionBackendEnum.FLASH_ATTN,
|
||||
AttentionBackendEnum.ROCM_AITER_FA,
|
||||
}
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@ -450,8 +427,7 @@ class KeyeSiglipAttention(nn.Module):
|
||||
dim=-1,
|
||||
)
|
||||
|
||||
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
|
||||
batch_size = q.shape[0]
|
||||
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
|
||||
|
||||
if rope_emb is None:
|
||||
q = q.view(*q.shape[:-1], self.num_heads, self.head_dim)
|
||||
@ -482,38 +458,14 @@ class KeyeSiglipAttention(nn.Module):
|
||||
self.head_dim,
|
||||
)
|
||||
|
||||
if self.is_flash_attn_backend:
|
||||
q, k, v = (rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v])
|
||||
|
||||
output = self.flash_attn_varlen_func(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
cu_seqlens_q=cu_seqlens,
|
||||
cu_seqlens_k=cu_seqlens,
|
||||
max_seqlen_q=max_seqlen,
|
||||
max_seqlen_k=max_seqlen,
|
||||
causal=False,
|
||||
softmax_scale=self.scale,
|
||||
)
|
||||
context_layer = rearrange(output, "(b s) ... -> b s ...", b=batch_size)
|
||||
elif self.attn_backend == AttentionBackendEnum.TORCH_SDPA:
|
||||
outputs = []
|
||||
for i in range(1, len(cu_seqlens)):
|
||||
start_idx = cu_seqlens[i - 1]
|
||||
end_idx = cu_seqlens[i]
|
||||
q_i = q[:, start_idx:end_idx]
|
||||
k_i = k[:, start_idx:end_idx]
|
||||
v_i = v[:, start_idx:end_idx]
|
||||
q_i, k_i, v_i = (
|
||||
rearrange(x, "b s h d -> b h s d") for x in (q_i, k_i, v_i)
|
||||
)
|
||||
output_i = F.scaled_dot_product_attention(q_i, k_i, v_i, dropout_p=0.0)
|
||||
output_i = rearrange(output_i, "b h s d -> b s h d ")
|
||||
outputs.append(output_i)
|
||||
context_layer = torch.cat(outputs, dim=1) if outputs else q[:, :0]
|
||||
|
||||
context_layer = rearrange(context_layer, "b s h d -> b s (h d)").contiguous()
|
||||
context_layer = self.attn(
|
||||
query=q,
|
||||
key=k,
|
||||
value=v,
|
||||
cu_seqlens=cu_seqlens,
|
||||
max_seqlen=max_seqlen,
|
||||
)
|
||||
context_layer = rearrange(context_layer, "b s h d -> b s (h d)")
|
||||
|
||||
output, _ = self.out_proj(context_layer)
|
||||
return output
|
||||
@ -547,8 +499,8 @@ class KeyeSiglipEncoderLayer(nn.Module):
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
multimodal_config: MultiModalConfig | None = None,
|
||||
prefix: str = "",
|
||||
attn_backend_override: AttentionBackendEnum | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.embed_dim = config.hidden_size
|
||||
@ -556,8 +508,8 @@ class KeyeSiglipEncoderLayer(nn.Module):
|
||||
self.self_attn = KeyeSiglipAttention(
|
||||
config,
|
||||
quant_config=quant_config,
|
||||
multimodal_config=multimodal_config,
|
||||
prefix=f"{prefix}.self_attn",
|
||||
attn_backend_override=attn_backend_override,
|
||||
)
|
||||
self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
|
||||
self.mlp = SiglipMLP(
|
||||
@ -601,8 +553,8 @@ class KeyeSiglipEncoder(nn.Module):
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
multimodal_config: MultiModalConfig | None = None,
|
||||
prefix: str = "",
|
||||
attn_backend_override: AttentionBackendEnum | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
@ -614,8 +566,8 @@ class KeyeSiglipEncoder(nn.Module):
|
||||
KeyeSiglipEncoderLayer(
|
||||
config,
|
||||
quant_config=quant_config,
|
||||
multimodal_config=multimodal_config,
|
||||
prefix=f"{prefix}.layers.{layer_idx}",
|
||||
attn_backend_override=attn_backend_override,
|
||||
)
|
||||
for layer_idx in range(config.num_hidden_layers)
|
||||
]
|
||||
@ -696,8 +648,8 @@ class KeyeSiglipVisionTransformer(nn.Module):
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
multimodal_config: MultiModalConfig | None = None,
|
||||
prefix: str = "",
|
||||
attn_backend_override: AttentionBackendEnum | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
@ -707,8 +659,8 @@ class KeyeSiglipVisionTransformer(nn.Module):
|
||||
self.encoder = KeyeSiglipEncoder(
|
||||
config,
|
||||
quant_config=quant_config,
|
||||
multimodal_config=multimodal_config,
|
||||
prefix=f"{prefix}.encoder",
|
||||
attn_backend_override=attn_backend_override,
|
||||
)
|
||||
self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
|
||||
|
||||
@ -779,16 +731,16 @@ class KeyeSiglipVisionModel(nn.Module):
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
multimodal_config: MultiModalConfig | None = None,
|
||||
prefix: str = "",
|
||||
attn_backend_override: AttentionBackendEnum | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.vision_model = KeyeSiglipVisionTransformer(
|
||||
config,
|
||||
quant_config=quant_config,
|
||||
multimodal_config=multimodal_config,
|
||||
prefix=f"{prefix}.vision_model",
|
||||
attn_backend_override=attn_backend_override,
|
||||
)
|
||||
self.quant_config = quant_config
|
||||
|
||||
@ -1329,16 +1281,11 @@ class BaseKeyeModule(nn.Module):
|
||||
self.config = config
|
||||
self.multimodal_config = multimodal_config
|
||||
|
||||
attn_backend_override = (
|
||||
multimodal_config.mm_encoder_attn_backend
|
||||
if multimodal_config is not None
|
||||
else None
|
||||
)
|
||||
self.visual = KeyeSiglipVisionModel(
|
||||
config.vision_config,
|
||||
quant_config=quant_config,
|
||||
multimodal_config=multimodal_config,
|
||||
prefix=maybe_prefix(prefix, "visual"),
|
||||
attn_backend_override=attn_backend_override,
|
||||
)
|
||||
|
||||
self.mlp_AR = self._build_projector(
|
||||
|
||||
@ -240,18 +240,12 @@ class OpenCUAForConditionalGeneration(Qwen2_5_VLForConditionalGeneration):
|
||||
)
|
||||
|
||||
if multimodal_config.get_limit_per_prompt("image"):
|
||||
attn_backend_override = (
|
||||
multimodal_config.mm_encoder_attn_backend
|
||||
if multimodal_config is not None
|
||||
else None
|
||||
)
|
||||
self.visual = OpenCUAVisionTransformer(
|
||||
vision_config=config.vision_config,
|
||||
norm_eps=getattr(config, "rms_norm_eps", 1e-6),
|
||||
quant_config=self.quant_config,
|
||||
multimodal_config=self.multimodal_config,
|
||||
prefix=maybe_prefix(prefix, "visual"),
|
||||
use_data_parallel=self.use_data_parallel,
|
||||
attn_backend_override=attn_backend_override,
|
||||
)
|
||||
else:
|
||||
self.visual = None
|
||||
|
||||
@ -10,8 +10,7 @@ import torch
|
||||
import torch.nn as nn
|
||||
from transformers import BaseImageProcessor, BatchFeature, PretrainedConfig
|
||||
|
||||
from vllm.attention.backends.registry import AttentionBackendEnum
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config import MultiModalConfig, VllmConfig
|
||||
from vllm.config.multimodal import BaseDummyOptions
|
||||
from vllm.model_executor.layers.linear import ReplicatedLinear
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
@ -104,18 +103,16 @@ class VisualTokenizer(torch.nn.Module):
|
||||
config: PretrainedConfig,
|
||||
visual_vocab_size: int,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
multimodal_config: MultiModalConfig | None = None,
|
||||
prefix: str = "",
|
||||
use_data_parallel: bool = False,
|
||||
attn_backend_override: AttentionBackendEnum | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.vit = self._init_backbone(
|
||||
config=config,
|
||||
quant_config=quant_config,
|
||||
multimodal_config=multimodal_config,
|
||||
prefix=f"{prefix}.vit",
|
||||
use_data_parallel=use_data_parallel,
|
||||
attn_backend_override=attn_backend_override,
|
||||
)
|
||||
# reserved tokens for INDICATOR_IDS
|
||||
head_dim = visual_vocab_size - len(INDICATOR_IDS)
|
||||
@ -133,18 +130,16 @@ class VisualTokenizer(torch.nn.Module):
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
multimodal_config: QuantizationConfig | None = None,
|
||||
prefix: str = "",
|
||||
use_data_parallel: bool = False,
|
||||
attn_backend_override: AttentionBackendEnum | None = None,
|
||||
):
|
||||
model_type = config.model_type
|
||||
if model_type == "siglip2_navit":
|
||||
return Siglip2NavitModel(
|
||||
config=config,
|
||||
quant_config=quant_config,
|
||||
multimodal_config=multimodal_config,
|
||||
prefix=prefix,
|
||||
use_data_parallel=use_data_parallel,
|
||||
attn_backend_override=attn_backend_override,
|
||||
)
|
||||
raise ValueError(f"Unsupported visual tokenizer model_type: {model_type}")
|
||||
|
||||
@ -468,17 +463,12 @@ class Ovis2_5(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
prefix=maybe_prefix(prefix, "llm"),
|
||||
)
|
||||
|
||||
attn_backend_override = (
|
||||
multimodal_config.mm_encoder_attn_backend
|
||||
if multimodal_config is not None
|
||||
else None
|
||||
)
|
||||
self.visual_tokenizer = VisualTokenizer(
|
||||
config=config.vit_config,
|
||||
visual_vocab_size=config.visual_vocab_size,
|
||||
multimodal_config=multimodal_config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.visual_tokenizer",
|
||||
attn_backend_override=attn_backend_override,
|
||||
)
|
||||
|
||||
self.vte = VisualEmbedding(config.visual_vocab_size, config.hidden_size)
|
||||
|
||||
@ -22,7 +22,6 @@ from typing import Annotated, Literal
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange, repeat
|
||||
from transformers import BatchFeature, PretrainedConfig
|
||||
from transformers.activations import GELUActivation
|
||||
@ -32,13 +31,10 @@ from transformers.modeling_outputs import (
|
||||
from transformers.utils import torch_int
|
||||
|
||||
from vllm.attention.backends.registry import AttentionBackendEnum
|
||||
from vllm.attention.layer import (
|
||||
maybe_get_vit_flash_attn_backend,
|
||||
from vllm.attention.layers.mm_encoder_attention import (
|
||||
MMEncoderAttention,
|
||||
)
|
||||
from vllm.attention.ops.vit_attn_wrappers import (
|
||||
vit_flash_attn_wrapper,
|
||||
)
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config import MultiModalConfig, VllmConfig
|
||||
from vllm.config.multimodal import BaseDummyOptions
|
||||
from vllm.distributed import parallel_state
|
||||
from vllm.distributed import utils as dist_utils
|
||||
@ -578,9 +574,8 @@ class SiglipAttention(nn.Module):
|
||||
num_heads: int,
|
||||
projection_size: int,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
multimodal_config: MultiModalConfig | None = None,
|
||||
prefix: str = "",
|
||||
attn_backend: AttentionBackendEnum = AttentionBackendEnum.TORCH_SDPA,
|
||||
attn_backend_override: AttentionBackendEnum | None = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
@ -608,18 +603,12 @@ class SiglipAttention(nn.Module):
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.out_proj",
|
||||
)
|
||||
|
||||
self.attn_backend = attn_backend
|
||||
self.attn_backend, self.flash_attn_varlen_func = (
|
||||
maybe_get_vit_flash_attn_backend(
|
||||
self.attn_backend,
|
||||
attn_backend_override=attn_backend_override,
|
||||
)
|
||||
self.attn = MMEncoderAttention(
|
||||
num_heads=self.num_attention_heads_per_partition,
|
||||
head_size=self.hidden_size_per_attention_head,
|
||||
multimodal_config=multimodal_config,
|
||||
prefix=f"{prefix}.attn",
|
||||
)
|
||||
self.is_flash_attn_backend = self.attn_backend in {
|
||||
AttentionBackendEnum.FLASH_ATTN,
|
||||
AttentionBackendEnum.ROCM_AITER_FA,
|
||||
}
|
||||
|
||||
def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]:
|
||||
seq_len, bs, _ = qkv.shape
|
||||
@ -665,44 +654,16 @@ class SiglipAttention(nn.Module):
|
||||
qk_rotated = apply_rotary_pos_emb_vision(qk_concat, rotary_pos_emb)
|
||||
q, k = torch.chunk(qk_rotated, 2, dim=0)
|
||||
|
||||
if self.is_flash_attn_backend:
|
||||
if max_seqlen is None:
|
||||
raise ValueError("Flash attention backend requires max_seqlen.")
|
||||
context_layer = vit_flash_attn_wrapper(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
cu_seqlens,
|
||||
max_seqlen,
|
||||
batch_size,
|
||||
self.attn_backend == AttentionBackendEnum.ROCM_AITER_FA,
|
||||
)
|
||||
elif self.attn_backend == AttentionBackendEnum.TORCH_SDPA:
|
||||
outputs = []
|
||||
for i in range(1, len(cu_seqlens)):
|
||||
start_idx = cu_seqlens[i - 1]
|
||||
end_idx = cu_seqlens[i]
|
||||
q_i = q[:, start_idx:end_idx]
|
||||
k_i = k[:, start_idx:end_idx]
|
||||
v_i = v[:, start_idx:end_idx]
|
||||
q_i, k_i, v_i = (
|
||||
rearrange(tensor, "b s h d -> b h s d")
|
||||
for tensor in (q_i, k_i, v_i)
|
||||
)
|
||||
output_i = F.scaled_dot_product_attention(q_i, k_i, v_i, dropout_p=0.0)
|
||||
output_i = rearrange(output_i, "b h s d -> b s h d")
|
||||
outputs.append(output_i)
|
||||
context_layer = torch.cat(outputs, dim=1)
|
||||
context_layer = rearrange(
|
||||
context_layer, "b s h d -> s b (h d)"
|
||||
).contiguous()
|
||||
else:
|
||||
raise RuntimeError(
|
||||
f"PaddleOCR-VL does not support {self.attn_backend} backend now."
|
||||
)
|
||||
context_layer = self.attn(
|
||||
query=q,
|
||||
key=k,
|
||||
value=v,
|
||||
cu_seqlens=cu_seqlens,
|
||||
max_seqlen=max_seqlen,
|
||||
)
|
||||
context_layer = rearrange(context_layer, "b s h d -> b s (h d)")
|
||||
|
||||
output, _ = self.out_proj(context_layer)
|
||||
output = rearrange(output, "s b d -> b s d")
|
||||
return output
|
||||
|
||||
|
||||
@ -774,10 +735,8 @@ class SiglipEncoderLayer(nn.Module):
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
multimodal_config: MultiModalConfig | None = None,
|
||||
prefix: str = "",
|
||||
*,
|
||||
attn_backend: AttentionBackendEnum = AttentionBackendEnum.TORCH_SDPA,
|
||||
attn_backend_override: AttentionBackendEnum | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.embed_dim = config.hidden_size
|
||||
@ -787,9 +746,8 @@ class SiglipEncoderLayer(nn.Module):
|
||||
num_heads=config.num_attention_heads,
|
||||
projection_size=config.hidden_size,
|
||||
quant_config=quant_config,
|
||||
multimodal_config=multimodal_config,
|
||||
prefix=f"{prefix}.self_attn",
|
||||
attn_backend=attn_backend,
|
||||
attn_backend_override=attn_backend_override,
|
||||
)
|
||||
self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
|
||||
self.mlp = SiglipMLP(
|
||||
@ -832,14 +790,18 @@ class SiglipEncoder(nn.Module):
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
multimodal_config: MultiModalConfig | None = None,
|
||||
prefix: str = "",
|
||||
attn_backend_override: AttentionBackendEnum | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
embed_dim = config.hidden_size
|
||||
num_heads = config.num_attention_heads
|
||||
head_dim = embed_dim // num_heads
|
||||
|
||||
attn_backend_override = (
|
||||
multimodal_config.mm_encoder_attn_backend if multimodal_config else None
|
||||
)
|
||||
self.attn_backend = get_vit_attn_backend(
|
||||
head_size=head_dim,
|
||||
dtype=torch.get_default_dtype(),
|
||||
@ -858,9 +820,8 @@ class SiglipEncoder(nn.Module):
|
||||
SiglipEncoderLayer(
|
||||
config,
|
||||
quant_config=quant_config,
|
||||
multimodal_config=multimodal_config,
|
||||
prefix=f"{prefix}.layers.{layer_idx}",
|
||||
attn_backend=self.attn_backend,
|
||||
attn_backend_override=attn_backend_override,
|
||||
)
|
||||
for layer_idx in range(config.num_hidden_layers)
|
||||
]
|
||||
@ -941,8 +902,8 @@ class SiglipVisionTransformer(nn.Module):
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
multimodal_config: MultiModalConfig | None = None,
|
||||
prefix: str = "",
|
||||
attn_backend_override: AttentionBackendEnum | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
@ -952,8 +913,8 @@ class SiglipVisionTransformer(nn.Module):
|
||||
self.encoder = SiglipEncoder(
|
||||
config,
|
||||
quant_config=quant_config,
|
||||
multimodal_config=multimodal_config,
|
||||
prefix=f"{prefix}.encoder",
|
||||
attn_backend_override=attn_backend_override,
|
||||
)
|
||||
self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
|
||||
|
||||
@ -991,16 +952,16 @@ class SiglipVisionModel(nn.Module):
|
||||
self,
|
||||
config,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
multimodal_config: MultiModalConfig | None = None,
|
||||
prefix: str = "",
|
||||
attn_backend_override: AttentionBackendEnum | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.vision_model = SiglipVisionTransformer(
|
||||
config,
|
||||
quant_config=quant_config,
|
||||
multimodal_config=multimodal_config,
|
||||
prefix=f"{prefix}.vision_model",
|
||||
attn_backend_override=attn_backend_override,
|
||||
)
|
||||
self.quant_config = quant_config
|
||||
|
||||
@ -1119,17 +1080,11 @@ class PaddleOCRVLForConditionalGeneration(nn.Module, SupportsMultiModal, Support
|
||||
self.config = config
|
||||
self.multimodal_config = multimodal_config
|
||||
|
||||
attn_backend_override = (
|
||||
multimodal_config.mm_encoder_attn_backend
|
||||
if multimodal_config is not None
|
||||
else None
|
||||
)
|
||||
|
||||
self.visual = SiglipVisionModel(
|
||||
config=config.vision_config,
|
||||
quant_config=quant_config,
|
||||
multimodal_config=multimodal_config,
|
||||
prefix=maybe_prefix(prefix, "visual"),
|
||||
attn_backend_override=attn_backend_override,
|
||||
)
|
||||
self.mlp_AR = Projector(config, config.vision_config)
|
||||
|
||||
|
||||
@ -845,6 +845,7 @@ class Qwen2_5OmniThinkerForConditionalGeneration(
|
||||
norm_eps=getattr(thinker_config.text_config, "rms_norm_eps", 1e-6),
|
||||
quant_config=quant_config,
|
||||
prefix=maybe_prefix(prefix, "visual"),
|
||||
multimodal_config=multimodal_config,
|
||||
)
|
||||
else:
|
||||
self.visual = None
|
||||
|
||||
@ -42,13 +42,9 @@ from transformers.models.qwen2_5_vl.configuration_qwen2_5_vl import (
|
||||
)
|
||||
|
||||
from vllm.attention.backends.registry import AttentionBackendEnum
|
||||
from vllm.attention.layer import maybe_get_vit_flash_attn_backend
|
||||
from vllm.attention.ops.vit_attn_wrappers import (
|
||||
vit_flash_attn_wrapper,
|
||||
vit_torch_sdpa_wrapper,
|
||||
)
|
||||
from vllm.attention.layers.mm_encoder_attention import MMEncoderAttention
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config import MultiModalConfig, VllmConfig
|
||||
from vllm.distributed import parallel_state
|
||||
from vllm.distributed import utils as dist_utils
|
||||
from vllm.forward_context import set_forward_context
|
||||
@ -267,10 +263,15 @@ class Qwen2_5_VisionMLP(nn.Module):
|
||||
bias: bool = False,
|
||||
act_fn: Callable[[torch.Tensor], torch.Tensor] = F.silu,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
multimodal_config: MultiModalConfig | None = None,
|
||||
prefix: str = "",
|
||||
use_data_parallel: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
use_data_parallel = (
|
||||
multimodal_config.mm_encoder_tp_mode == "data"
|
||||
if multimodal_config
|
||||
else False
|
||||
)
|
||||
self.gate_up_proj = MergedColumnParallelLinear(
|
||||
input_size=in_features,
|
||||
output_sizes=[hidden_features] * 2, # [gate_proj, up_proj]
|
||||
@ -304,13 +305,16 @@ class Qwen2_5_VisionAttention(nn.Module):
|
||||
num_heads: int,
|
||||
projection_size: int,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
multimodal_config: MultiModalConfig | None = None,
|
||||
prefix: str = "",
|
||||
use_data_parallel: bool = False,
|
||||
attn_backend: AttentionBackendEnum = AttentionBackendEnum.TORCH_SDPA,
|
||||
attn_backend_override: AttentionBackendEnum | None = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
# Per attention head and per partition values.
|
||||
use_data_parallel = (
|
||||
multimodal_config.mm_encoder_tp_mode == "data"
|
||||
if multimodal_config
|
||||
else False
|
||||
)
|
||||
self.tp_size = (
|
||||
1
|
||||
if use_data_parallel
|
||||
@ -342,18 +346,12 @@ class Qwen2_5_VisionAttention(nn.Module):
|
||||
prefix=f"{prefix}.proj",
|
||||
disable_tp=use_data_parallel,
|
||||
)
|
||||
self.attn_backend = attn_backend
|
||||
self.attn_backend, self.flash_attn_varlen_func = (
|
||||
maybe_get_vit_flash_attn_backend(
|
||||
self.attn_backend,
|
||||
attn_backend_override=attn_backend_override,
|
||||
)
|
||||
)
|
||||
|
||||
self.is_flash_attn_backend = self.attn_backend in {
|
||||
AttentionBackendEnum.FLASH_ATTN,
|
||||
AttentionBackendEnum.ROCM_AITER_FA,
|
||||
}
|
||||
self.attn = MMEncoderAttention(
|
||||
num_heads=self.num_attention_heads_per_partition,
|
||||
head_size=self.hidden_size_per_attention_head,
|
||||
multimodal_config=multimodal_config,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -394,32 +392,17 @@ class Qwen2_5_VisionAttention(nn.Module):
|
||||
else:
|
||||
q, k, v = qkv.unbind(dim=2)
|
||||
|
||||
if self.is_flash_attn_backend:
|
||||
context_layer = vit_flash_attn_wrapper(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
cu_seqlens,
|
||||
max_seqlen,
|
||||
batch_size,
|
||||
self.attn_backend == AttentionBackendEnum.ROCM_AITER_FA,
|
||||
)
|
||||
elif self.attn_backend == AttentionBackendEnum.TORCH_SDPA:
|
||||
# Execute attention entry by entry for speed & less VRAM.
|
||||
from vllm.platforms import current_platform
|
||||
context_layer = self.attn(
|
||||
query=q,
|
||||
key=k,
|
||||
value=v,
|
||||
cu_seqlens=cu_seqlens,
|
||||
max_seqlen=max_seqlen,
|
||||
)
|
||||
|
||||
# Never remove the next contiguous logic
|
||||
# Without it, hallucinations occur with the backend
|
||||
if current_platform.is_rocm():
|
||||
q = q.contiguous()
|
||||
k = k.contiguous()
|
||||
v = v.contiguous()
|
||||
context_layer = vit_torch_sdpa_wrapper(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
cu_seqlens,
|
||||
)
|
||||
context_layer = einops.rearrange(
|
||||
context_layer, "b s h d -> s b (h d)", b=batch_size
|
||||
).contiguous()
|
||||
|
||||
output, _ = self.proj(context_layer)
|
||||
return output
|
||||
@ -443,10 +426,8 @@ class Qwen2_5_VisionBlock(nn.Module):
|
||||
act_fn: Callable[[torch.Tensor], torch.Tensor] = F.silu,
|
||||
norm_layer: Callable[[int], nn.Module] | None = None,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
multimodal_config: MultiModalConfig | None = None,
|
||||
prefix: str = "",
|
||||
use_data_parallel: bool = False,
|
||||
attn_backend: AttentionBackendEnum = AttentionBackendEnum.TORCH_SDPA,
|
||||
attn_backend_override: AttentionBackendEnum | None = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
if norm_layer is None:
|
||||
@ -458,10 +439,8 @@ class Qwen2_5_VisionBlock(nn.Module):
|
||||
num_heads=num_heads,
|
||||
projection_size=dim,
|
||||
quant_config=quant_config,
|
||||
multimodal_config=multimodal_config,
|
||||
prefix=f"{prefix}.attn",
|
||||
use_data_parallel=use_data_parallel,
|
||||
attn_backend=attn_backend,
|
||||
attn_backend_override=attn_backend_override,
|
||||
)
|
||||
self.mlp = Qwen2_5_VisionMLP(
|
||||
dim,
|
||||
@ -469,8 +448,8 @@ class Qwen2_5_VisionBlock(nn.Module):
|
||||
act_fn=act_fn,
|
||||
bias=True,
|
||||
quant_config=quant_config,
|
||||
multimodal_config=multimodal_config,
|
||||
prefix=f"{prefix}.mlp",
|
||||
use_data_parallel=use_data_parallel,
|
||||
)
|
||||
|
||||
def forward(
|
||||
@ -542,10 +521,15 @@ class Qwen2_5_VisionPatchMerger(nn.Module):
|
||||
norm_layer: Callable[[int], nn.Module] | None = None,
|
||||
spatial_merge_size: int = 2,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
multimodal_config: MultiModalConfig | None = None,
|
||||
prefix: str = "",
|
||||
use_data_parallel: bool = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
use_data_parallel = (
|
||||
multimodal_config.mm_encoder_tp_mode == "data"
|
||||
if multimodal_config
|
||||
else False
|
||||
)
|
||||
self.hidden_size = context_dim * (spatial_merge_size**2)
|
||||
if norm_layer is None:
|
||||
norm_layer = partial(nn.LayerNorm, eps=1e-6)
|
||||
@ -586,9 +570,8 @@ class Qwen2_5_VisionTransformer(nn.Module):
|
||||
vision_config: Qwen2_5_VLVisionConfig,
|
||||
norm_eps: float = 1e-6,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
multimodal_config: MultiModalConfig | None = None,
|
||||
prefix: str = "",
|
||||
use_data_parallel: bool = False,
|
||||
attn_backend_override: AttentionBackendEnum | None = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
@ -598,7 +581,6 @@ class Qwen2_5_VisionTransformer(nn.Module):
|
||||
depth = vision_config.depth
|
||||
self.hidden_size = vision_config.hidden_size
|
||||
self.num_heads = vision_config.num_heads
|
||||
self.use_data_parallel = use_data_parallel
|
||||
self.out_hidden_size = vision_config.out_hidden_size
|
||||
|
||||
# args for get_window_index_thw
|
||||
@ -629,19 +611,17 @@ class Qwen2_5_VisionTransformer(nn.Module):
|
||||
rope_parameters={"partial_rotary_factor": 0.5},
|
||||
)
|
||||
|
||||
attn_backend_override = (
|
||||
multimodal_config.mm_encoder_attn_backend
|
||||
if multimodal_config is not None
|
||||
else None
|
||||
)
|
||||
self.attn_backend = get_vit_attn_backend(
|
||||
head_size=head_dim,
|
||||
dtype=torch.get_default_dtype(),
|
||||
attn_backend_override=attn_backend_override,
|
||||
)
|
||||
|
||||
self.attn_backend, self.flash_attn_varlen_func = (
|
||||
maybe_get_vit_flash_attn_backend(
|
||||
self.attn_backend,
|
||||
attn_backend_override=attn_backend_override,
|
||||
)
|
||||
)
|
||||
|
||||
if self.attn_backend not in {
|
||||
AttentionBackendEnum.FLASH_ATTN,
|
||||
AttentionBackendEnum.TORCH_SDPA,
|
||||
@ -661,10 +641,8 @@ class Qwen2_5_VisionTransformer(nn.Module):
|
||||
act_fn=get_act_and_mul_fn(vision_config.hidden_act),
|
||||
norm_layer=norm_layer,
|
||||
quant_config=quant_config,
|
||||
multimodal_config=multimodal_config,
|
||||
prefix=f"{prefix}.blocks.{layer_idx}",
|
||||
use_data_parallel=use_data_parallel,
|
||||
attn_backend=self.attn_backend,
|
||||
attn_backend_override=attn_backend_override,
|
||||
)
|
||||
for layer_idx in range(depth)
|
||||
]
|
||||
@ -677,8 +655,8 @@ class Qwen2_5_VisionTransformer(nn.Module):
|
||||
norm_layer=norm_layer,
|
||||
spatial_merge_size=self.spatial_merge_size,
|
||||
quant_config=quant_config,
|
||||
multimodal_config=multimodal_config,
|
||||
prefix=f"{prefix}.merger",
|
||||
use_data_parallel=use_data_parallel,
|
||||
)
|
||||
|
||||
@property
|
||||
@ -1200,18 +1178,12 @@ class Qwen2_5_VLForConditionalGeneration(
|
||||
if multimodal_config.get_limit_per_prompt(
|
||||
"image"
|
||||
) or multimodal_config.get_limit_per_prompt("video"):
|
||||
attn_backend_override = (
|
||||
multimodal_config.mm_encoder_attn_backend
|
||||
if multimodal_config is not None
|
||||
else None
|
||||
)
|
||||
self.visual = Qwen2_5_VisionTransformer(
|
||||
vision_config=config.vision_config,
|
||||
norm_eps=getattr(config, "rms_norm_eps", 1e-6),
|
||||
quant_config=self.quant_config,
|
||||
prefix=maybe_prefix(prefix, "visual"),
|
||||
use_data_parallel=self.use_data_parallel,
|
||||
attn_backend_override=attn_backend_override,
|
||||
multimodal_config=multimodal_config,
|
||||
)
|
||||
else:
|
||||
self.visual = None
|
||||
|
||||
@ -33,7 +33,6 @@ from typing import Annotated, Any, Literal, TypeAlias
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange
|
||||
from transformers import BatchFeature
|
||||
from transformers.models.qwen2_vl import Qwen2VLImageProcessor, Qwen2VLProcessor
|
||||
@ -45,10 +44,8 @@ from transformers.models.qwen2_vl.image_processing_qwen2_vl import smart_resize
|
||||
from transformers.models.qwen2_vl.video_processing_qwen2_vl import Qwen2VLVideoProcessor
|
||||
|
||||
from vllm.attention.backends.registry import AttentionBackendEnum
|
||||
from vllm.attention.layer import (
|
||||
maybe_get_vit_flash_attn_backend,
|
||||
)
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.attention.layers.mm_encoder_attention import MMEncoderAttention
|
||||
from vllm.config import MultiModalConfig, VllmConfig
|
||||
from vllm.config.multimodal import BaseDummyOptions
|
||||
from vllm.distributed import parallel_state, tensor_model_parallel_all_gather
|
||||
from vllm.distributed import utils as dist_utils
|
||||
@ -251,10 +248,15 @@ class Qwen2VisionMLP(nn.Module):
|
||||
hidden_features: int,
|
||||
act_layer: type[nn.Module] = QuickGELU,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
multimodal_config: MultiModalConfig | None = None,
|
||||
prefix: str = "",
|
||||
use_data_parallel: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
use_data_parallel = (
|
||||
multimodal_config.mm_encoder_tp_mode == "data"
|
||||
if multimodal_config
|
||||
else False
|
||||
)
|
||||
self.fc1 = ColumnParallelLinear(
|
||||
in_features,
|
||||
hidden_features,
|
||||
@ -295,12 +297,16 @@ class Qwen2VisionAttention(nn.Module):
|
||||
num_heads: int,
|
||||
projection_size: int,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
multimodal_config: MultiModalConfig | None = None,
|
||||
prefix: str = "",
|
||||
use_data_parallel: bool = False,
|
||||
attn_backend_override: AttentionBackendEnum | None = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
# Per attention head and per partition values.
|
||||
use_data_parallel = (
|
||||
multimodal_config.mm_encoder_tp_mode == "data"
|
||||
if multimodal_config
|
||||
else False
|
||||
)
|
||||
self.tp_size = (
|
||||
1
|
||||
if use_data_parallel
|
||||
@ -329,34 +335,12 @@ class Qwen2VisionAttention(nn.Module):
|
||||
disable_tp=use_data_parallel,
|
||||
)
|
||||
|
||||
# Detect attention implementation.
|
||||
self.attn_backend = get_vit_attn_backend(
|
||||
self.attn = MMEncoderAttention(
|
||||
num_heads=self.num_attention_heads_per_partition,
|
||||
head_size=self.hidden_size_per_attention_head,
|
||||
dtype=torch.get_default_dtype(),
|
||||
attn_backend_override=attn_backend_override,
|
||||
multimodal_config=multimodal_config,
|
||||
)
|
||||
|
||||
self.attn_backend, self.flash_attn_varlen_func = (
|
||||
maybe_get_vit_flash_attn_backend(
|
||||
self.attn_backend,
|
||||
attn_backend_override=attn_backend_override,
|
||||
)
|
||||
)
|
||||
|
||||
if self.attn_backend not in {
|
||||
AttentionBackendEnum.FLASH_ATTN,
|
||||
AttentionBackendEnum.TORCH_SDPA,
|
||||
AttentionBackendEnum.ROCM_AITER_FA,
|
||||
}:
|
||||
raise RuntimeError(
|
||||
f"Qwen2-VL does not support {self.attn_backend} backend now."
|
||||
)
|
||||
|
||||
self.is_flash_attn_backend = self.attn_backend in {
|
||||
AttentionBackendEnum.FLASH_ATTN,
|
||||
AttentionBackendEnum.ROCM_AITER_FA,
|
||||
}
|
||||
|
||||
def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]:
|
||||
# [s, b, 3 * head * head_dim]
|
||||
seq_len, bs, _ = qkv.shape
|
||||
@ -398,7 +382,6 @@ class Qwen2VisionAttention(nn.Module):
|
||||
|
||||
# [s, b, 3 * head * head_dim] -> 3 * [s, b, head, head_dim]
|
||||
q, k, v = self.split_qkv(x)
|
||||
batch_size = q.shape[1]
|
||||
|
||||
q, k, v = (rearrange(x, "s b ... -> b s ...") for x in (q, k, v))
|
||||
|
||||
@ -409,49 +392,15 @@ class Qwen2VisionAttention(nn.Module):
|
||||
)
|
||||
q, k = torch.chunk(qk_rotated, 2, dim=0)
|
||||
|
||||
if self.is_flash_attn_backend:
|
||||
q, k, v = (rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v])
|
||||
context_layer = self.attn(
|
||||
query=q,
|
||||
key=k,
|
||||
value=v,
|
||||
cu_seqlens=cu_seqlens,
|
||||
max_seqlen=max_seqlen,
|
||||
)
|
||||
|
||||
output = self.flash_attn_varlen_func(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
cu_seqlens_q=cu_seqlens,
|
||||
cu_seqlens_k=cu_seqlens,
|
||||
max_seqlen_q=max_seqlen,
|
||||
max_seqlen_k=max_seqlen,
|
||||
dropout_p=0.0,
|
||||
causal=False,
|
||||
)
|
||||
|
||||
context_layer = rearrange(
|
||||
output, "(b s) h d -> s b (h d)", b=batch_size
|
||||
).contiguous()
|
||||
elif self.attn_backend == AttentionBackendEnum.TORCH_SDPA:
|
||||
# Execute attention entry by entry for speed & less VRAM.
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
if current_platform.is_rocm():
|
||||
q = q.contiguous()
|
||||
k = k.contiguous()
|
||||
v = v.contiguous()
|
||||
outputs = []
|
||||
|
||||
lens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
|
||||
q_chunks = torch.split(q, lens, dim=1)
|
||||
k_chunks = torch.split(k, lens, dim=1)
|
||||
v_chunks = torch.split(v, lens, dim=1)
|
||||
for q_i, k_i, v_i in zip(q_chunks, k_chunks, v_chunks):
|
||||
q_i, k_i, v_i = (
|
||||
rearrange(x, "b s h d -> b h s d") for x in [q_i, k_i, v_i]
|
||||
)
|
||||
output_i = F.scaled_dot_product_attention(q_i, k_i, v_i, dropout_p=0.0)
|
||||
output_i = rearrange(output_i, "b h s d -> b s h d ")
|
||||
outputs.append(output_i)
|
||||
context_layer = torch.cat(outputs, dim=1)
|
||||
context_layer = rearrange(
|
||||
context_layer, "b s h d -> s b (h d)"
|
||||
).contiguous()
|
||||
context_layer = rearrange(context_layer, "b s h d -> s b (h d)").contiguous()
|
||||
|
||||
output, _ = self.proj(context_layer)
|
||||
return output
|
||||
@ -466,9 +415,8 @@ class Qwen2VisionBlock(nn.Module):
|
||||
act_layer: type[nn.Module] = QuickGELU,
|
||||
norm_layer: Callable[[int], nn.Module] | None = None,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
multimodal_config: MultiModalConfig | None = None,
|
||||
prefix: str = "",
|
||||
use_data_parallel: bool = False,
|
||||
attn_backend_override: AttentionBackendEnum | None = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
if norm_layer is None:
|
||||
@ -482,17 +430,16 @@ class Qwen2VisionBlock(nn.Module):
|
||||
num_heads=num_heads,
|
||||
projection_size=dim,
|
||||
quant_config=quant_config,
|
||||
multimodal_config=multimodal_config,
|
||||
prefix=f"{prefix}.attn",
|
||||
use_data_parallel=use_data_parallel,
|
||||
attn_backend_override=attn_backend_override,
|
||||
)
|
||||
self.mlp = Qwen2VisionMLP(
|
||||
dim,
|
||||
mlp_hidden_dim,
|
||||
act_layer=act_layer,
|
||||
quant_config=quant_config,
|
||||
multimodal_config=multimodal_config,
|
||||
prefix=f"{prefix}.mlp",
|
||||
use_data_parallel=use_data_parallel,
|
||||
)
|
||||
|
||||
def forward(
|
||||
@ -552,10 +499,15 @@ class Qwen2VisionPatchMerger(nn.Module):
|
||||
norm_layer: Callable[[int], nn.Module] | None = None,
|
||||
spatial_merge_size: int = 2,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
multimodal_config: MultiModalConfig | None = None,
|
||||
prefix: str = "",
|
||||
use_data_parallel: bool = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
use_data_parallel = (
|
||||
multimodal_config.mm_encoder_tp_mode == "data"
|
||||
if multimodal_config
|
||||
else False
|
||||
)
|
||||
self.hidden_size = context_dim * (spatial_merge_size**2)
|
||||
if norm_layer is None:
|
||||
norm_layer = partial(nn.LayerNorm, eps=1e-6)
|
||||
@ -599,9 +551,8 @@ class Qwen2VisionTransformer(nn.Module):
|
||||
vision_config: Qwen2VLVisionConfig,
|
||||
norm_eps: float = 1e-6,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
multimodal_config: MultiModalConfig | None = None,
|
||||
prefix: str = "",
|
||||
use_data_parallel: bool = False,
|
||||
attn_backend_override: AttentionBackendEnum | None = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
@ -615,7 +566,11 @@ class Qwen2VisionTransformer(nn.Module):
|
||||
num_heads = vision_config.num_heads
|
||||
mlp_ratio = vision_config.mlp_ratio
|
||||
|
||||
self.use_data_parallel = use_data_parallel
|
||||
self.use_data_parallel = (
|
||||
multimodal_config.mm_encoder_tp_mode == "data"
|
||||
if multimodal_config
|
||||
else False
|
||||
)
|
||||
self.out_hidden_size = vision_config.hidden_size
|
||||
|
||||
self.spatial_merge_size = spatial_merge_size
|
||||
@ -647,8 +602,7 @@ class Qwen2VisionTransformer(nn.Module):
|
||||
norm_layer=norm_layer,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.blocks.{layer_idx}",
|
||||
use_data_parallel=use_data_parallel,
|
||||
attn_backend_override=attn_backend_override,
|
||||
multimodal_config=multimodal_config,
|
||||
)
|
||||
for layer_idx in range(depth)
|
||||
]
|
||||
@ -659,7 +613,10 @@ class Qwen2VisionTransformer(nn.Module):
|
||||
norm_layer=norm_layer,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.merger",
|
||||
use_data_parallel=use_data_parallel,
|
||||
multimodal_config=multimodal_config,
|
||||
)
|
||||
attn_backend_override = (
|
||||
multimodal_config.mm_encoder_attn_backend if multimodal_config else None
|
||||
)
|
||||
self.attn_backend = get_vit_attn_backend(
|
||||
head_size=head_dim,
|
||||
@ -720,7 +677,7 @@ class Qwen2VisionTransformer(nn.Module):
|
||||
AttentionBackendEnum.FLASH_ATTN,
|
||||
AttentionBackendEnum.ROCM_AITER_FA,
|
||||
}:
|
||||
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
|
||||
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
|
||||
return max_seqlen
|
||||
|
||||
def forward(
|
||||
@ -1324,18 +1281,12 @@ class Qwen2VLForConditionalGeneration(
|
||||
if multimodal_config.get_limit_per_prompt(
|
||||
"image"
|
||||
) or multimodal_config.get_limit_per_prompt("video"):
|
||||
attn_backend_override = (
|
||||
multimodal_config.mm_encoder_attn_backend
|
||||
if multimodal_config is not None
|
||||
else None
|
||||
)
|
||||
self.visual = Qwen2VisionTransformer(
|
||||
config.vision_config,
|
||||
norm_eps=getattr(config, "rms_norm_eps", 1e-6),
|
||||
quant_config=quant_config,
|
||||
multimodal_config=multimodal_config,
|
||||
prefix=maybe_prefix(prefix, "visual"),
|
||||
use_data_parallel=self.use_data_parallel,
|
||||
attn_backend_override=attn_backend_override,
|
||||
)
|
||||
else:
|
||||
self.visual = None
|
||||
|
||||
@ -48,7 +48,7 @@ from transformers.models.whisper import WhisperFeatureExtractor
|
||||
|
||||
from vllm.attention.backends.registry import AttentionBackendEnum
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config import MultiModalConfig, VllmConfig
|
||||
from vllm.distributed import get_pp_group
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.activation import _ACTIVATION_REGISTRY
|
||||
@ -192,6 +192,7 @@ class Qwen3_VisionBlock(nn.Module):
|
||||
mlp_hidden_dim: int,
|
||||
act_fn: Callable[[torch.Tensor], torch.Tensor] = F.silu,
|
||||
norm_layer: Callable[[int], nn.Module] | None = None,
|
||||
multimodal_config: MultiModalConfig | None = None,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
@ -205,6 +206,7 @@ class Qwen3_VisionBlock(nn.Module):
|
||||
num_heads=num_heads,
|
||||
projection_size=dim,
|
||||
quant_config=quant_config,
|
||||
multimodal_config=multimodal_config,
|
||||
prefix=f"{prefix}.attn",
|
||||
)
|
||||
self.mlp = Qwen3_VisionMLP(
|
||||
@ -299,8 +301,8 @@ class Qwen3Omni_VisionTransformer(nn.Module):
|
||||
vision_config,
|
||||
norm_eps: float = 1e-6,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
multimodal_config: MultiModalConfig | None = None,
|
||||
prefix: str = "",
|
||||
attn_backend_override: AttentionBackendEnum | None = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.hidden_size = vision_config.hidden_size
|
||||
@ -347,6 +349,7 @@ class Qwen3Omni_VisionTransformer(nn.Module):
|
||||
act_fn=_ACTIVATION_REGISTRY[vision_config.hidden_act],
|
||||
norm_layer=norm_layer,
|
||||
quant_config=quant_config,
|
||||
multimodal_config=multimodal_config,
|
||||
prefix=f"{prefix}.blocks.{layer_idx}",
|
||||
)
|
||||
for layer_idx in range(vision_config.depth)
|
||||
@ -376,6 +379,12 @@ class Qwen3Omni_VisionTransformer(nn.Module):
|
||||
]
|
||||
)
|
||||
|
||||
attn_backend_override = (
|
||||
multimodal_config.mm_encoder_attn_backend
|
||||
if multimodal_config is not None
|
||||
else None
|
||||
)
|
||||
|
||||
self.attn_backend = get_vit_attn_backend(
|
||||
head_size=head_dim,
|
||||
dtype=torch.get_default_dtype(),
|
||||
@ -1188,17 +1197,12 @@ class Qwen3OmniMoeThinkerForConditionalGeneration(
|
||||
|
||||
self.audio_tower = Qwen3OmniMoeAudioEncoder(thinker_config.audio_config)
|
||||
|
||||
attn_backend_override = (
|
||||
multimodal_config.mm_encoder_attn_backend
|
||||
if multimodal_config is not None
|
||||
else None
|
||||
)
|
||||
self.visual = Qwen3Omni_VisionTransformer(
|
||||
vision_config=thinker_config.vision_config,
|
||||
norm_eps=getattr(thinker_config.text_config, "rms_norm_eps", 1e-6),
|
||||
quant_config=quant_config,
|
||||
prefix=maybe_prefix(prefix, "visual"),
|
||||
attn_backend_override=attn_backend_override,
|
||||
multimodal_config=multimodal_config,
|
||||
)
|
||||
self.quant_config = quant_config
|
||||
|
||||
|
||||
@ -50,7 +50,7 @@ from transformers.video_utils import VideoMetadata
|
||||
|
||||
from vllm.attention.backends.registry import AttentionBackendEnum
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config import MultiModalConfig, VllmConfig
|
||||
from vllm.config.multimodal import BaseDummyOptions, VideoDummyOptions
|
||||
from vllm.distributed import get_pp_group
|
||||
from vllm.logger import init_logger
|
||||
@ -169,10 +169,15 @@ class Qwen3_VisionMLP(nn.Module):
|
||||
bias: bool = False,
|
||||
act_fn: Callable[[torch.Tensor], torch.Tensor] = F.silu,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
multimodal_config: MultiModalConfig | None = None,
|
||||
prefix: str = "",
|
||||
use_data_parallel: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
use_data_parallel = (
|
||||
multimodal_config.mm_encoder_tp_mode == "data"
|
||||
if multimodal_config
|
||||
else False
|
||||
)
|
||||
self.linear_fc1 = ColumnParallelLinear(
|
||||
in_features,
|
||||
hidden_features,
|
||||
@ -206,10 +211,9 @@ class Qwen3_VisionBlock(nn.Module):
|
||||
mlp_hidden_dim: int,
|
||||
act_fn: Callable[[torch.Tensor], torch.Tensor] = F.silu,
|
||||
norm_layer: Callable[[int], nn.Module] | None = None,
|
||||
multimodal_config: MultiModalConfig | None = None,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
prefix: str = "",
|
||||
use_data_parallel: bool = False,
|
||||
attn_backend: AttentionBackendEnum = AttentionBackendEnum.TORCH_SDPA,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
if norm_layer is None:
|
||||
@ -221,9 +225,8 @@ class Qwen3_VisionBlock(nn.Module):
|
||||
num_heads=num_heads,
|
||||
projection_size=dim,
|
||||
quant_config=quant_config,
|
||||
multimodal_config=multimodal_config,
|
||||
prefix=f"{prefix}.attn",
|
||||
use_data_parallel=use_data_parallel,
|
||||
attn_backend=attn_backend,
|
||||
)
|
||||
self.mlp = Qwen3_VisionMLP(
|
||||
dim,
|
||||
@ -231,8 +234,8 @@ class Qwen3_VisionBlock(nn.Module):
|
||||
act_fn=act_fn,
|
||||
bias=True,
|
||||
quant_config=quant_config,
|
||||
multimodal_config=multimodal_config,
|
||||
prefix=f"{prefix}.mlp",
|
||||
use_data_parallel=use_data_parallel,
|
||||
)
|
||||
|
||||
def forward(
|
||||
@ -264,10 +267,15 @@ class Qwen3_VisionPatchMerger(nn.Module):
|
||||
spatial_merge_size: int = 2,
|
||||
use_postshuffle_norm: bool = False,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
multimodal_config: MultiModalConfig | None = None,
|
||||
prefix: str = "",
|
||||
use_data_parallel: bool = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
use_data_parallel = (
|
||||
multimodal_config.mm_encoder_tp_mode == "data"
|
||||
if multimodal_config
|
||||
else False
|
||||
)
|
||||
self.hidden_size = context_dim * (spatial_merge_size**2)
|
||||
|
||||
self.use_postshuffle_norm = use_postshuffle_norm
|
||||
@ -313,9 +321,8 @@ class Qwen3_VisionTransformer(nn.Module):
|
||||
vision_config: Qwen3VLVisionConfig,
|
||||
norm_eps: float = 1e-6,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
multimodal_config: MultiModalConfig | None = None,
|
||||
prefix: str = "",
|
||||
use_data_parallel: bool = False,
|
||||
attn_backend_override: AttentionBackendEnum | None = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.hidden_size = vision_config.hidden_size
|
||||
@ -326,7 +333,6 @@ class Qwen3_VisionTransformer(nn.Module):
|
||||
self.spatial_merge_unit = self.spatial_merge_size**2
|
||||
self.temporal_patch_size = vision_config.temporal_patch_size
|
||||
self.deepstack_visual_indexes = vision_config.deepstack_visual_indexes
|
||||
self.use_data_parallel = use_data_parallel
|
||||
self.num_grid_per_side = int(self.num_position_embeddings**0.5)
|
||||
|
||||
# NOTE: This is used for creating empty tensor for all_gather for
|
||||
@ -359,8 +365,8 @@ class Qwen3_VisionTransformer(nn.Module):
|
||||
norm_layer=norm_layer,
|
||||
spatial_merge_size=self.spatial_merge_size,
|
||||
quant_config=quant_config,
|
||||
multimodal_config=multimodal_config,
|
||||
prefix=f"{prefix}.merger",
|
||||
use_data_parallel=use_data_parallel,
|
||||
)
|
||||
|
||||
self.deepstack_merger_list = nn.ModuleList(
|
||||
@ -372,13 +378,16 @@ class Qwen3_VisionTransformer(nn.Module):
|
||||
use_postshuffle_norm=True,
|
||||
norm_layer=norm_layer,
|
||||
quant_config=quant_config,
|
||||
multimodal_config=multimodal_config,
|
||||
prefix=f"{prefix}.deepstack_merger_list.{layer_idx}",
|
||||
use_data_parallel=use_data_parallel,
|
||||
)
|
||||
for layer_idx in range(len(self.deepstack_visual_indexes))
|
||||
]
|
||||
)
|
||||
|
||||
attn_backend_override = (
|
||||
multimodal_config.mm_encoder_attn_backend if multimodal_config else None
|
||||
)
|
||||
self.attn_backend = get_vit_attn_backend(
|
||||
head_size=head_dim,
|
||||
dtype=torch.get_default_dtype(),
|
||||
@ -402,9 +411,8 @@ class Qwen3_VisionTransformer(nn.Module):
|
||||
act_fn=_ACTIVATION_REGISTRY[vision_config.hidden_act],
|
||||
norm_layer=norm_layer,
|
||||
quant_config=quant_config,
|
||||
multimodal_config=multimodal_config,
|
||||
prefix=f"{prefix}.blocks.{layer_idx}",
|
||||
use_data_parallel=use_data_parallel,
|
||||
attn_backend=self.attn_backend,
|
||||
)
|
||||
for layer_idx in range(vision_config.depth)
|
||||
]
|
||||
@ -1277,18 +1285,12 @@ class Qwen3VLForConditionalGeneration(
|
||||
) and not multimodal_config.get_limit_per_prompt("video"):
|
||||
self.visual = None
|
||||
else:
|
||||
attn_backend_override = (
|
||||
multimodal_config.mm_encoder_attn_backend
|
||||
if multimodal_config is not None
|
||||
else None
|
||||
)
|
||||
self.visual = Qwen3_VisionTransformer(
|
||||
config.vision_config,
|
||||
norm_eps=getattr(config, "rms_norm_eps", 1e-6),
|
||||
quant_config=quant_config,
|
||||
multimodal_config=multimodal_config,
|
||||
prefix=maybe_prefix(prefix, "visual"),
|
||||
use_data_parallel=self.use_data_parallel,
|
||||
attn_backend_override=attn_backend_override,
|
||||
)
|
||||
|
||||
self.language_model = Qwen3LLMForCausalLM(
|
||||
|
||||
@ -418,7 +418,6 @@ class Qwen3VLMoeForConditionalGeneration(
|
||||
|
||||
self.config = config
|
||||
self.multimodal_config = multimodal_config
|
||||
self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data"
|
||||
|
||||
if not multimodal_config.get_limit_per_prompt(
|
||||
"image"
|
||||
@ -429,8 +428,8 @@ class Qwen3VLMoeForConditionalGeneration(
|
||||
config.vision_config,
|
||||
norm_eps=getattr(config, "rms_norm_eps", 1e-6),
|
||||
quant_config=quant_config,
|
||||
multimodal_config=multimodal_config,
|
||||
prefix=maybe_prefix(prefix, "visual"),
|
||||
use_data_parallel=self.use_data_parallel,
|
||||
)
|
||||
|
||||
self.language_model = Qwen3MoeLLMForCausalLM(
|
||||
|
||||
@ -13,7 +13,8 @@ from transformers import Siglip2VisionConfig
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
|
||||
from vllm.attention.backends.registry import AttentionBackendEnum
|
||||
from vllm.attention.layer import maybe_get_vit_flash_attn_backend
|
||||
from vllm.attention.layers.mm_encoder_attention import MMEncoderAttention
|
||||
from vllm.config import MultiModalConfig
|
||||
from vllm.distributed import divide, get_tensor_model_parallel_world_size
|
||||
from vllm.model_executor.layers.activation import get_act_fn
|
||||
from vllm.model_executor.layers.conv import Conv2dLayer
|
||||
@ -28,8 +29,6 @@ from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from .vision import get_vit_attn_backend
|
||||
|
||||
|
||||
class VisionRotaryEmbedding(nn.Module):
|
||||
def __init__(self, dim: int, theta: float = 10000.0) -> None:
|
||||
@ -190,7 +189,7 @@ def apply_rotary_pos_emb(
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
cos = cos.chunk(2, dim=-1)[0].contiguous()
|
||||
sin = sin.chunk(2, dim=-1)[0].contiguous()
|
||||
if is_flash_attn_backend and not current_platform.is_xpu():
|
||||
if is_flash_attn_backend and current_platform.is_cuda():
|
||||
from vllm.vllm_flash_attn.layers.rotary import apply_rotary_emb
|
||||
|
||||
apply_rotary_emb_func = apply_rotary_emb
|
||||
@ -208,6 +207,7 @@ class Siglip2Attention(nn.Module):
|
||||
self,
|
||||
config: Siglip2VisionConfig,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
multimodal_config: MultiModalConfig | None = None,
|
||||
prefix: str = "",
|
||||
use_data_parallel: bool = False,
|
||||
attn_backend_override: AttentionBackendEnum | None = None,
|
||||
@ -227,20 +227,25 @@ class Siglip2Attention(nn.Module):
|
||||
self.dropout = config.attention_dropout
|
||||
self.is_causal = False
|
||||
|
||||
# TODO(Isotr0py): Enable data parallel after we support
|
||||
# disabling TP on parallel linear layer
|
||||
use_data_parallel = (
|
||||
multimodal_config.mm_encoder_tp_mode == "data"
|
||||
if multimodal_config
|
||||
else False
|
||||
)
|
||||
self.qkv_proj = QKVParallelLinear(
|
||||
hidden_size=self.embed_dim,
|
||||
head_size=self.head_dim,
|
||||
total_num_heads=self.num_heads,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.qkv_proj",
|
||||
disable_tp=use_data_parallel,
|
||||
)
|
||||
self.out_proj = RowParallelLinear(
|
||||
input_size=self.embed_dim,
|
||||
output_size=self.embed_dim,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.out_proj",
|
||||
disable_tp=use_data_parallel,
|
||||
)
|
||||
|
||||
self.tp_size = (
|
||||
@ -249,31 +254,13 @@ class Siglip2Attention(nn.Module):
|
||||
self.num_heads_per_partition = divide(self.num_heads, self.tp_size)
|
||||
self.use_rope = config.use_rope
|
||||
|
||||
# Detect attention implementation.
|
||||
self.attn_backend = get_vit_attn_backend(
|
||||
self.attn = MMEncoderAttention(
|
||||
num_heads=self.num_heads_per_partition,
|
||||
head_size=self.head_dim,
|
||||
dtype=torch.get_default_dtype(),
|
||||
attn_backend_override=attn_backend_override,
|
||||
prefix=f"{prefix}.attn",
|
||||
multimodal_config=multimodal_config,
|
||||
)
|
||||
|
||||
self.attn_backend, self.flash_attn_varlen_func = (
|
||||
maybe_get_vit_flash_attn_backend(
|
||||
self.attn_backend,
|
||||
attn_backend_override=attn_backend_override,
|
||||
)
|
||||
)
|
||||
|
||||
if self.attn_backend not in {
|
||||
AttentionBackendEnum.FLASH_ATTN,
|
||||
AttentionBackendEnum.TORCH_SDPA,
|
||||
AttentionBackendEnum.ROCM_AITER_FA,
|
||||
}:
|
||||
self.attn_backend = AttentionBackendEnum.TORCH_SDPA
|
||||
self.is_flash_attn_backend = self.attn_backend in {
|
||||
AttentionBackendEnum.FLASH_ATTN,
|
||||
AttentionBackendEnum.ROCM_AITER_FA,
|
||||
}
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@ -298,46 +285,23 @@ class Siglip2Attention(nn.Module):
|
||||
keys.unsqueeze(0),
|
||||
cos,
|
||||
sin,
|
||||
self.is_flash_attn_backend,
|
||||
self.attn.is_flash_attn_backend,
|
||||
)
|
||||
queries = queries.squeeze(0)
|
||||
keys = keys.squeeze(0)
|
||||
|
||||
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
|
||||
if self.is_flash_attn_backend:
|
||||
attn_output = self.flash_attn_varlen_func(
|
||||
queries,
|
||||
keys,
|
||||
values,
|
||||
cu_seqlens_q=cu_seqlens,
|
||||
cu_seqlens_k=cu_seqlens,
|
||||
max_seqlen_q=max_seqlen,
|
||||
max_seqlen_k=max_seqlen,
|
||||
).reshape(seq_length, -1)
|
||||
elif self.attn_backend == AttentionBackendEnum.TORCH_SDPA:
|
||||
# Execute attention entry by entry for speed & less VRAM.
|
||||
batch_size = cu_seqlens.shape[0] - 1
|
||||
outputs = []
|
||||
cu = cu_seqlens.tolist()
|
||||
for i in range(batch_size):
|
||||
start_idx = cu[i]
|
||||
end_idx = cu[i + 1]
|
||||
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
|
||||
attn_output = self.attn(
|
||||
query=queries.unsqueeze(0),
|
||||
key=keys.unsqueeze(0),
|
||||
value=values.unsqueeze(0),
|
||||
cu_seqlens=cu_seqlens,
|
||||
max_seqlen=max_seqlen,
|
||||
)
|
||||
attn_output = attn_output.reshape(
|
||||
seq_length, self.num_heads_per_partition * self.head_dim
|
||||
)
|
||||
|
||||
# Each sequence is processed independently.
|
||||
q_i = queries[start_idx:end_idx].unsqueeze(0)
|
||||
k_i = keys[start_idx:end_idx].unsqueeze(0)
|
||||
v_i = values[start_idx:end_idx].unsqueeze(0)
|
||||
|
||||
# (1, seq_len, num_heads, head_dim) ->
|
||||
# (1, num_heads, seq_len, head_dim)
|
||||
q_i, k_i, v_i = [x.transpose(1, 2) for x in (q_i, k_i, v_i)]
|
||||
|
||||
output_i = F.scaled_dot_product_attention(q_i, k_i, v_i, dropout_p=0.0)
|
||||
# (1, num_heads, seq_len, head_dim) -> (seq_len, embed_dim)
|
||||
output_i = output_i.transpose(1, 2).reshape(end_idx - start_idx, -1)
|
||||
outputs.append(output_i)
|
||||
|
||||
attn_output = torch.cat(outputs, dim=0)
|
||||
attn_output, _ = self.out_proj(attn_output)
|
||||
return attn_output
|
||||
|
||||
@ -347,25 +311,30 @@ class Siglip2MLP(nn.Module):
|
||||
self,
|
||||
config: Siglip2VisionConfig,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
multimodal_config: MultiModalConfig | None = None,
|
||||
prefix: str = "",
|
||||
use_data_parallel: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
use_data_parallel = (
|
||||
multimodal_config.mm_encoder_tp_mode == "data"
|
||||
if multimodal_config
|
||||
else False
|
||||
)
|
||||
self.activation_fn = get_act_fn(config.hidden_act)
|
||||
# TODO(Isotr0py): Enable data parallel after we support
|
||||
# disabling TP on parallel linear layer
|
||||
self.fc1 = ColumnParallelLinear(
|
||||
config.hidden_size,
|
||||
config.intermediate_size,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.fc1",
|
||||
disable_tp=use_data_parallel,
|
||||
)
|
||||
self.fc2 = RowParallelLinear(
|
||||
config.intermediate_size,
|
||||
config.hidden_size,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.fc2",
|
||||
disable_tp=use_data_parallel,
|
||||
)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
@ -380,9 +349,8 @@ class Siglip2EncoderLayer(nn.Module):
|
||||
self,
|
||||
config: Siglip2VisionConfig,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
multimodal_config: MultiModalConfig | None = None,
|
||||
prefix: str = "",
|
||||
use_data_parallel: bool = False,
|
||||
attn_backend_override: AttentionBackendEnum | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.embed_dim = config.hidden_size
|
||||
@ -390,16 +358,15 @@ class Siglip2EncoderLayer(nn.Module):
|
||||
self.self_attn = Siglip2Attention(
|
||||
config,
|
||||
quant_config=quant_config,
|
||||
multimodal_config=multimodal_config,
|
||||
prefix=f"{prefix}.self_attn",
|
||||
use_data_parallel=use_data_parallel,
|
||||
attn_backend_override=attn_backend_override,
|
||||
)
|
||||
self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
|
||||
self.mlp = Siglip2MLP(
|
||||
config,
|
||||
quant_config=quant_config,
|
||||
multimodal_config=multimodal_config,
|
||||
prefix=f"{prefix}.mlp",
|
||||
use_data_parallel=use_data_parallel,
|
||||
)
|
||||
|
||||
def forward(
|
||||
@ -444,9 +411,8 @@ class Siglip2Encoder(nn.Module):
|
||||
self,
|
||||
config: Siglip2VisionConfig,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
multimodal_config: MultiModalConfig | None = None,
|
||||
prefix: str = "",
|
||||
use_data_parallel: bool = False,
|
||||
attn_backend_override: AttentionBackendEnum | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
@ -455,9 +421,8 @@ class Siglip2Encoder(nn.Module):
|
||||
Siglip2EncoderLayer(
|
||||
config,
|
||||
quant_config=quant_config,
|
||||
multimodal_config=multimodal_config,
|
||||
prefix=f"{prefix}.layers.{idx}",
|
||||
use_data_parallel=use_data_parallel,
|
||||
attn_backend_override=attn_backend_override,
|
||||
)
|
||||
for idx in range(config.num_hidden_layers)
|
||||
]
|
||||
@ -630,9 +595,8 @@ class Siglip2VisionTransformer(nn.Module):
|
||||
self,
|
||||
config: Siglip2VisionConfig,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
multimodal_config: MultiModalConfig | None = None,
|
||||
prefix: str = "",
|
||||
use_data_parallel: bool = False,
|
||||
attn_backend_override: AttentionBackendEnum | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
@ -642,9 +606,8 @@ class Siglip2VisionTransformer(nn.Module):
|
||||
self.encoder = Siglip2Encoder(
|
||||
config,
|
||||
quant_config=quant_config,
|
||||
multimodal_config=multimodal_config,
|
||||
prefix=f"{prefix}.encoder",
|
||||
use_data_parallel=use_data_parallel,
|
||||
attn_backend_override=attn_backend_override,
|
||||
)
|
||||
self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
|
||||
|
||||
@ -671,18 +634,16 @@ class Siglip2NavitModel(torch.nn.Module):
|
||||
self,
|
||||
config: Siglip2VisionConfig,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
multimodal_config: MultiModalConfig | None = None,
|
||||
prefix: str = "",
|
||||
use_data_parallel: bool = False,
|
||||
attn_backend_override: AttentionBackendEnum | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.vision_model = Siglip2VisionTransformer(
|
||||
config,
|
||||
quant_config=quant_config,
|
||||
multimodal_config=multimodal_config,
|
||||
prefix=f"{prefix}.vision_model",
|
||||
use_data_parallel=use_data_parallel,
|
||||
attn_backend_override=attn_backend_override,
|
||||
)
|
||||
|
||||
def forward(
|
||||
|
||||
@ -88,14 +88,17 @@ def get_vit_attn_backend(
|
||||
"""
|
||||
Get the available attention backend for Vision Transformer.
|
||||
"""
|
||||
if attn_backend_override is not None:
|
||||
return attn_backend_override
|
||||
attn_backend = attn_backend_override
|
||||
|
||||
selected_backend = get_current_vllm_config().attention_config.backend
|
||||
if selected_backend is not None:
|
||||
return selected_backend
|
||||
if attn_backend is None:
|
||||
attn_backend = selected_backend
|
||||
|
||||
return current_platform.get_vit_attn_backend(head_size, dtype)
|
||||
return current_platform.get_vit_attn_backend(
|
||||
head_size,
|
||||
dtype,
|
||||
backend=attn_backend,
|
||||
)
|
||||
|
||||
|
||||
def should_torch_compile_mm_vit(vllm_config: VllmConfig) -> bool:
|
||||
|
||||
@ -7,7 +7,7 @@ pynvml. However, it should not initialize cuda context.
|
||||
import os
|
||||
from collections.abc import Callable
|
||||
from functools import cache, wraps
|
||||
from typing import TYPE_CHECKING, TypeVar
|
||||
from typing import TYPE_CHECKING, Optional, TypeVar
|
||||
|
||||
import torch
|
||||
from typing_extensions import ParamSpec
|
||||
@ -255,23 +255,6 @@ class CudaPlatformBase(Platform):
|
||||
torch.cuda.reset_peak_memory_stats(device)
|
||||
return torch.cuda.max_memory_allocated(device)
|
||||
|
||||
@classmethod
|
||||
def get_vit_attn_backend(
|
||||
cls, head_size: int, dtype: torch.dtype
|
||||
) -> "AttentionBackendEnum":
|
||||
# Try FlashAttention first
|
||||
if (cc := cls.get_device_capability()) and cc.major >= 8:
|
||||
try:
|
||||
backend_class = AttentionBackendEnum.FLASH_ATTN.get_class()
|
||||
if backend_class.supports_head_size(
|
||||
head_size
|
||||
) and backend_class.supports_dtype(dtype):
|
||||
return AttentionBackendEnum.FLASH_ATTN
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
return AttentionBackendEnum.TORCH_SDPA
|
||||
|
||||
@classmethod
|
||||
def get_valid_backends(
|
||||
cls,
|
||||
@ -418,6 +401,41 @@ class CudaPlatformBase(Platform):
|
||||
|
||||
return selected_backend.get_path()
|
||||
|
||||
@classmethod
|
||||
def get_supported_vit_attn_backends(cls) -> list["AttentionBackendEnum"]:
|
||||
return [
|
||||
AttentionBackendEnum.TORCH_SDPA,
|
||||
AttentionBackendEnum.FLASH_ATTN,
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def get_vit_attn_backend(
|
||||
cls,
|
||||
head_size: int,
|
||||
dtype: torch.dtype,
|
||||
backend: Optional["AttentionBackendEnum"] = None,
|
||||
) -> "AttentionBackendEnum":
|
||||
if backend is not None:
|
||||
assert backend in cls.get_supported_vit_attn_backends(), (
|
||||
f"Backend {backend} is not supported for vit attention. "
|
||||
f"Supported backends are: {cls.get_supported_vit_attn_backends()}"
|
||||
)
|
||||
logger.info_once(f"Using backend {backend} for vit attention")
|
||||
return backend
|
||||
|
||||
# Try FlashAttention first
|
||||
if (cc := cls.get_device_capability()) and cc.major >= 8:
|
||||
try:
|
||||
backend_class = AttentionBackendEnum.FLASH_ATTN.get_class()
|
||||
if backend_class.supports_head_size(
|
||||
head_size
|
||||
) and backend_class.supports_dtype(dtype):
|
||||
return AttentionBackendEnum.FLASH_ATTN
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
return AttentionBackendEnum.TORCH_SDPA
|
||||
|
||||
@classmethod
|
||||
def get_punica_wrapper(cls) -> str:
|
||||
return "vllm.lora.punica_wrapper.punica_gpu.PunicaWrapperGPU"
|
||||
|
||||
@ -7,7 +7,7 @@ import platform
|
||||
import random
|
||||
import sys
|
||||
from datetime import timedelta
|
||||
from typing import TYPE_CHECKING, Any, NamedTuple
|
||||
from typing import TYPE_CHECKING, Any, NamedTuple, Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@ -222,12 +222,6 @@ class Platform:
|
||||
with contextlib.suppress(ImportError):
|
||||
import vllm._moe_C # noqa: F401
|
||||
|
||||
@classmethod
|
||||
def get_vit_attn_backend(
|
||||
cls, head_size: int, dtype: torch.dtype
|
||||
) -> "AttentionBackendEnum":
|
||||
return AttentionBackendEnum.TORCH_SDPA
|
||||
|
||||
@classmethod
|
||||
def get_attn_backend_cls(
|
||||
cls,
|
||||
@ -245,6 +239,43 @@ class Platform:
|
||||
"""Get the attention backend class of a device."""
|
||||
return ""
|
||||
|
||||
@classmethod
|
||||
def get_supported_vit_attn_backends(cls) -> list["AttentionBackendEnum"]:
|
||||
return [
|
||||
AttentionBackendEnum.TORCH_SDPA,
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def get_vit_attn_backend(
|
||||
cls,
|
||||
head_size: int,
|
||||
dtype: torch.dtype,
|
||||
backend: Optional["AttentionBackendEnum"] = None,
|
||||
) -> "AttentionBackendEnum":
|
||||
"""
|
||||
Get the vision attention backend class of a device.
|
||||
|
||||
NOTE: ViT Attention should be checked and override in the platform-specific
|
||||
implementation. we should not override this in any other places, like
|
||||
the model_executor/models/<model_name>.py.
|
||||
|
||||
We check if the backend is None or not:
|
||||
1. If not, check if the backend is supported by the platform.
|
||||
2. If None, continue to the default selection logic.
|
||||
"""
|
||||
if backend is not None:
|
||||
assert backend in cls.get_supported_vit_attn_backends(), (
|
||||
f"Backend {backend} is not supported for vit attention"
|
||||
f"Supported backends are: {cls.get_supported_vit_attn_backends()}"
|
||||
)
|
||||
logger.info_once(f"Using backend {backend} for vit attention")
|
||||
return backend
|
||||
|
||||
logger.info_once(
|
||||
f"Using default backend {AttentionBackendEnum.TORCH_SDPA} for vit attention"
|
||||
)
|
||||
return AttentionBackendEnum.TORCH_SDPA
|
||||
|
||||
@classmethod
|
||||
def get_device_capability(
|
||||
cls,
|
||||
|
||||
@ -3,7 +3,7 @@
|
||||
|
||||
import os
|
||||
from functools import cache, lru_cache, wraps
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
import torch
|
||||
|
||||
@ -187,24 +187,6 @@ class RocmPlatform(Platform):
|
||||
if not on_gfx9():
|
||||
supported_quantization += ["bitsandbytes"]
|
||||
|
||||
@classmethod
|
||||
def get_vit_attn_backend(
|
||||
cls, head_size: int, dtype: torch.dtype
|
||||
) -> AttentionBackendEnum:
|
||||
from importlib.util import find_spec
|
||||
|
||||
from vllm._aiter_ops import rocm_aiter_ops
|
||||
|
||||
if rocm_aiter_ops.is_mha_enabled():
|
||||
# Note: AITER FA is only supported for Qwen-VL models.
|
||||
# TODO: Add support for other VL models in their model class.
|
||||
return AttentionBackendEnum.ROCM_AITER_FA
|
||||
|
||||
if on_gfx9() and find_spec("flash_attn") is not None:
|
||||
return AttentionBackendEnum.FLASH_ATTN
|
||||
|
||||
return AttentionBackendEnum.TORCH_SDPA
|
||||
|
||||
@classmethod
|
||||
def get_attn_backend_cls(
|
||||
cls,
|
||||
@ -322,6 +304,43 @@ class RocmPlatform(Platform):
|
||||
"ROCm. Note that V0 attention backends have been removed."
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_supported_vit_attn_backends(cls) -> list["AttentionBackendEnum"]:
|
||||
return [
|
||||
AttentionBackendEnum.FLASH_ATTN,
|
||||
AttentionBackendEnum.ROCM_AITER_FA,
|
||||
AttentionBackendEnum.TORCH_SDPA,
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def get_vit_attn_backend(
|
||||
cls,
|
||||
head_size: int,
|
||||
dtype: torch.dtype,
|
||||
backend: Optional["AttentionBackendEnum"] = None,
|
||||
) -> "AttentionBackendEnum":
|
||||
if backend is not None:
|
||||
assert backend in cls.get_supported_vit_attn_backends(), (
|
||||
f"Backend {backend} is not supported for vit attention. "
|
||||
f"Supported backends are: {cls.get_supported_vit_attn_backends()}"
|
||||
)
|
||||
logger.info_once(f"Using backend {backend} for vit attention")
|
||||
return backend
|
||||
|
||||
from importlib.util import find_spec
|
||||
|
||||
from vllm._aiter_ops import rocm_aiter_ops
|
||||
|
||||
if rocm_aiter_ops.is_mha_enabled():
|
||||
# Note: AITER FA is only supported for Qwen-VL models.
|
||||
# TODO: Add support for other VL models in their model class.
|
||||
return AttentionBackendEnum.ROCM_AITER_FA
|
||||
|
||||
if on_gfx9() and find_spec("flash_attn") is not None:
|
||||
return AttentionBackendEnum.FLASH_ATTN
|
||||
|
||||
return AttentionBackendEnum.TORCH_SDPA
|
||||
|
||||
@classmethod
|
||||
def set_device(cls, device: torch.device) -> None:
|
||||
"""
|
||||
|
||||
@ -2,7 +2,7 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import contextlib
|
||||
from typing import TYPE_CHECKING, cast
|
||||
from typing import TYPE_CHECKING, Optional, cast
|
||||
|
||||
import torch
|
||||
from tpu_info import device
|
||||
@ -75,6 +75,32 @@ class TpuPlatform(Platform):
|
||||
logger.info("Using Pallas V1 backend.")
|
||||
return AttentionBackendEnum.PALLAS.get_path()
|
||||
|
||||
@classmethod
|
||||
def get_supported_vit_attn_backends(cls) -> list["AttentionBackendEnum"]:
|
||||
return [
|
||||
AttentionBackendEnum.PALLAS,
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def get_vit_attn_backend(
|
||||
cls,
|
||||
head_size: int,
|
||||
dtype: torch.dtype,
|
||||
backend: Optional["AttentionBackendEnum"] = None,
|
||||
) -> "AttentionBackendEnum":
|
||||
if backend is not None:
|
||||
assert backend in cls.get_supported_vit_attn_backends(), (
|
||||
f"Backend {backend} is not supported for vit attention"
|
||||
f"Supported backends are: {cls.get_supported_vit_attn_backends()}."
|
||||
)
|
||||
logger.info_once(f"Using backend {backend} for vit attention.")
|
||||
return backend
|
||||
|
||||
logger.info_once(
|
||||
f"Using default backend {AttentionBackendEnum.PALLAS} for vit attention."
|
||||
)
|
||||
return AttentionBackendEnum.PALLAS
|
||||
|
||||
@classmethod
|
||||
def set_device(cls, device: torch.device) -> None:
|
||||
"""
|
||||
|
||||
@ -3,7 +3,7 @@
|
||||
|
||||
import contextlib
|
||||
import os
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
import torch
|
||||
|
||||
@ -77,6 +77,34 @@ class XPUPlatform(Platform):
|
||||
logger.info("Using Flash Attention backend.")
|
||||
return AttentionBackendEnum.FLASH_ATTN.get_path()
|
||||
|
||||
@classmethod
|
||||
def get_supported_vit_attn_backends(cls) -> list["AttentionBackendEnum"]:
|
||||
# XPU only supports FLASH_ATTN for vision attention.
|
||||
return [
|
||||
AttentionBackendEnum.FLASH_ATTN,
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def get_vit_attn_backend(
|
||||
cls,
|
||||
head_size: int,
|
||||
dtype: torch.dtype,
|
||||
backend: Optional["AttentionBackendEnum"] = None,
|
||||
) -> "AttentionBackendEnum":
|
||||
if backend is not None:
|
||||
assert backend in cls.get_supported_vit_attn_backends(), (
|
||||
f"Backend {backend} is not supported for vit attention. "
|
||||
f"Supported backends are: "
|
||||
f"{cls.get_supported_vit_attn_backends()}."
|
||||
)
|
||||
logger.info_once(f"Using backend {backend} for vit attention")
|
||||
return backend
|
||||
|
||||
logger.info_once(
|
||||
f"Using backend {AttentionBackendEnum.FLASH_ATTN} for vit attention"
|
||||
)
|
||||
return AttentionBackendEnum.FLASH_ATTN
|
||||
|
||||
@classmethod
|
||||
def set_device(cls, device: torch.device) -> None:
|
||||
"""
|
||||
@ -110,12 +138,6 @@ class XPUPlatform(Platform):
|
||||
device_props = torch.xpu.get_device_properties(device_id)
|
||||
return device_props.total_memory
|
||||
|
||||
@classmethod
|
||||
def get_vit_attn_backend(
|
||||
cls, head_size: int, dtype: torch.dtype
|
||||
) -> "AttentionBackendEnum":
|
||||
return AttentionBackendEnum.FLASH_ATTN
|
||||
|
||||
@classmethod
|
||||
def inference_mode(cls):
|
||||
return torch.no_grad()
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user