mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-06-08 13:42:19 +08:00
[MM Encoder]: Migrate legacy ViT MultiHeadAttention to new MMEncoderAttention interface (#30684)
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
This commit is contained in:
parent
62be3670cb
commit
700a5ad6c6
@ -9,7 +9,8 @@ import torch
|
|||||||
from tests.kernels.allclose_default import get_default_atol, get_default_rtol
|
from tests.kernels.allclose_default import get_default_atol, get_default_rtol
|
||||||
from tests.kernels.utils import opcheck
|
from tests.kernels.utils import opcheck
|
||||||
from vllm import _custom_ops as ops
|
from vllm import _custom_ops as ops
|
||||||
from vllm.attention.layer import Attention, MultiHeadAttention
|
from vllm.attention.layer import Attention
|
||||||
|
from vllm.attention.layers.mm_encoder_attention import MMEncoderAttention
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.utils.mem_utils import get_max_shared_memory_bytes
|
from vllm.utils.mem_utils import get_max_shared_memory_bytes
|
||||||
|
|
||||||
@ -442,7 +443,7 @@ def ref_multi_query_kv_attention(
|
|||||||
return torch.cat(ref_outputs, dim=0)
|
return torch.cat(ref_outputs, dim=0)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("attention_cls", [Attention, MultiHeadAttention])
|
@pytest.mark.parametrize("attention_cls", [Attention, MMEncoderAttention])
|
||||||
def test_num_heads_not_divisble_by_num_kv_heads(attention_cls: type) -> None:
|
def test_num_heads_not_divisble_by_num_kv_heads(attention_cls: type) -> None:
|
||||||
head_size = 64
|
head_size = 64
|
||||||
scale = float(1.0 / (head_size**0.5))
|
scale = float(1.0 / (head_size**0.5))
|
||||||
|
|||||||
@ -3,16 +3,17 @@
|
|||||||
"""
|
"""
|
||||||
Test:
|
Test:
|
||||||
|
|
||||||
* Tests for MultiHeadAttention layer
|
* Tests for MMEncoderAttention layer
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import itertools
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from vllm.attention.backends.registry import AttentionBackendEnum
|
from vllm.attention.backends.registry import AttentionBackendEnum
|
||||||
from vllm.attention.layer import MultiHeadAttention
|
from vllm.attention.layers.mm_encoder_attention import MMEncoderAttention
|
||||||
from vllm.attention.selector import _cached_get_attn_backend
|
from vllm.attention.selector import _cached_get_attn_backend
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.platforms.cpu import CpuPlatform
|
from vllm.platforms.cpu import CpuPlatform
|
||||||
@ -42,35 +43,31 @@ def test_mha_attn_platform(device: str):
|
|||||||
|
|
||||||
if device == "cpu":
|
if device == "cpu":
|
||||||
with (
|
with (
|
||||||
patch("vllm.attention.layer.current_platform", CpuPlatform()),
|
|
||||||
patch("vllm.model_executor.models.vision.current_platform", CpuPlatform()),
|
patch("vllm.model_executor.models.vision.current_platform", CpuPlatform()),
|
||||||
):
|
):
|
||||||
attn = MultiHeadAttention(16, 64, scale=1)
|
attn = MMEncoderAttention(16, 64, scale=1)
|
||||||
assert attn.attn_backend == AttentionBackendEnum.TORCH_SDPA
|
assert attn.attn_backend == AttentionBackendEnum.TORCH_SDPA
|
||||||
elif device == "hip":
|
elif device == "hip":
|
||||||
with (
|
with (
|
||||||
patch("vllm.attention.layer.current_platform", RocmPlatform()),
|
|
||||||
patch("vllm.model_executor.models.vision.current_platform", RocmPlatform()),
|
patch("vllm.model_executor.models.vision.current_platform", RocmPlatform()),
|
||||||
):
|
):
|
||||||
attn = MultiHeadAttention(16, 64, scale=1)
|
attn = MMEncoderAttention(16, 64, scale=1)
|
||||||
assert attn.attn_backend == AttentionBackendEnum.FLASH_ATTN
|
assert attn.attn_backend == AttentionBackendEnum.FLASH_ATTN
|
||||||
else:
|
else:
|
||||||
# Test CUDA with head_size=64 (divisible by 32)
|
# Test CUDA with head_size=64 (divisible by 32)
|
||||||
# - should use vLLM's FlashAttention
|
# - should use vLLM's FlashAttention
|
||||||
with (
|
with (
|
||||||
patch("vllm.attention.layer.current_platform", CudaPlatform()),
|
|
||||||
patch("vllm.model_executor.models.vision.current_platform", CudaPlatform()),
|
patch("vllm.model_executor.models.vision.current_platform", CudaPlatform()),
|
||||||
):
|
):
|
||||||
attn = MultiHeadAttention(16, 64, scale=1)
|
attn = MMEncoderAttention(16, 64, scale=1)
|
||||||
assert attn.attn_backend == AttentionBackendEnum.FLASH_ATTN
|
assert attn.attn_backend == AttentionBackendEnum.FLASH_ATTN
|
||||||
|
|
||||||
# Test CUDA with head_size=72 (not divisible by 32)
|
# Test CUDA with head_size=72 (not divisible by 32)
|
||||||
# - should use vLLM's FlashAttention
|
# - should use vLLM's FlashAttention
|
||||||
with (
|
with (
|
||||||
patch("vllm.attention.layer.current_platform", CudaPlatform()),
|
|
||||||
patch("vllm.model_executor.models.vision.current_platform", CudaPlatform()),
|
patch("vllm.model_executor.models.vision.current_platform", CudaPlatform()),
|
||||||
):
|
):
|
||||||
attn = MultiHeadAttention(16, 72, scale=1)
|
attn = MMEncoderAttention(16, 72, scale=1)
|
||||||
assert attn.attn_backend == AttentionBackendEnum.FLASH_ATTN
|
assert attn.attn_backend == AttentionBackendEnum.FLASH_ATTN
|
||||||
|
|
||||||
|
|
||||||
@ -94,6 +91,10 @@ def ref_attention(
|
|||||||
|
|
||||||
BATCH_SIZES = [1, 16]
|
BATCH_SIZES = [1, 16]
|
||||||
SEQ_LENS = [1]
|
SEQ_LENS = [1]
|
||||||
|
VAR_SEQ_LENS = [
|
||||||
|
[2, 2],
|
||||||
|
[2, 3, 4],
|
||||||
|
]
|
||||||
NUM_HEADS = [1, 16]
|
NUM_HEADS = [1, 16]
|
||||||
NUM_KV_HEADS = [1]
|
NUM_KV_HEADS = [1]
|
||||||
HEAD_SIZES = [64, 80]
|
HEAD_SIZES = [64, 80]
|
||||||
@ -130,7 +131,7 @@ def test_mha_attn_forward(
|
|||||||
k = torch.randn(batch_size, seq_len, num_kv_heads * head_size)
|
k = torch.randn(batch_size, seq_len, num_kv_heads * head_size)
|
||||||
v = torch.randn(batch_size, seq_len, num_kv_heads * head_size)
|
v = torch.randn(batch_size, seq_len, num_kv_heads * head_size)
|
||||||
scale = 1.0 / head_size**0.5
|
scale = 1.0 / head_size**0.5
|
||||||
attn = MultiHeadAttention(
|
attn = MMEncoderAttention(
|
||||||
num_heads, head_size, scale=scale, num_kv_heads=num_kv_heads
|
num_heads, head_size, scale=scale, num_kv_heads=num_kv_heads
|
||||||
)
|
)
|
||||||
output = attn(q, k, v)
|
output = attn(q, k, v)
|
||||||
@ -151,3 +152,58 @@ def test_mha_attn_forward(
|
|||||||
scale=scale,
|
scale=scale,
|
||||||
).reshape(batch_size, seq_len, num_heads * head_size)
|
).reshape(batch_size, seq_len, num_heads * head_size)
|
||||||
torch.testing.assert_close(output, ref_output)
|
torch.testing.assert_close(output, ref_output)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("var_seq_len", VAR_SEQ_LENS)
|
||||||
|
@pytest.mark.parametrize("num_heads", NUM_HEADS)
|
||||||
|
@pytest.mark.parametrize("num_kv_heads", NUM_KV_HEADS)
|
||||||
|
@pytest.mark.parametrize("head_size", HEAD_SIZES)
|
||||||
|
@pytest.mark.parametrize("dtype", DTYPES)
|
||||||
|
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||||
|
def test_mha_attn_varlen_forward(
|
||||||
|
var_seq_len: list[int],
|
||||||
|
num_heads: int,
|
||||||
|
num_kv_heads: int,
|
||||||
|
head_size: int,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
device: str,
|
||||||
|
):
|
||||||
|
current_platform.seed_everything(0)
|
||||||
|
torch.set_default_device(device)
|
||||||
|
torch.set_default_dtype(dtype)
|
||||||
|
|
||||||
|
q = torch.randn(1, sum(var_seq_len), num_heads, head_size)
|
||||||
|
k = torch.randn(1, sum(var_seq_len), num_kv_heads, head_size)
|
||||||
|
v = torch.randn(1, sum(var_seq_len), num_kv_heads, head_size)
|
||||||
|
cu_seqlens = torch.tensor(
|
||||||
|
[0] + list(itertools.accumulate(var_seq_len)), dtype=torch.int32
|
||||||
|
)
|
||||||
|
scale = 1.0 / head_size**0.5
|
||||||
|
attn = MMEncoderAttention(
|
||||||
|
num_heads, head_size, scale=scale, num_kv_heads=num_kv_heads
|
||||||
|
)
|
||||||
|
output = attn(
|
||||||
|
q, k, v, cu_seqlens=cu_seqlens, max_seqlen=torch.tensor(max(var_seq_len))
|
||||||
|
)
|
||||||
|
|
||||||
|
assert num_heads % num_kv_heads == 0
|
||||||
|
num_queries_per_kv = num_heads // num_kv_heads
|
||||||
|
if num_queries_per_kv > 1:
|
||||||
|
k = torch.repeat_interleave(k, num_queries_per_kv, dim=2)
|
||||||
|
v = torch.repeat_interleave(v, num_queries_per_kv, dim=2)
|
||||||
|
|
||||||
|
ref_output = []
|
||||||
|
for q_i, k_i, v_i in zip(
|
||||||
|
torch.split(q, var_seq_len, dim=1),
|
||||||
|
torch.split(k, var_seq_len, dim=1),
|
||||||
|
torch.split(v, var_seq_len, dim=1),
|
||||||
|
):
|
||||||
|
output_i = ref_attention(
|
||||||
|
q_i,
|
||||||
|
k_i,
|
||||||
|
v_i,
|
||||||
|
scale=scale,
|
||||||
|
)
|
||||||
|
ref_output.append(output_i)
|
||||||
|
ref_output = torch.cat(ref_output, dim=1)
|
||||||
|
torch.testing.assert_close(output, ref_output, atol=1e-2, rtol=1e-2)
|
||||||
|
|||||||
@ -3,7 +3,7 @@
|
|||||||
"""
|
"""
|
||||||
Test:
|
Test:
|
||||||
|
|
||||||
* Tests for MultiHeadAttention layer
|
* Tests for MMEncoderAttention layer
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
@ -12,7 +12,7 @@ import torch_xla
|
|||||||
import torch_xla.core
|
import torch_xla.core
|
||||||
import torch_xla.core.xla_model
|
import torch_xla.core.xla_model
|
||||||
|
|
||||||
from vllm.attention.layer import MultiHeadAttention
|
from vllm.attention.layers.mm_encoder_attention import MMEncoderAttention
|
||||||
from vllm.attention.selector import _cached_get_attn_backend
|
from vllm.attention.selector import _cached_get_attn_backend
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
@ -69,7 +69,7 @@ def test_mha_attn_forward(
|
|||||||
k = torch.randn(batch_size, seq_len, num_kv_heads * head_size, device=device)
|
k = torch.randn(batch_size, seq_len, num_kv_heads * head_size, device=device)
|
||||||
v = torch.randn(batch_size, seq_len, num_kv_heads * head_size, device=device)
|
v = torch.randn(batch_size, seq_len, num_kv_heads * head_size, device=device)
|
||||||
scale = 1.0 / head_size**0.5
|
scale = 1.0 / head_size**0.5
|
||||||
attn = MultiHeadAttention(
|
attn = MMEncoderAttention(
|
||||||
num_heads, head_size, scale=scale, num_kv_heads=num_kv_heads
|
num_heads, head_size, scale=scale, num_kv_heads=num_kv_heads
|
||||||
)
|
)
|
||||||
output = attn(q, k, v)
|
output = attn(q, k, v)
|
||||||
|
|||||||
@ -2,12 +2,10 @@
|
|||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
"""Attention layer."""
|
"""Attention layer."""
|
||||||
|
|
||||||
import functools
|
|
||||||
from typing import cast
|
from typing import cast
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
|
||||||
|
|
||||||
import vllm.envs as envs
|
import vllm.envs as envs
|
||||||
from vllm.attention.backends.abstract import (
|
from vllm.attention.backends.abstract import (
|
||||||
@ -16,13 +14,10 @@ from vllm.attention.backends.abstract import (
|
|||||||
MLAAttentionImpl,
|
MLAAttentionImpl,
|
||||||
)
|
)
|
||||||
from vllm.attention.backends.registry import AttentionBackendEnum
|
from vllm.attention.backends.registry import AttentionBackendEnum
|
||||||
from vllm.attention.layers.mm_encoder_attention import maybe_get_vit_flash_attn_backend
|
|
||||||
from vllm.attention.selector import get_attn_backend
|
from vllm.attention.selector import get_attn_backend
|
||||||
from vllm.attention.utils.fa_utils import get_flash_attn_version
|
|
||||||
from vllm.attention.utils.kv_sharing_utils import validate_kv_sharing_target
|
from vllm.attention.utils.kv_sharing_utils import validate_kv_sharing_target
|
||||||
from vllm.attention.utils.kv_transfer_utils import maybe_transfer_kv_layer
|
from vllm.attention.utils.kv_transfer_utils import maybe_transfer_kv_layer
|
||||||
from vllm.config import CacheConfig, get_current_vllm_config
|
from vllm.config import CacheConfig, get_current_vllm_config
|
||||||
from vllm.config.multimodal import MultiModalConfig
|
|
||||||
from vllm.config.vllm import VllmConfig
|
from vllm.config.vllm import VllmConfig
|
||||||
from vllm.forward_context import ForwardContext, get_forward_context
|
from vllm.forward_context import ForwardContext, get_forward_context
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
@ -36,7 +31,6 @@ from vllm.model_executor.layers.quantization import QuantizationConfig
|
|||||||
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
|
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
|
||||||
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
|
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
|
||||||
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
|
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
|
||||||
from vllm.model_executor.models.vision import get_vit_attn_backend
|
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.utils.torch_utils import (
|
from vllm.utils.torch_utils import (
|
||||||
direct_register_custom_op,
|
direct_register_custom_op,
|
||||||
@ -412,132 +406,6 @@ class Attention(nn.Module, AttentionLayerBase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class MultiHeadAttention(nn.Module):
|
|
||||||
"""Multi-headed attention without any cache, used for ViT."""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
num_heads: int,
|
|
||||||
head_size: int,
|
|
||||||
scale: float,
|
|
||||||
num_kv_heads: int | None = None,
|
|
||||||
# This has no effect, it is only here to make it easier to swap
|
|
||||||
# between Attention and MultiHeadAttention
|
|
||||||
prefix: str = "",
|
|
||||||
multimodal_config: MultiModalConfig | None = None,
|
|
||||||
) -> None:
|
|
||||||
super().__init__()
|
|
||||||
self.num_heads = num_heads
|
|
||||||
self.head_size = head_size
|
|
||||||
self.scale = scale
|
|
||||||
self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
|
|
||||||
self.layer_name = prefix
|
|
||||||
|
|
||||||
assert self.num_heads % self.num_kv_heads == 0, (
|
|
||||||
f"num_heads ({self.num_heads}) is not "
|
|
||||||
f"divisible by num_kv_heads ({self.num_kv_heads})"
|
|
||||||
)
|
|
||||||
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
|
||||||
|
|
||||||
# During model initialization, the default dtype is set as the model
|
|
||||||
# weight and activation dtype.
|
|
||||||
dtype = torch.get_default_dtype()
|
|
||||||
|
|
||||||
# Determine the attention backend
|
|
||||||
attn_backend_override = None
|
|
||||||
if multimodal_config is not None:
|
|
||||||
attn_backend_override = multimodal_config.mm_encoder_attn_backend
|
|
||||||
|
|
||||||
self.attn_backend = get_vit_attn_backend(
|
|
||||||
head_size=head_size,
|
|
||||||
dtype=dtype,
|
|
||||||
attn_backend_override=attn_backend_override,
|
|
||||||
)
|
|
||||||
|
|
||||||
self._flash_attn_varlen_func = maybe_get_vit_flash_attn_backend(
|
|
||||||
self.attn_backend,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.is_flash_attn_backend = self.attn_backend in {
|
|
||||||
AttentionBackendEnum.FLASH_ATTN,
|
|
||||||
AttentionBackendEnum.ROCM_AITER_FA,
|
|
||||||
}
|
|
||||||
|
|
||||||
self.fa_version = None
|
|
||||||
if (
|
|
||||||
self.attn_backend == AttentionBackendEnum.FLASH_ATTN
|
|
||||||
and current_platform.is_cuda()
|
|
||||||
):
|
|
||||||
self.fa_version = get_flash_attn_version()
|
|
||||||
assert self._flash_attn_varlen_func is not None
|
|
||||||
self._flash_attn_varlen_func = functools.partial(
|
|
||||||
self._flash_attn_varlen_func, fa_version=self.fa_version
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.info_once(
|
|
||||||
f"Using {self.attn_backend} for MultiHeadAttention in multimodal encoder."
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
query: torch.Tensor,
|
|
||||||
key: torch.Tensor,
|
|
||||||
value: torch.Tensor,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
"""Input shape:
|
|
||||||
(batch_size x seq_len x hidden_size) or
|
|
||||||
(batch_size x seq_len x num_heads x head_size)
|
|
||||||
"""
|
|
||||||
bsz, q_len = query.size()[:2]
|
|
||||||
kv_len = key.size(1)
|
|
||||||
|
|
||||||
query = query.view(bsz, q_len, self.num_heads, self.head_size)
|
|
||||||
key = key.view(bsz, kv_len, self.num_kv_heads, self.head_size)
|
|
||||||
value = value.view(bsz, kv_len, self.num_kv_heads, self.head_size)
|
|
||||||
|
|
||||||
if (num_repeat := self.num_queries_per_kv) > 1:
|
|
||||||
# Handle MQA and GQA
|
|
||||||
key = torch.repeat_interleave(key, num_repeat, dim=2)
|
|
||||||
value = torch.repeat_interleave(value, num_repeat, dim=2)
|
|
||||||
|
|
||||||
if self.is_flash_attn_backend:
|
|
||||||
assert self._flash_attn_varlen_func is not None
|
|
||||||
cu_seqlens_q = torch.arange(
|
|
||||||
0, (bsz + 1) * q_len, step=q_len, dtype=torch.int32, device=query.device
|
|
||||||
)
|
|
||||||
cu_seqlens_k = torch.arange(
|
|
||||||
0, (bsz + 1) * kv_len, step=kv_len, dtype=torch.int32, device=key.device
|
|
||||||
)
|
|
||||||
|
|
||||||
out = self._flash_attn_varlen_func(
|
|
||||||
query.flatten(0, 1),
|
|
||||||
key.flatten(0, 1),
|
|
||||||
value.flatten(0, 1),
|
|
||||||
cu_seqlens_q=cu_seqlens_q,
|
|
||||||
cu_seqlens_k=cu_seqlens_k,
|
|
||||||
max_seqlen_q=q_len,
|
|
||||||
max_seqlen_k=kv_len,
|
|
||||||
softmax_scale=self.scale,
|
|
||||||
)
|
|
||||||
elif self.attn_backend == AttentionBackendEnum.TORCH_SDPA:
|
|
||||||
query, key, value = (x.transpose(1, 2) for x in (query, key, value))
|
|
||||||
out = F.scaled_dot_product_attention(query, key, value, scale=self.scale)
|
|
||||||
out = out.transpose(1, 2)
|
|
||||||
elif self.attn_backend == AttentionBackendEnum.PALLAS:
|
|
||||||
query, key, value = (x.transpose(1, 2) for x in (query, key, value))
|
|
||||||
from torch_xla.experimental.custom_kernel import flash_attention
|
|
||||||
|
|
||||||
out = flash_attention(query, key, value, sm_scale=self.scale)
|
|
||||||
out = out.transpose(1, 2)
|
|
||||||
else:
|
|
||||||
# ViT attention hasn't supported this backend yet
|
|
||||||
raise NotImplementedError(
|
|
||||||
f"ViT attention hasn't supported {self.attn_backend} backend yet."
|
|
||||||
)
|
|
||||||
|
|
||||||
return out.reshape(bsz, q_len, -1)
|
|
||||||
|
|
||||||
|
|
||||||
class MLAAttention(nn.Module, AttentionLayerBase):
|
class MLAAttention(nn.Module, AttentionLayerBase):
|
||||||
"""Multi-Head Latent Attention layer.
|
"""Multi-Head Latent Attention layer.
|
||||||
|
|
||||||
|
|||||||
@ -1,7 +1,6 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
from collections.abc import Callable
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@ -19,27 +18,6 @@ from vllm.model_executor.models.vision import get_vit_attn_backend
|
|||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def maybe_get_vit_flash_attn_backend(
|
|
||||||
attn_backend: AttentionBackendEnum | None,
|
|
||||||
) -> Callable | None:
|
|
||||||
# At this point,
|
|
||||||
# we already have the attn_backend,
|
|
||||||
# overriding logic is done in the platform-specific implementation.
|
|
||||||
# so we don't need to override backend here.
|
|
||||||
# Just return the attn_backend and flash_attn_varlen_func.
|
|
||||||
|
|
||||||
if attn_backend == AttentionBackendEnum.FLASH_ATTN:
|
|
||||||
from vllm.attention.utils.fa_utils import flash_attn_varlen_func
|
|
||||||
elif attn_backend == AttentionBackendEnum.ROCM_AITER_FA:
|
|
||||||
from aiter import flash_attn_varlen_func
|
|
||||||
else:
|
|
||||||
flash_attn_varlen_func = None
|
|
||||||
|
|
||||||
# if attn_backend is TORCH_SDPA,
|
|
||||||
# it will reach here and the flash_attn_varlen_func will be None.
|
|
||||||
return flash_attn_varlen_func
|
|
||||||
|
|
||||||
|
|
||||||
@CustomOp.register("mm_encoder_attn")
|
@CustomOp.register("mm_encoder_attn")
|
||||||
class MMEncoderAttention(CustomOp):
|
class MMEncoderAttention(CustomOp):
|
||||||
"""Multi-headed attention without any cache, used for multimodal encoder."""
|
"""Multi-headed attention without any cache, used for multimodal encoder."""
|
||||||
@ -98,21 +76,17 @@ class MMEncoderAttention(CustomOp):
|
|||||||
AttentionBackendEnum.ROCM_AITER_FA,
|
AttentionBackendEnum.ROCM_AITER_FA,
|
||||||
}
|
}
|
||||||
|
|
||||||
self.flash_attn_varlen_func = maybe_get_vit_flash_attn_backend(
|
self._fa_version = (
|
||||||
self.attn_backend,
|
get_flash_attn_version() if self.is_flash_attn_backend else None
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.is_flash_attn_backend:
|
|
||||||
assert self.flash_attn_varlen_func is not None
|
|
||||||
self._fa_version = get_flash_attn_version()
|
|
||||||
|
|
||||||
logger.info_once(f"Using {self.attn_backend} for MMEncoderAttention.")
|
logger.info_once(f"Using {self.attn_backend} for MMEncoderAttention.")
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def enabled(cls) -> bool:
|
def enabled(cls) -> bool:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def reshape_qkv_to_4d(
|
def maybe_reshape_qkv_to_4d(
|
||||||
self,
|
self,
|
||||||
query: torch.Tensor,
|
query: torch.Tensor,
|
||||||
key: torch.Tensor,
|
key: torch.Tensor,
|
||||||
@ -136,30 +110,6 @@ class MMEncoderAttention(CustomOp):
|
|||||||
|
|
||||||
return query, key, value
|
return query, key, value
|
||||||
|
|
||||||
def reshape_qkv_to_3d(
|
|
||||||
self,
|
|
||||||
query: torch.Tensor,
|
|
||||||
key: torch.Tensor,
|
|
||||||
value: torch.Tensor,
|
|
||||||
bsz: int,
|
|
||||||
q_len: int,
|
|
||||||
kv_len: int,
|
|
||||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
||||||
"""
|
|
||||||
Reshape query, key, value to 3D tensors:
|
|
||||||
(batch_size * seq_len, num_heads, head_size)
|
|
||||||
"""
|
|
||||||
query = query.view(bsz * q_len, self.num_heads, self.head_size)
|
|
||||||
key = key.view(bsz * kv_len, self.num_kv_heads, self.head_size)
|
|
||||||
value = value.view(bsz * kv_len, self.num_kv_heads, self.head_size)
|
|
||||||
|
|
||||||
if (num_repeat := self.num_queries_per_kv) > 1:
|
|
||||||
# Handle MQA and GQA
|
|
||||||
key = torch.repeat_interleave(key, num_repeat, dim=1)
|
|
||||||
value = torch.repeat_interleave(value, num_repeat, dim=1)
|
|
||||||
|
|
||||||
return query, key, value
|
|
||||||
|
|
||||||
def _forward_sdpa(
|
def _forward_sdpa(
|
||||||
self,
|
self,
|
||||||
query: torch.Tensor,
|
query: torch.Tensor,
|
||||||
@ -167,13 +117,15 @@ class MMEncoderAttention(CustomOp):
|
|||||||
value: torch.Tensor,
|
value: torch.Tensor,
|
||||||
cu_seqlens: torch.Tensor | None = None,
|
cu_seqlens: torch.Tensor | None = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
# TODO(Isotr0py): Migrate MultiHeadAttention
|
"""Input shape:
|
||||||
assert cu_seqlens is not None
|
(batch_size x seq_len x hidden_size) or
|
||||||
|
(batch_size x seq_len x num_heads x head_size)
|
||||||
|
"""
|
||||||
bsz, q_len = query.size()[:2]
|
bsz, q_len = query.size()[:2]
|
||||||
kv_len = key.size(1)
|
kv_len = key.size(1)
|
||||||
|
is_reshaped = query.dim() != 4
|
||||||
|
|
||||||
query, key, value = self.reshape_qkv_to_4d(
|
query, key, value = self.maybe_reshape_qkv_to_4d(
|
||||||
query, key, value, bsz, q_len, kv_len
|
query, key, value, bsz, q_len, kv_len
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -183,6 +135,8 @@ class MMEncoderAttention(CustomOp):
|
|||||||
v=value,
|
v=value,
|
||||||
cu_seqlens=cu_seqlens,
|
cu_seqlens=cu_seqlens,
|
||||||
)
|
)
|
||||||
|
if is_reshaped:
|
||||||
|
output = output.view(bsz, q_len, -1)
|
||||||
return output
|
return output
|
||||||
|
|
||||||
def _forward_fa(
|
def _forward_fa(
|
||||||
@ -193,13 +147,21 @@ class MMEncoderAttention(CustomOp):
|
|||||||
cu_seqlens: torch.Tensor | None = None,
|
cu_seqlens: torch.Tensor | None = None,
|
||||||
max_seqlen: torch.Tensor | None = None, # Only used for Flash Attention
|
max_seqlen: torch.Tensor | None = None, # Only used for Flash Attention
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
assert self.flash_attn_varlen_func is not None, (
|
"""Input shape:
|
||||||
"Flash attention function is not set."
|
(batch_size x seq_len x hidden_size) or
|
||||||
)
|
(batch_size x seq_len x num_heads x head_size)
|
||||||
# # TODO(Isotr0py): Migrate MultiHeadAttention
|
"""
|
||||||
assert cu_seqlens is not None and max_seqlen is not None
|
assert (cu_seqlens is not None and max_seqlen is not None) or (
|
||||||
|
cu_seqlens is None and max_seqlen is None
|
||||||
|
), "cu_seqlens and max_seqlen should be both set or both None."
|
||||||
|
|
||||||
bsz = query.shape[0]
|
bsz, q_len = query.size()[:2]
|
||||||
|
kv_len = key.size(1)
|
||||||
|
is_reshaped = query.dim() != 4
|
||||||
|
|
||||||
|
query, key, value = self.maybe_reshape_qkv_to_4d(
|
||||||
|
query, key, value, bsz, q_len, kv_len
|
||||||
|
)
|
||||||
|
|
||||||
output = vit_flash_attn_wrapper(
|
output = vit_flash_attn_wrapper(
|
||||||
q=query,
|
q=query,
|
||||||
@ -211,6 +173,8 @@ class MMEncoderAttention(CustomOp):
|
|||||||
is_rocm_aiter=(self.attn_backend == AttentionBackendEnum.ROCM_AITER_FA),
|
is_rocm_aiter=(self.attn_backend == AttentionBackendEnum.ROCM_AITER_FA),
|
||||||
fa_version=self._fa_version,
|
fa_version=self._fa_version,
|
||||||
)
|
)
|
||||||
|
if is_reshaped:
|
||||||
|
output = output.view(bsz, q_len, -1)
|
||||||
return output
|
return output
|
||||||
|
|
||||||
def forward_native(
|
def forward_native(
|
||||||
|
|||||||
@ -24,11 +24,11 @@ def flash_attn_maxseqlen_wrapper(
|
|||||||
q: torch.Tensor,
|
q: torch.Tensor,
|
||||||
k: torch.Tensor,
|
k: torch.Tensor,
|
||||||
v: torch.Tensor,
|
v: torch.Tensor,
|
||||||
cu_seqlens: torch.Tensor,
|
|
||||||
max_seqlen: torch.Tensor,
|
|
||||||
batch_size: int,
|
batch_size: int,
|
||||||
is_rocm_aiter: bool,
|
is_rocm_aiter: bool,
|
||||||
fa_version: int | None,
|
fa_version: int | None,
|
||||||
|
cu_seqlens: torch.Tensor | None = None,
|
||||||
|
max_seqlen: torch.Tensor | None = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
kwargs = {}
|
kwargs = {}
|
||||||
if is_rocm_aiter:
|
if is_rocm_aiter:
|
||||||
@ -38,6 +38,14 @@ def flash_attn_maxseqlen_wrapper(
|
|||||||
|
|
||||||
if not current_platform.is_rocm() and fa_version is not None:
|
if not current_platform.is_rocm() and fa_version is not None:
|
||||||
kwargs["fa_version"] = fa_version
|
kwargs["fa_version"] = fa_version
|
||||||
|
|
||||||
|
q_len = q.size(1)
|
||||||
|
if cu_seqlens is None:
|
||||||
|
cu_seqlens = torch.arange(
|
||||||
|
0, (batch_size + 1) * q_len, step=q_len, dtype=torch.int32, device=q.device
|
||||||
|
)
|
||||||
|
max_seqlen = q_len if max_seqlen is None else max_seqlen.item()
|
||||||
|
|
||||||
q, k, v = (einops.rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v])
|
q, k, v = (einops.rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v])
|
||||||
output = flash_attn_varlen_func(
|
output = flash_attn_varlen_func(
|
||||||
q,
|
q,
|
||||||
@ -45,8 +53,8 @@ def flash_attn_maxseqlen_wrapper(
|
|||||||
v,
|
v,
|
||||||
cu_seqlens_q=cu_seqlens,
|
cu_seqlens_q=cu_seqlens,
|
||||||
cu_seqlens_k=cu_seqlens,
|
cu_seqlens_k=cu_seqlens,
|
||||||
max_seqlen_q=max_seqlen.item(),
|
max_seqlen_q=max_seqlen,
|
||||||
max_seqlen_k=max_seqlen.item(),
|
max_seqlen_k=max_seqlen,
|
||||||
dropout_p=0.0,
|
dropout_p=0.0,
|
||||||
causal=False,
|
causal=False,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
@ -79,24 +87,42 @@ def vit_flash_attn_wrapper(
|
|||||||
q: torch.Tensor,
|
q: torch.Tensor,
|
||||||
k: torch.Tensor,
|
k: torch.Tensor,
|
||||||
v: torch.Tensor,
|
v: torch.Tensor,
|
||||||
cu_seqlens: torch.Tensor,
|
|
||||||
max_seqlen: torch.Tensor,
|
|
||||||
batch_size: int,
|
batch_size: int,
|
||||||
is_rocm_aiter: bool,
|
is_rocm_aiter: bool,
|
||||||
fa_version: int | None,
|
fa_version: int | None,
|
||||||
|
cu_seqlens: torch.Tensor | None = None,
|
||||||
|
max_seqlen: torch.Tensor | None = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
return torch.ops.vllm.flash_attn_maxseqlen_wrapper(
|
return torch.ops.vllm.flash_attn_maxseqlen_wrapper(
|
||||||
q, k, v, cu_seqlens, max_seqlen, batch_size, is_rocm_aiter, fa_version
|
q,
|
||||||
|
k,
|
||||||
|
v,
|
||||||
|
batch_size,
|
||||||
|
is_rocm_aiter,
|
||||||
|
fa_version,
|
||||||
|
cu_seqlens,
|
||||||
|
max_seqlen,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def apply_sdpa(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Input shape:
|
||||||
|
(batch_size x seq_len x num_heads x head_size)
|
||||||
|
"""
|
||||||
|
q, k, v = (einops.rearrange(x, "b s h d -> b h s d") for x in [q, k, v])
|
||||||
|
output = F.scaled_dot_product_attention(q, k, v, dropout_p=0.0)
|
||||||
|
output = einops.rearrange(output, "b h s d -> b s h d ")
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
# TODO: Once we have a torch 2.10, we can use tensor slices
|
# TODO: Once we have a torch 2.10, we can use tensor slices
|
||||||
# so we won't need to wrap this in custom ops
|
# so we won't need to wrap this in custom ops
|
||||||
def torch_sdpa_wrapper(
|
def torch_sdpa_wrapper(
|
||||||
q: torch.Tensor,
|
q: torch.Tensor,
|
||||||
k: torch.Tensor,
|
k: torch.Tensor,
|
||||||
v: torch.Tensor,
|
v: torch.Tensor,
|
||||||
cu_seqlens: torch.Tensor,
|
cu_seqlens: torch.Tensor | None = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
# Never remove the contiguous logic for ROCm
|
# Never remove the contiguous logic for ROCm
|
||||||
# Without it, hallucinations occur with the backend
|
# Without it, hallucinations occur with the backend
|
||||||
@ -105,6 +131,9 @@ def torch_sdpa_wrapper(
|
|||||||
k = k.contiguous()
|
k = k.contiguous()
|
||||||
v = v.contiguous()
|
v = v.contiguous()
|
||||||
|
|
||||||
|
if cu_seqlens is None:
|
||||||
|
return apply_sdpa(q, k, v)
|
||||||
|
|
||||||
outputs = []
|
outputs = []
|
||||||
|
|
||||||
lens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
|
lens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
|
||||||
@ -112,11 +141,7 @@ def torch_sdpa_wrapper(
|
|||||||
k_chunks = torch.split(k, lens, dim=1)
|
k_chunks = torch.split(k, lens, dim=1)
|
||||||
v_chunks = torch.split(v, lens, dim=1)
|
v_chunks = torch.split(v, lens, dim=1)
|
||||||
for q_i, k_i, v_i in zip(q_chunks, k_chunks, v_chunks):
|
for q_i, k_i, v_i in zip(q_chunks, k_chunks, v_chunks):
|
||||||
q_i, k_i, v_i = (
|
output_i = apply_sdpa(q_i, k_i, v_i)
|
||||||
einops.rearrange(x, "b s h d -> b h s d") for x in [q_i, k_i, v_i]
|
|
||||||
)
|
|
||||||
output_i = F.scaled_dot_product_attention(q_i, k_i, v_i, dropout_p=0.0)
|
|
||||||
output_i = einops.rearrange(output_i, "b h s d -> b s h d ")
|
|
||||||
outputs.append(output_i)
|
outputs.append(output_i)
|
||||||
context_layer = torch.cat(outputs, dim=1)
|
context_layer = torch.cat(outputs, dim=1)
|
||||||
return context_layer
|
return context_layer
|
||||||
@ -142,6 +167,6 @@ def vit_torch_sdpa_wrapper(
|
|||||||
q: torch.Tensor,
|
q: torch.Tensor,
|
||||||
k: torch.Tensor,
|
k: torch.Tensor,
|
||||||
v: torch.Tensor,
|
v: torch.Tensor,
|
||||||
cu_seqlens: torch.Tensor,
|
cu_seqlens: torch.Tensor | None = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
return torch.ops.vllm.torch_sdpa_wrapper(q, k, v, cu_seqlens)
|
return torch.ops.vllm.torch_sdpa_wrapper(q, k, v, cu_seqlens)
|
||||||
|
|||||||
@ -8,7 +8,7 @@ from collections.abc import Iterable
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
from vllm.attention.layer import MultiHeadAttention
|
from vllm.attention.layers.mm_encoder_attention import MMEncoderAttention
|
||||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||||
from vllm.distributed.utils import divide
|
from vllm.distributed.utils import divide
|
||||||
from vllm.model_executor.layers.activation import SiluAndMul
|
from vllm.model_executor.layers.activation import SiluAndMul
|
||||||
@ -126,7 +126,7 @@ class AIMv2Attention(nn.Module):
|
|||||||
self.tp_size = get_tensor_model_parallel_world_size()
|
self.tp_size = get_tensor_model_parallel_world_size()
|
||||||
self.num_heads_per_partition = divide(self.num_heads, self.tp_size)
|
self.num_heads_per_partition = divide(self.num_heads, self.tp_size)
|
||||||
|
|
||||||
self.attn = MultiHeadAttention(
|
self.attn = MMEncoderAttention(
|
||||||
self.num_heads_per_partition, self.head_dim, self.scale
|
self.num_heads_per_partition, self.head_dim, self.scale
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@ -9,7 +9,7 @@ import torch
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from transformers import Blip2VisionConfig, BlipVisionConfig
|
from transformers import Blip2VisionConfig, BlipVisionConfig
|
||||||
|
|
||||||
from vllm.attention.layer import MultiHeadAttention
|
from vllm.attention.layers.mm_encoder_attention import MMEncoderAttention
|
||||||
from vllm.distributed import divide, get_tensor_model_parallel_world_size
|
from vllm.distributed import divide, get_tensor_model_parallel_world_size
|
||||||
from vllm.model_executor.layers.activation import get_act_fn
|
from vllm.model_executor.layers.activation import get_act_fn
|
||||||
from vllm.model_executor.layers.conv import Conv2dLayer
|
from vllm.model_executor.layers.conv import Conv2dLayer
|
||||||
@ -122,7 +122,7 @@ class BlipAttention(nn.Module):
|
|||||||
self.tp_size = get_tensor_model_parallel_world_size()
|
self.tp_size = get_tensor_model_parallel_world_size()
|
||||||
self.num_heads_per_partition = divide(self.num_heads, self.tp_size)
|
self.num_heads_per_partition = divide(self.num_heads, self.tp_size)
|
||||||
|
|
||||||
self.attn = MultiHeadAttention(
|
self.attn = MMEncoderAttention(
|
||||||
self.num_heads_per_partition, self.head_dim, self.scale
|
self.num_heads_per_partition, self.head_dim, self.scale
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@ -14,7 +14,8 @@ from transformers import (
|
|||||||
CLIPVisionConfig,
|
CLIPVisionConfig,
|
||||||
)
|
)
|
||||||
|
|
||||||
from vllm.attention.layer import Attention, MultiHeadAttention
|
from vllm.attention.layer import Attention
|
||||||
|
from vllm.attention.layers.mm_encoder_attention import MMEncoderAttention
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
from vllm.config.multimodal import BaseDummyOptions
|
from vllm.config.multimodal import BaseDummyOptions
|
||||||
from vllm.distributed import divide, get_tensor_model_parallel_world_size
|
from vllm.distributed import divide, get_tensor_model_parallel_world_size
|
||||||
@ -354,7 +355,7 @@ class CLIPAttention(nn.Module):
|
|||||||
quant_config: QuantizationConfig | None = None,
|
quant_config: QuantizationConfig | None = None,
|
||||||
*,
|
*,
|
||||||
prefix: str = "",
|
prefix: str = "",
|
||||||
attn_cls: type[Attention] | type[MultiHeadAttention],
|
attn_cls: type[Attention] | type[MMEncoderAttention],
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@ -449,7 +450,7 @@ class CLIPEncoderLayer(nn.Module):
|
|||||||
quant_config: QuantizationConfig | None = None,
|
quant_config: QuantizationConfig | None = None,
|
||||||
*,
|
*,
|
||||||
prefix: str = "",
|
prefix: str = "",
|
||||||
attn_cls: type[Attention] | type[MultiHeadAttention],
|
attn_cls: type[Attention] | type[MMEncoderAttention],
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.self_attn = CLIPAttention(
|
self.self_attn = CLIPAttention(
|
||||||
@ -493,7 +494,7 @@ class CLIPEncoder(nn.Module):
|
|||||||
num_hidden_layers_override: int | None = None,
|
num_hidden_layers_override: int | None = None,
|
||||||
*,
|
*,
|
||||||
prefix: str = "",
|
prefix: str = "",
|
||||||
attn_cls: type[Attention] | type[MultiHeadAttention],
|
attn_cls: type[Attention] | type[MMEncoderAttention],
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@ -638,7 +639,7 @@ class CLIPVisionTransformer(nn.Module):
|
|||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
num_hidden_layers_override=num_hidden_layers_override,
|
num_hidden_layers_override=num_hidden_layers_override,
|
||||||
prefix=f"{prefix}.encoder",
|
prefix=f"{prefix}.encoder",
|
||||||
attn_cls=MultiHeadAttention,
|
attn_cls=MMEncoderAttention,
|
||||||
)
|
)
|
||||||
|
|
||||||
num_hidden_layers = config.num_hidden_layers
|
num_hidden_layers = config.num_hidden_layers
|
||||||
|
|||||||
@ -18,7 +18,7 @@ import torch.nn as nn
|
|||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from transformers import CLIPVisionConfig
|
from transformers import CLIPVisionConfig
|
||||||
|
|
||||||
from vllm.attention.layer import MultiHeadAttention
|
from vllm.attention.layers.mm_encoder_attention import MMEncoderAttention
|
||||||
from vllm.model_executor.layers.conv import Conv2dLayer
|
from vllm.model_executor.layers.conv import Conv2dLayer
|
||||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||||
@ -628,7 +628,7 @@ class DeepCLIPVisionTransformer(nn.Module):
|
|||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
num_hidden_layers_override=num_hidden_layers_override,
|
num_hidden_layers_override=num_hidden_layers_override,
|
||||||
prefix=f"{prefix}.encoder",
|
prefix=f"{prefix}.encoder",
|
||||||
attn_cls=MultiHeadAttention,
|
attn_cls=MMEncoderAttention,
|
||||||
)
|
)
|
||||||
|
|
||||||
num_hidden_layers = config.num_hidden_layers
|
num_hidden_layers = config.num_hidden_layers
|
||||||
|
|||||||
@ -19,7 +19,7 @@ from transformers import BatchFeature, PreTrainedTokenizer, TensorType
|
|||||||
from transformers.image_utils import ImageInput
|
from transformers.image_utils import ImageInput
|
||||||
from transformers.tokenization_utils_base import TextInput
|
from transformers.tokenization_utils_base import TextInput
|
||||||
|
|
||||||
from vllm.attention.layer import MultiHeadAttention
|
from vllm.attention.layers.mm_encoder_attention import MMEncoderAttention
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
from vllm.config.multimodal import BaseDummyOptions
|
from vllm.config.multimodal import BaseDummyOptions
|
||||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||||
@ -135,7 +135,7 @@ class EVA2CLIPAttention(nn.Module):
|
|||||||
prefix=f"{prefix}.dense",
|
prefix=f"{prefix}.dense",
|
||||||
)
|
)
|
||||||
|
|
||||||
self.attn = MultiHeadAttention(
|
self.attn = MMEncoderAttention(
|
||||||
self.num_heads_per_rank, self.head_dim, self.scale
|
self.num_heads_per_rank, self.head_dim, self.scale
|
||||||
)
|
)
|
||||||
self.output_dropout = torch.nn.Dropout(config.dropout_prob)
|
self.output_dropout = torch.nn.Dropout(config.dropout_prob)
|
||||||
|
|||||||
@ -34,7 +34,7 @@ import torch.nn.functional as F
|
|||||||
from transformers import BatchFeature
|
from transformers import BatchFeature
|
||||||
|
|
||||||
from vllm.attention.backends.registry import AttentionBackendEnum
|
from vllm.attention.backends.registry import AttentionBackendEnum
|
||||||
from vllm.attention.layer import MultiHeadAttention
|
from vllm.attention.layers.mm_encoder_attention import MMEncoderAttention
|
||||||
from vllm.config import MultiModalConfig, VllmConfig
|
from vllm.config import MultiModalConfig, VllmConfig
|
||||||
from vllm.config.multimodal import BaseDummyOptions
|
from vllm.config.multimodal import BaseDummyOptions
|
||||||
from vllm.distributed import parallel_state
|
from vllm.distributed import parallel_state
|
||||||
@ -232,7 +232,7 @@ class HunYuanVisionAttention(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
self.scale = self.hidden_size_per_attention_head**-0.5
|
self.scale = self.hidden_size_per_attention_head**-0.5
|
||||||
self.attn = MultiHeadAttention(
|
self.attn = MMEncoderAttention(
|
||||||
self.num_attention_heads_per_partition,
|
self.num_attention_heads_per_partition,
|
||||||
self.hidden_size_per_attention_head,
|
self.hidden_size_per_attention_head,
|
||||||
self.scale,
|
self.scale,
|
||||||
|
|||||||
@ -27,7 +27,7 @@ from transformers.models.idefics2.configuration_idefics2 import (
|
|||||||
Idefics2VisionConfig,
|
Idefics2VisionConfig,
|
||||||
)
|
)
|
||||||
|
|
||||||
from vllm.attention.layer import MultiHeadAttention
|
from vllm.attention.layers.mm_encoder_attention import MMEncoderAttention
|
||||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||||
from vllm.model_executor.layers.activation import get_act_fn
|
from vllm.model_executor.layers.activation import get_act_fn
|
||||||
from vllm.model_executor.layers.conv import Conv2dLayer
|
from vllm.model_executor.layers.conv import Conv2dLayer
|
||||||
@ -161,8 +161,8 @@ class Idefics2VisionAttention(nn.Module):
|
|||||||
prefix=f"{prefix}.out_proj",
|
prefix=f"{prefix}.out_proj",
|
||||||
disable_tp=use_data_parallel,
|
disable_tp=use_data_parallel,
|
||||||
)
|
)
|
||||||
# Use unified MultiHeadAttention with Flash Attention support
|
# Use unified MMEncoderAttention with Flash Attention support
|
||||||
self.attn = MultiHeadAttention(
|
self.attn = MMEncoderAttention(
|
||||||
self.num_heads_per_partition, self.head_dim, self.scale
|
self.num_heads_per_partition, self.head_dim, self.scale
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -175,7 +175,7 @@ class Idefics2VisionAttention(nn.Module):
|
|||||||
) # batch_size, q_len, 3 * num_heads_per_partition * head_dim
|
) # batch_size, q_len, 3 * num_heads_per_partition * head_dim
|
||||||
query_states, key_states, value_states = qkv.chunk(3, dim=-1)
|
query_states, key_states, value_states = qkv.chunk(3, dim=-1)
|
||||||
|
|
||||||
# Use unified MultiHeadAttention implementation
|
# Use unified MMEncoderAttention implementation
|
||||||
out = self.attn(query_states, key_states, value_states)
|
out = self.attn(query_states, key_states, value_states)
|
||||||
attn_output, _ = self.out_proj(out)
|
attn_output, _ = self.out_proj(out)
|
||||||
return attn_output
|
return attn_output
|
||||||
|
|||||||
@ -15,7 +15,7 @@ import torch.nn as nn
|
|||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from transformers import PretrainedConfig
|
from transformers import PretrainedConfig
|
||||||
|
|
||||||
from vllm.attention.layer import MultiHeadAttention
|
from vllm.attention.layers.mm_encoder_attention import MMEncoderAttention
|
||||||
from vllm.distributed import (
|
from vllm.distributed import (
|
||||||
divide,
|
divide,
|
||||||
get_tensor_model_parallel_rank,
|
get_tensor_model_parallel_rank,
|
||||||
@ -207,7 +207,7 @@ class InternParallelAttention(nn.Module):
|
|||||||
disable_tp=use_data_parallel,
|
disable_tp=use_data_parallel,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.attn = MultiHeadAttention(
|
self.attn = MMEncoderAttention(
|
||||||
self.num_heads_per_partition, self.head_dim, self.scale
|
self.num_heads_per_partition, self.head_dim, self.scale
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@ -14,7 +14,7 @@ import torch.nn as nn
|
|||||||
from transformers import PretrainedConfig
|
from transformers import PretrainedConfig
|
||||||
from transformers.utils import torch_int
|
from transformers.utils import torch_int
|
||||||
|
|
||||||
from vllm.attention.layer import MultiHeadAttention
|
from vllm.attention.layers.mm_encoder_attention import MMEncoderAttention
|
||||||
from vllm.model_executor.layers.activation import get_act_fn
|
from vllm.model_executor.layers.activation import get_act_fn
|
||||||
from vllm.model_executor.layers.conv import Conv2dLayer
|
from vllm.model_executor.layers.conv import Conv2dLayer
|
||||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||||
@ -214,8 +214,8 @@ class InternSdpaAttention(nn.Module):
|
|||||||
|
|
||||||
self.projection_layer = nn.Linear(self.dummy_dim, self.embed_dim)
|
self.projection_layer = nn.Linear(self.dummy_dim, self.embed_dim)
|
||||||
|
|
||||||
# Use unified MultiHeadAttention with automatic backend selection
|
# Use unified MMEncoderAttention with automatic backend selection
|
||||||
self.attn = MultiHeadAttention(self.num_heads, self.head_dim, self.scale)
|
self.attn = MMEncoderAttention(self.num_heads, self.head_dim, self.scale)
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
"""x shape: (B, N, C)"""
|
"""x shape: (B, N, C)"""
|
||||||
@ -228,7 +228,7 @@ class InternSdpaAttention(nn.Module):
|
|||||||
q = self.q_norm(q)
|
q = self.q_norm(q)
|
||||||
k = self.k_norm(k)
|
k = self.k_norm(k)
|
||||||
|
|
||||||
# Use unified MultiHeadAttention with automatic backend selection
|
# Use unified MMEncoderAttention with automatic backend selection
|
||||||
x = self.attn(q, k, v)
|
x = self.attn(q, k, v)
|
||||||
|
|
||||||
x = self.projection_layer(x)
|
x = self.projection_layer(x)
|
||||||
|
|||||||
@ -31,7 +31,7 @@ from transformers.models.llama4.image_processing_llama4_fast import (
|
|||||||
get_best_fit,
|
get_best_fit,
|
||||||
)
|
)
|
||||||
|
|
||||||
from vllm.attention.layer import MultiHeadAttention
|
from vllm.attention.layers.mm_encoder_attention import MMEncoderAttention
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
from vllm.config.multimodal import BaseDummyOptions
|
from vllm.config.multimodal import BaseDummyOptions
|
||||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||||
@ -255,7 +255,7 @@ class Llama4VisionAttention(nn.Module):
|
|||||||
self.attention_dropout = config.attention_dropout
|
self.attention_dropout = config.attention_dropout
|
||||||
self.scaling = self.head_dim**-0.5
|
self.scaling = self.head_dim**-0.5
|
||||||
|
|
||||||
self.attn = MultiHeadAttention(
|
self.attn = MMEncoderAttention(
|
||||||
self.num_local_heads, self.head_dim, self.scaling
|
self.num_local_heads, self.head_dim, self.scaling
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@ -17,7 +17,8 @@ from transformers import BatchFeature, PretrainedConfig, ProcessorMixin, TensorT
|
|||||||
from transformers.image_utils import ImageInput
|
from transformers.image_utils import ImageInput
|
||||||
from transformers.tokenization_utils_base import TextInput
|
from transformers.tokenization_utils_base import TextInput
|
||||||
|
|
||||||
from vllm.attention.layer import Attention, MultiHeadAttention
|
from vllm.attention.layer import Attention
|
||||||
|
from vllm.attention.layers.mm_encoder_attention import MMEncoderAttention
|
||||||
from vllm.compilation.decorators import support_torch_compile
|
from vllm.compilation.decorators import support_torch_compile
|
||||||
from vllm.config import CacheConfig, VllmConfig
|
from vllm.config import CacheConfig, VllmConfig
|
||||||
from vllm.config.multimodal import BaseDummyOptions
|
from vllm.config.multimodal import BaseDummyOptions
|
||||||
@ -222,7 +223,7 @@ class MultiHeadDotProductAttention(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
self.scale = self.head_dim**-0.5
|
self.scale = self.head_dim**-0.5
|
||||||
self.attn = MultiHeadAttention(
|
self.attn = MMEncoderAttention(
|
||||||
self.num_heads, self.head_dim, self.scale, num_kv_heads=self.num_kv_heads
|
self.num_heads, self.head_dim, self.scale, num_kv_heads=self.num_kv_heads
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@ -16,8 +16,8 @@ from transformers import (
|
|||||||
SiglipVisionConfig,
|
SiglipVisionConfig,
|
||||||
)
|
)
|
||||||
|
|
||||||
from vllm.attention.layer import MultiHeadAttention
|
|
||||||
from vllm.attention.layers.encoder_only_attention import EncoderOnlyAttention
|
from vllm.attention.layers.encoder_only_attention import EncoderOnlyAttention
|
||||||
|
from vllm.attention.layers.mm_encoder_attention import MMEncoderAttention
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
from vllm.config.multimodal import BaseDummyOptions
|
from vllm.config.multimodal import BaseDummyOptions
|
||||||
from vllm.distributed import divide, get_tensor_model_parallel_world_size
|
from vllm.distributed import divide, get_tensor_model_parallel_world_size
|
||||||
@ -379,7 +379,7 @@ class SiglipAttention(nn.Module):
|
|||||||
quant_config: QuantizationConfig | None = None,
|
quant_config: QuantizationConfig | None = None,
|
||||||
*,
|
*,
|
||||||
prefix: str = "",
|
prefix: str = "",
|
||||||
attn_cls: type[EncoderOnlyAttention] | type[MultiHeadAttention],
|
attn_cls: type[EncoderOnlyAttention] | type[MMEncoderAttention],
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@ -481,7 +481,7 @@ class SiglipEncoderLayer(nn.Module):
|
|||||||
quant_config: QuantizationConfig | None = None,
|
quant_config: QuantizationConfig | None = None,
|
||||||
*,
|
*,
|
||||||
prefix: str = "",
|
prefix: str = "",
|
||||||
attn_cls: type[EncoderOnlyAttention] | type[MultiHeadAttention],
|
attn_cls: type[EncoderOnlyAttention] | type[MMEncoderAttention],
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@ -527,7 +527,7 @@ class SiglipEncoder(nn.Module):
|
|||||||
num_hidden_layers_override: int | None = None,
|
num_hidden_layers_override: int | None = None,
|
||||||
*,
|
*,
|
||||||
prefix: str = "",
|
prefix: str = "",
|
||||||
attn_cls: type[EncoderOnlyAttention] | type[MultiHeadAttention],
|
attn_cls: type[EncoderOnlyAttention] | type[MMEncoderAttention],
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@ -700,7 +700,7 @@ class SiglipVisionTransformer(nn.Module):
|
|||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
num_hidden_layers_override=num_hidden_layers_override,
|
num_hidden_layers_override=num_hidden_layers_override,
|
||||||
prefix=f"{prefix}.encoder",
|
prefix=f"{prefix}.encoder",
|
||||||
attn_cls=MultiHeadAttention,
|
attn_cls=MMEncoderAttention,
|
||||||
)
|
)
|
||||||
|
|
||||||
num_hidden_layers = config.num_hidden_layers
|
num_hidden_layers = config.num_hidden_layers
|
||||||
|
|||||||
@ -15,7 +15,7 @@ from torchvision import transforms
|
|||||||
from torchvision.transforms.functional import InterpolationMode
|
from torchvision.transforms.functional import InterpolationMode
|
||||||
from transformers import BatchFeature, PretrainedConfig, TensorType
|
from transformers import BatchFeature, PretrainedConfig, TensorType
|
||||||
|
|
||||||
from vllm.attention.layer import MultiHeadAttention
|
from vllm.attention.layers.mm_encoder_attention import MMEncoderAttention
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
from vllm.config.multimodal import BaseDummyOptions
|
from vllm.config.multimodal import BaseDummyOptions
|
||||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||||
@ -753,8 +753,8 @@ class Step3VisionAttention(nn.Module):
|
|||||||
disable_tp=use_data_parallel,
|
disable_tp=use_data_parallel,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Use unified MultiHeadAttention with automatic backend selection
|
# Use unified MMEncoderAttention with automatic backend selection
|
||||||
self.attn = MultiHeadAttention(self.num_heads, self.head_dim, self.scale)
|
self.attn = MMEncoderAttention(self.num_heads, self.head_dim, self.scale)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -767,7 +767,7 @@ class Step3VisionAttention(nn.Module):
|
|||||||
qkv, _ = self.qkv_proj(hidden_states)
|
qkv, _ = self.qkv_proj(hidden_states)
|
||||||
q, k, v = qkv.chunk(chunks=3, dim=-1)
|
q, k, v = qkv.chunk(chunks=3, dim=-1)
|
||||||
|
|
||||||
# Use unified MultiHeadAttention with automatic backend selection
|
# Use unified MMEncoderAttention with automatic backend selection
|
||||||
attn_output = self.attn(q, k, v)
|
attn_output = self.attn(q, k, v)
|
||||||
|
|
||||||
attn_output, _ = self.out_proj(attn_output)
|
attn_output, _ = self.out_proj(attn_output)
|
||||||
|
|||||||
@ -16,9 +16,9 @@ from transformers import (
|
|||||||
)
|
)
|
||||||
from transformers.models.whisper.modeling_whisper import sinusoids
|
from transformers.models.whisper.modeling_whisper import sinusoids
|
||||||
|
|
||||||
from vllm.attention.backends.abstract import AttentionType
|
from vllm.attention.layer import Attention, AttentionType
|
||||||
from vllm.attention.layer import Attention, MultiHeadAttention
|
|
||||||
from vllm.attention.layers.cross_attention import CrossAttention
|
from vllm.attention.layers.cross_attention import CrossAttention
|
||||||
|
from vllm.attention.layers.mm_encoder_attention import MMEncoderAttention
|
||||||
from vllm.config import CacheConfig, ModelConfig, SpeechToTextConfig, VllmConfig
|
from vllm.config import CacheConfig, ModelConfig, SpeechToTextConfig, VllmConfig
|
||||||
from vllm.config.multimodal import BaseDummyOptions
|
from vllm.config.multimodal import BaseDummyOptions
|
||||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||||
@ -141,7 +141,7 @@ class WhisperAudioInputs(TensorSchema):
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
class WhisperEncoderAttention(MultiHeadAttention):
|
class WhisperEncoderAttention(MMEncoderAttention):
|
||||||
"""Multi-headed attention for Whisper encoder with 2D tensor support."""
|
"""Multi-headed attention for Whisper encoder with 2D tensor support."""
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user