[Core] Deprecate xformers (#29262)

Signed-off-by: Roger Wang <hey@rogerw.io>
This commit is contained in:
Roger Wang 2025-11-23 20:18:55 -08:00 committed by GitHub
parent 5253f4276f
commit 0ff70821c9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
31 changed files with 77 additions and 963 deletions

View File

@ -76,34 +76,6 @@ RUN --mount=type=cache,target=/root/.cache/uv \
RUN --mount=type=cache,target=/root/.cache/uv \
uv pip install --system -r requirements/common.txt
# must put before installing xformers, so it can install the correct version of xfomrers.
ARG torch_cuda_arch_list='8.0;8.6;8.9;9.0'
ENV TORCH_CUDA_ARCH_LIST=${torch_cuda_arch_list}
# Build xformers with cuda and torch nightly
# following official xformers guidance: https://github.com/facebookresearch/xformers#build
# todo(elainewy): cache xformers build result for faster build
ARG max_jobs=16
ENV MAX_JOBS=${max_jobs}
ARG XFORMERS_COMMIT=f2de641ef670510cadab099ce6954031f52f191c
ENV CCACHE_DIR=/root/.cache/ccache
RUN --mount=type=cache,target=/root/.cache/ccache \
--mount=type=cache,target=/root/.cache/uv \
echo 'git clone xformers...' \
&& git clone https://github.com/facebookresearch/xformers.git --recursive \
&& cd xformers \
&& git checkout ${XFORMERS_COMMIT} \
&& git submodule update --init --recursive \
&& echo 'finish git clone xformers...' \
&& rm -rf build \
&& python3 setup.py bdist_wheel --dist-dir=../xformers-dist --verbose \
&& cd .. \
&& rm -rf xformers
RUN --mount=type=cache,target=/root/.cache/uv \
uv pip install --system xformers-dist/*.whl --verbose
# build can take a long time, and the torch nightly version fetched from url can be different in next docker stage.
# track the nightly torch version used in the build, when we set up runtime environment we can make sure the version is the same
RUN uv pip freeze | grep -i '^torch\|^torchvision\|^torchaudio' > torch_build_versions.txt
@ -233,11 +205,6 @@ RUN --mount=type=bind,from=build,src=/workspace/dist,target=/vllm-workspace/vllm
--mount=type=cache,target=/root/.cache/uv \
uv pip install --system vllm-dist/*.whl --verbose
# install xformers again for the new environment
RUN --mount=type=bind,from=base,src=/workspace/xformers-dist,target=/vllm-workspace/xformers-dist \
--mount=type=cache,target=/root/.cache/uv \
uv pip install --system /vllm-workspace/xformers-dist/*.whl --verbose
ARG torch_cuda_arch_list='8.0;8.6;8.9;9.0'
# install package for build flashinfer
@ -307,7 +274,7 @@ RUN --mount=type=cache,target=/root/.cache/uv \
uv pip install --system -r requirements/nightly_torch_test.txt
# Logging to confirm the torch versions
RUN pip freeze | grep -E 'torch|xformers|vllm|flashinfer'
RUN pip freeze | grep -E 'torch|vllm|flashinfer'
# Logging to confirm all the packages are installed
RUN pip freeze

View File

@ -98,21 +98,6 @@ to warm it up so that future builds are faster.
<img width="60%" alt="Buildkite new build popup" src="https://github.com/user-attachments/assets/a8ff0fcd-76e0-4e91-b72f-014e3fdb6b94">
</p>
## Update dependencies
Several vLLM dependencies like xFormers depend on PyTorch and need
to be updated accordingly. Rather than waiting for all of them to publish new
releases (which would take too much time), they can be built from
source to unblock the update process.
### xFormers
```bash
export TORCH_CUDA_ARCH_LIST='7.5 8.0+PTX 9.0a'
MAX_JOBS=16 uv pip install --system \
--no-build-isolation "git+https://github.com/facebookresearch/xformers@v0.0.32.post2"
```
## Update all the different vLLM platforms
Rather than attempting to update all vLLM platforms in a single pull request, it's more manageable

View File

@ -283,7 +283,7 @@ Currently, vLLM supports multiple backends for efficient Attention computation a
If desired, you can also manually set the backend of your choice by configuring the environment variable `VLLM_ATTENTION_BACKEND` to one of the following options:
- On NVIDIA CUDA: `FLASH_ATTN`, `FLASHINFER` or `XFORMERS`.
- On NVIDIA CUDA: `FLASH_ATTN` or `FLASHINFER`.
- On AMD ROCm: `TRITON_ATTN`, `ROCM_ATTN`, `ROCM_AITER_FA` or `ROCM_AITER_UNIFIED_ATTN`.
For AMD ROCm, you can further control the specific Attention implementation using the following variables:

View File

@ -22,7 +22,6 @@ API_KEY=${API_KEY:-"your-api-key"}
POOLING_TYPE=${POOLING_TYPE:-"auto"} # auto, MEAN, CLS, LAST
export VLLM_ENABLE_CHUNKED_PROCESSING=true
export CUDA_VISIBLE_DEVICES=2,3,4,5
# export VLLM_ATTENTION_BACKEND=XFORMERS
echo "🚀 Starting vLLM Embedding Server with Enhanced Chunked Processing"
echo "=================================================================="

View File

@ -9,6 +9,5 @@ torch==2.9.0
torchaudio==2.9.0
# These must be updated alongside torch
torchvision==0.24.0 # Required for phi3v processor. See https://github.com/pytorch/vision?tab=readme-ov-file#installation for corresponding version
xformers==0.0.33.post1; platform_system == 'Linux' and platform_machine == 'x86_64' # Requires PyTorch >= 2.9
# FlashInfer should be updated together with the Dockerfile
flashinfer-python==0.5.2

View File

@ -74,9 +74,6 @@ def test_models(
model_executor: str,
enable_prompt_embeds: bool,
) -> None:
if backend == "XFORMERS" and model == "google/gemma-2-2b-it":
pytest.skip(f"{backend} does not support gemma2 with full context length.")
with monkeypatch.context() as m:
m.setenv("VLLM_ATTENTION_BACKEND", backend)

View File

@ -13,12 +13,6 @@ from vllm.attention.layer import Attention, MultiHeadAttention
from vllm.platforms import current_platform
from vllm.utils.mem_utils import get_max_shared_memory_bytes
if not current_platform.is_rocm():
from xformers import ops as xops
from xformers.ops.fmha.attn_bias import BlockDiagonalCausalMask
from tests.kernels.utils import make_alibi_bias
FLOAT32_BYTES = torch.finfo(torch.float).bits // 8
# This will change depending on the compute capability.
# - 512 as a buffer
@ -448,129 +442,6 @@ def ref_multi_query_kv_attention(
return torch.cat(ref_outputs, dim=0)
@pytest.mark.parametrize("num_seqs", NUM_PREFILL_SEQS)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.skipif(
current_platform.is_rocm(), reason="Xformers backend is not supported on ROCm."
)
@torch.inference_mode()
def test_multi_query_kv_attention(
num_seqs: int,
num_heads: tuple[int, int],
head_size: int,
dtype: torch.dtype,
seed: int,
device: str,
use_alibi: bool = False,
) -> None:
current_platform.seed_everything(seed)
torch.set_default_device(device)
# MAX_SEQ_LEN sometimes causes OOM in the reference implementation.
# As the xformers library is already tested with its own tests, we can use
# a smaller MAX_SEQ_LEN here.
max_len = min(MAX_SEQ_LEN, 4096)
seq_lens = random.sample(range(1, max_len), num_seqs)
num_tokens = sum(seq_lens)
scale = float(1.0 / (head_size**0.5))
num_query_heads, num_kv_heads = num_heads
qkv = torch.empty(
num_tokens, num_query_heads + 2 * num_kv_heads, head_size, dtype=dtype
)
qkv.uniform_(-scale, scale)
query, key, value = qkv.split([num_query_heads, num_kv_heads, num_kv_heads], dim=1)
num_queries_per_kv = num_query_heads // num_kv_heads
if num_queries_per_kv > 1:
# Handle MQA and GQA
key = torch.repeat_interleave(key, num_queries_per_kv, dim=1)
value = torch.repeat_interleave(value, num_queries_per_kv, dim=1)
alibi_bias = None
if use_alibi:
alibi_slopes = torch.randn(num_query_heads, dtype=torch.float)
attn_bias = make_alibi_bias(alibi_slopes, num_kv_heads, dtype, seq_lens)
output = torch.empty_like(query)
start = 0
# Dynamic sequence length not supported with custom attn_bias.
for i, seq_len in enumerate(seq_lens):
end = start + seq_len
out = xops.memory_efficient_attention_forward(
query[None, start:end],
key[None, start:end],
value[None, start:end],
attn_bias=attn_bias[i],
p=0.0,
scale=scale,
)
output[start:end].copy_(out.view_as(query[start:end]))
start += seq_len
# xformers.AttentionBias to Tensor for use in reference impl.
alibi_bias = [
b.materialize((1, num_query_heads, i, i), device=device).squeeze()
for b, i in zip(attn_bias, seq_lens)
]
else:
attn_bias = BlockDiagonalCausalMask.from_seqlens(seq_lens)
output = xops.memory_efficient_attention_forward(
query.unsqueeze(0),
key.unsqueeze(0),
value.unsqueeze(0),
attn_bias=attn_bias,
p=0.0,
scale=scale,
)
output = output.squeeze(0)
cu_seq_lens = [0]
for seq_len in seq_lens:
cu_seq_lens.append(cu_seq_lens[-1] + seq_len)
ref_output = ref_multi_query_kv_attention(
cu_seq_lens,
query,
key,
value,
scale,
alibi_bias,
dtype,
)
atol = get_default_atol(output) if current_platform.is_rocm() else 1e-3
rtol = get_default_rtol(output) if current_platform.is_rocm() else 1e-5
torch.testing.assert_close(output, ref_output, atol=atol, rtol=rtol)
@pytest.mark.parametrize("num_seqs", NUM_PREFILL_SEQS)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", [64])
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.skipif(
current_platform.is_rocm(), reason="Xformers backend is not supported on ROCm."
)
@torch.inference_mode()
def test_multi_query_kv_attention_with_alibi(
num_seqs: int,
num_heads: tuple[int, int],
head_size: int,
dtype: torch.dtype,
seed: int,
device: str,
) -> None:
return test_multi_query_kv_attention(
num_seqs,
num_heads,
head_size,
dtype,
seed,
device,
use_alibi=True,
)
@pytest.mark.parametrize("attention_cls", [Attention, MultiHeadAttention])
def test_num_heads_not_divisble_by_num_kv_heads(attention_cls: type) -> None:
head_size = 64

View File

@ -34,7 +34,7 @@ DEVICE_MLA_BACKENDS = {
}
DEVICE_REGULAR_ATTN_BACKENDS = {
"cuda": ["XFORMERS", "FLASHINFER", "FLASH_ATTN"],
"cuda": ["FLASHINFER", "FLASH_ATTN"],
"hip": ["ROCM_ATTN"],
"cpu": ["CPU_ATTN"],
}
@ -207,12 +207,6 @@ def test_env(
)
expected = "FLASHINFER"
assert backend.get_name() == expected
elif name == "XFORMERS":
backend = get_attn_backend(
32, torch.float16, None, block_size, use_mla=use_mla
)
expected = "XFORMERS"
assert backend.get_name() == expected
elif name == "FLASH_ATTN":
backend = get_attn_backend(
32, torch.float16, None, block_size, use_mla=use_mla

View File

@ -24,10 +24,6 @@ from vllm.platforms.rocm import RocmPlatform
def clear_cache():
"""Clear lru cache to ensure each test case runs without caching."""
_cached_get_attn_backend.cache_clear()
# Clear xformers availability cache
import vllm.attention.layer as layer_module
layer_module.USE_XFORMERS_OPS = None
@pytest.mark.parametrize("device", ["cpu", "hip", "cuda"])

View File

@ -509,43 +509,6 @@ def pack_qkv(qkv: QKVInputs, device: torch.device | str) -> PackedQKVInputs:
)
def make_alibi_bias(
alibi_slopes: torch.Tensor,
num_kv_heads: int,
dtype: torch.dtype,
seq_lens: list[int],
) -> list[Any]:
"""Create ALiBi biases compatible with xFormers attention tests."""
from xformers.ops.fmha.attn_bias import LowerTriangularMaskWithTensorBias
if alibi_slopes is None:
return [None for _ in seq_lens]
attn_biases: list[Any] = []
num_heads = alibi_slopes.shape[0]
assert num_heads >= num_kv_heads, (
"ALiBi slopes expect at least as many heads as KV heads"
)
for seq_len in seq_lens:
bias = torch.arange(seq_len, dtype=dtype, device=alibi_slopes.device)
bias = bias[None, :] - bias[:, None]
padded_len = (seq_len + 7) // 8 * 8
bias_tensor = torch.empty(
1,
num_heads,
seq_len,
padded_len,
device=alibi_slopes.device,
dtype=dtype,
)[:, :, :, :seq_len].copy_(bias)
bias_tensor.mul_(alibi_slopes[:, None, None])
attn_biases.append(LowerTriangularMaskWithTensorBias(bias_tensor))
return attn_biases
def _make_metadata_tensors(
seq_lens: list[int] | None,
context_lens: list[int] | None,
@ -649,23 +612,12 @@ def make_kv_cache(
Returns:
* kv_cache: 2 x num_blocks x (block_size * num_heads * head_size)
* for backend 'XFORMERS'
* kv_cache: 2 x num_blocks x block_size x num_heads x head_size
* for backend 'FLASH_ATTN'
"""
if backend == "XFORMERS":
kv_cache = torch.rand((2, num_blocks, block_size * num_heads * head_size)).to(
device
)
elif backend == "FLASH_ATTN":
kv_cache = torch.rand((2, num_blocks, block_size, num_heads, head_size)).to(
device
)
else:
raise ValueError(
f"Unknown backend value: '{backend}'. Expected 'XFORMERS' or 'FLASH_ATTN'."
)
if backend != "FLASH_ATTN":
raise ValueError(f"Unknown backend value: '{backend}'. Expected 'FLASH_ATTN'.")
kv_cache = torch.rand((2, num_blocks, block_size, num_heads, head_size)).to(device)
if default_val is not None:
kv_cache[:, :, :] = default_val
return kv_cache
@ -843,22 +795,14 @@ def assert_actual_matches_ideal(
* output_under_test: actually observed output value
"""
ideal_output = test_params.packed_qkvo.ideal_output
if backend == "XFORMERS":
torch.testing.assert_close(
ideal_output, output_under_test.view_as(ideal_output)
)
elif backend == "FLASH_ATTN":
# For FlashAttention override the accuracy thresholds to non default
# values since we notice a higher difference between the ideal and
# actual output.
torch.testing.assert_close(
ideal_output, output_under_test.view_as(ideal_output), atol=0.01, rtol=0.016
)
else:
raise ValueError(
f"Unknown backend value: '{backend}'. Expected 'XFORMERS' or 'FLASH_ATTN'."
)
if backend != "FLASH_ATTN":
raise ValueError(f"Unknown backend value: '{backend}'. Expected 'FLASH_ATTN'.")
# For FlashAttention override the accuracy thresholds to non default
# values since we notice a higher difference between the ideal and
# actual output.
torch.testing.assert_close(
ideal_output, output_under_test.view_as(ideal_output), atol=0.01, rtol=0.016
)
# Copied/modified from torch._refs.__init__.py

View File

@ -57,10 +57,6 @@ def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> list[str]:
return generated_texts
@pytest.mark.xfail(
current_platform.is_rocm(),
reason="MiniCPM-V dependency xformers incompatible with ROCm",
)
def test_minicpmv_lora(minicpmv_lora_files):
llm = vllm.LLM(
MODEL_PATH,
@ -84,10 +80,6 @@ def test_minicpmv_lora(minicpmv_lora_files):
@pytest.mark.skipif(
current_platform.is_cuda_alike(), reason="Skipping to avoid redundant model tests"
)
@pytest.mark.xfail(
current_platform.is_rocm(),
reason="MiniCPM-V dependency xformers incompatible with ROCm",
)
@multi_gpu_test(num_gpus=4)
def test_minicpmv_tp4_wo_fully_sharded_loras(minicpmv_lora_files):
llm = vllm.LLM(
@ -108,10 +100,6 @@ def test_minicpmv_tp4_wo_fully_sharded_loras(minicpmv_lora_files):
@pytest.mark.skipif(
current_platform.is_cuda_alike(), reason="Skipping to avoid redundant model tests"
)
@pytest.mark.xfail(
current_platform.is_rocm(),
reason="MiniCPM-V dependency xformers incompatible with ROCm",
)
@multi_gpu_test(num_gpus=4)
def test_minicpmv_tp4_fully_sharded_loras(minicpmv_lora_files):
llm = vllm.LLM(

View File

@ -2,12 +2,9 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass
import pytest
import vllm
from vllm.assets.image import ImageAsset
from vllm.lora.request import LoRARequest
from vllm.platforms import current_platform
from vllm.sampling_params import BeamSearchParams
@ -142,10 +139,6 @@ QWEN2VL_MODEL_PATH = "Qwen/Qwen2-VL-2B-Instruct"
QWEN25VL_MODEL_PATH = "Qwen/Qwen2.5-VL-3B-Instruct"
@pytest.mark.xfail(
current_platform.is_rocm(),
reason="Qwen2-VL dependency xformers incompatible with ROCm",
)
def test_qwen2vl_lora(qwen2vl_lora_files):
"""Test Qwen 2.0 VL model with LoRA"""
config = TestConfig(model_path=QWEN2VL_MODEL_PATH, lora_path=qwen2vl_lora_files)
@ -156,10 +149,6 @@ def test_qwen2vl_lora(qwen2vl_lora_files):
tester.run_test(TEST_IMAGES, expected_outputs=EXPECTED_OUTPUTS, lora_id=lora_id)
@pytest.mark.xfail(
current_platform.is_rocm(),
reason="Qwen2-VL dependency xformers incompatible with ROCm",
)
def test_qwen2vl_lora_beam_search(qwen2vl_lora_files):
"""Test Qwen 2.0 VL model with LoRA through beam search."""
config = TestConfig(model_path=QWEN2VL_MODEL_PATH, lora_path=qwen2vl_lora_files)
@ -178,10 +167,6 @@ def test_qwen2vl_lora_beam_search(qwen2vl_lora_files):
)
@pytest.mark.xfail(
current_platform.is_rocm(),
reason="Qwen2.5-VL dependency xformers incompatible with ROCm",
)
def test_qwen25vl_lora(qwen25vl_lora_files):
"""Test Qwen 2.5 VL model with LoRA"""
config = TestConfig(model_path=QWEN25VL_MODEL_PATH, lora_path=qwen25vl_lora_files)

View File

@ -43,7 +43,6 @@ class AttentionBackendEnum(Enum, metaclass=_AttentionBackendEnumMeta):
FLASH_ATTN = "vllm.v1.attention.backends.flash_attn.FlashAttentionBackend"
TRITON_ATTN = "vllm.v1.attention.backends.triton_attn.TritonAttentionBackend"
XFORMERS = "vllm.v1.attention.backends.xformers.XFormersAttentionBackend"
ROCM_ATTN = "vllm.v1.attention.backends.rocm_attn.RocmAttentionBackend"
ROCM_AITER_MLA = "vllm.v1.attention.backends.mla.rocm_aiter_mla.AiterMLABackend"
ROCM_AITER_TRITON_MLA = (

View File

@ -51,31 +51,6 @@ else:
FP8_DTYPE = current_platform.fp8_dtype()
logger = init_logger(__name__)
USE_XFORMERS_OPS = None
def check_xformers_availability():
global USE_XFORMERS_OPS
if USE_XFORMERS_OPS is not None:
return USE_XFORMERS_OPS
if current_platform.is_cuda() and current_platform.has_device_capability(100):
# Xformers FA is not compatible with B200
USE_XFORMERS_OPS = False
else:
try:
from importlib.util import find_spec
find_spec("xformers.ops")
USE_XFORMERS_OPS = True
except ImportError:
USE_XFORMERS_OPS = False
# the warning only needs to be shown once
if not USE_XFORMERS_OPS:
logger.warning("Xformers is not available, falling back.")
return USE_XFORMERS_OPS
def check_upstream_fa_availability(dtype: torch.dtype):
@ -533,7 +508,6 @@ class MultiHeadAttention(nn.Module):
if backend
in {
AttentionBackendEnum.TORCH_SDPA,
AttentionBackendEnum.XFORMERS,
AttentionBackendEnum.PALLAS,
AttentionBackendEnum.ROCM_AITER_FA,
AttentionBackendEnum.FLASH_ATTN,
@ -549,12 +523,6 @@ class MultiHeadAttention(nn.Module):
)
)
if (
self.attn_backend == AttentionBackendEnum.XFORMERS
and not check_xformers_availability()
):
self.attn_backend = AttentionBackendEnum.TORCH_SDPA
self.is_flash_attn_backend = self.attn_backend in {
AttentionBackendEnum.FLASH_ATTN,
AttentionBackendEnum.ROCM_AITER_FA,
@ -614,12 +582,6 @@ class MultiHeadAttention(nn.Module):
max_seqlen_k=kv_len,
softmax_scale=self.scale,
)
elif self.attn_backend == AttentionBackendEnum.XFORMERS:
from xformers import ops as xops
out = xops.memory_efficient_attention_forward(
query, key, value, scale=self.scale
)
elif self.attn_backend == AttentionBackendEnum.TORCH_SDPA:
query, key, value = (x.transpose(1, 2) for x in (query, key, value))
out = F.scaled_dot_product_attention(query, key, value, scale=self.scale)

View File

@ -3,7 +3,7 @@
"""
This file contains ops for ViT attention to be compatible with torch.compile
as there are operations here not supported by torch.compile (for instance,
`to_list` in xformers attn, or `.item()` in flash attention)
`.item()` in flash attention)
Using these ops and wrapping vision blocks with `torch.compile` can speed up
throughput in vision models by ~5% relative on H100, and improve token
@ -19,42 +19,6 @@ import torch.nn.functional as F
from vllm.utils.torch_utils import direct_register_custom_op
def xformers_attn_seqlens_wrapper(
q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, seqlens: torch.Tensor
) -> torch.Tensor:
from xformers import ops as xops
from xformers.ops.fmha.attn_bias import BlockDiagonalMask
attn_bias = BlockDiagonalMask.from_seqlens(
q_seqlen=seqlens.tolist(), kv_seqlen=None, device=q.device
)
context_layer = xops.memory_efficient_attention_forward(
q, k, v, attn_bias=attn_bias, p=0, scale=None
)
context_layer = einops.rearrange(context_layer, "b s h d -> s b (h d)").contiguous()
return context_layer
def xformers_attn_seqlens_wrapper_fake(
q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, seqlens: torch.Tensor
) -> torch.Tensor:
b, s, h, d = q.shape
return torch.empty((s, b, h * d), dtype=q.dtype, device=q.device)
direct_register_custom_op(
op_name="xformers_attn_seqlens_wrapper",
op_func=xformers_attn_seqlens_wrapper,
fake_impl=xformers_attn_seqlens_wrapper_fake,
)
def vit_xformers_attn_wrapper(
q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, seqlens: torch.Tensor
) -> torch.Tensor:
return torch.ops.vllm.xformers_attn_seqlens_wrapper(q, k, v, seqlens)
def flash_attn_maxseqlen_wrapper(
q: torch.Tensor,
k: torch.Tensor,

View File

@ -36,7 +36,14 @@ def get_env_variable_attn_backend() -> AttentionBackendEnum | None:
* None otherwise
"""
backend_name = os.environ.get(STR_BACKEND_ENV_VAR)
return None if backend_name is None else AttentionBackendEnum[backend_name]
if backend_name is None:
return None
if backend_name == "XFORMERS":
raise ValueError(
"Attention backend 'XFORMERS' has been removed (See PR #29262 for "
"details). Please select a supported attention backend."
)
return AttentionBackendEnum[backend_name]
# Global state allows a particular choice of backend

View File

@ -173,6 +173,12 @@ class MultiModalConfig:
# We need to import the real type here (deferred to avoid circular import).
from vllm.attention.backends.registry import AttentionBackendEnum
if isinstance(value, str) and value.upper() == "XFORMERS":
raise ValueError(
"Attention backend 'XFORMERS' has been removed (See PR #29262 for "
"details). Please select a supported attention backend."
)
if value is None or isinstance(value, AttentionBackendEnum):
return value

View File

@ -640,7 +640,6 @@ environment_variables: dict[str, Callable[[], Any]] = {
# Example options:
# - "TORCH_SDPA": use torch.nn.MultiheadAttention
# - "FLASH_ATTN": use FlashAttention
# - "XFORMERS": use XFormers
# - "FLASHINFER": use flashinfer
# - "FLASHMLA": use FlashMLA
# - "FLASH_ATTN_MLA": use FlashAttention for MLA

View File

@ -306,7 +306,6 @@ class DotsVisionAttention(nn.Module):
if self.attn_backend not in {
AttentionBackendEnum.FLASH_ATTN,
AttentionBackendEnum.TORCH_SDPA,
AttentionBackendEnum.XFORMERS,
AttentionBackendEnum.ROCM_AITER_FA,
}:
raise RuntimeError(
@ -324,7 +323,6 @@ class DotsVisionAttention(nn.Module):
rotary_pos_emb: torch.Tensor | None = None,
*,
max_seqlen: int | None = None,
seqlens: list[int] | None = None,
) -> torch.Tensor:
# [S, C] -> [S, B=1, C]
x = hidden_states.unsqueeze(1)
@ -374,16 +372,6 @@ class DotsVisionAttention(nn.Module):
out_i = out_i.permute(0, 2, 1, 3)
outputs.append(out_i)
context_layer = torch.cat(outputs, dim=1) if outputs else q[:, :0]
elif self.attn_backend == AttentionBackendEnum.XFORMERS:
from xformers import ops as xops
from xformers.ops.fmha.attn_bias import BlockDiagonalMask
attn_bias = BlockDiagonalMask.from_seqlens(
q_seqlen=seqlens, kv_seqlen=None, device=q.device
)
context_layer = xops.memory_efficient_attention_forward(
q, k, v, attn_bias=attn_bias, p=0, scale=None
)
else:
raise RuntimeError("Unsupported attention backend")
@ -545,14 +533,12 @@ class DotsVisionBlock(nn.Module):
cu_seqlens: torch.Tensor,
rotary_pos_emb: torch.Tensor,
max_seqlen: int | None = None,
seqlens: list[int] | None = None,
) -> torch.Tensor:
hidden_states = hidden_states + self.attn(
self.norm1(hidden_states),
cu_seqlens=cu_seqlens,
rotary_pos_emb=rotary_pos_emb,
max_seqlen=max_seqlen,
seqlens=seqlens,
)
hidden_states = hidden_states + self.mlp(self.norm2(hidden_states))
return hidden_states
@ -663,18 +649,14 @@ class DotsVisionTransformer(nn.Module):
rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
return rotary_pos_emb
def compute_attn_mask_seqlen(
self, cu_seqlens: torch.Tensor
) -> tuple[int | None, list[int] | None]:
max_seqlen, seqlens = None, None
def compute_attn_mask_seqlen(self, cu_seqlens: torch.Tensor) -> int | None:
max_seqlen = None
if (
self.attn_backend == AttentionBackendEnum.FLASH_ATTN
or self.attn_backend == AttentionBackendEnum.ROCM_AITER_FA
):
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
elif self.attn_backend == AttentionBackendEnum.XFORMERS:
seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
return max_seqlen, seqlens
return max_seqlen
def forward(
self, hidden_states: torch.Tensor, grid_thw: list[list[int]]
@ -694,14 +676,13 @@ class DotsVisionTransformer(nn.Module):
)
cu_seqlens = torch.cat([cu_seqlens.new_zeros(1), cu_seqlens])
max_seqlen, seqlens = self.compute_attn_mask_seqlen(cu_seqlens)
max_seqlen = self.compute_attn_mask_seqlen(cu_seqlens)
for blk in self.blocks:
hidden_states = blk(
hidden_states,
cu_seqlens=cu_seqlens,
rotary_pos_emb=rotary_pos_emb,
max_seqlen=max_seqlen,
seqlens=seqlens,
)
if self.post_trunk_norm is not None:

View File

@ -214,7 +214,6 @@ class Ernie4_5_VisionAttention(nn.Module):
if self.attn_backend not in {
AttentionBackendEnum.FLASH_ATTN,
AttentionBackendEnum.TORCH_SDPA,
AttentionBackendEnum.XFORMERS,
AttentionBackendEnum.ROCM_AITER_FA,
}:
raise RuntimeError(
@ -259,7 +258,6 @@ class Ernie4_5_VisionAttention(nn.Module):
cu_seqlens: torch.Tensor,
rotary_pos_emb: torch.Tensor,
max_seqlen: int | None = None, # Only used for Flash Attention
seqlens: list[int] | None = None, # Only used for xFormers
) -> torch.Tensor:
# [s, b, c] --> [s, b, head * 3 * head_dim]
x, _ = self.qkv(x)
@ -311,20 +309,6 @@ class Ernie4_5_VisionAttention(nn.Module):
context_layer = rearrange(
context_layer, "b s h d -> s b (h d)"
).contiguous()
elif self.attn_backend == AttentionBackendEnum.XFORMERS:
from xformers import ops as xops
from xformers.ops.fmha.attn_bias import BlockDiagonalMask
attn_bias = BlockDiagonalMask.from_seqlens(
q_seqlen=seqlens, kv_seqlen=None, device=q.device
)
context_layer = xops.memory_efficient_attention_forward(
q, k, v, attn_bias=attn_bias, p=0, scale=None
)
context_layer = rearrange(
context_layer, "b s h d -> s b (h d)"
).contiguous()
output, _ = self.proj(context_layer)
return output
@ -404,14 +388,12 @@ class Ernie4_5_VisionBlock(nn.Module):
cu_seqlens: torch.Tensor,
rotary_pos_emb: torch.Tensor,
max_seqlen: int | None = None, # Only used for Flash Attention
seqlens: list[int] | None = None, # Only used for xFormers
) -> torch.Tensor:
hidden_states = hidden_states + self.attn(
self.norm1(hidden_states),
cu_seqlens=cu_seqlens,
rotary_pos_emb=rotary_pos_emb,
max_seqlen=max_seqlen,
seqlens=seqlens,
)
hidden_states = hidden_states + self.mlp(self.norm2(hidden_states))
return hidden_states
@ -562,18 +544,14 @@ class Ernie4_5_VisionTransformer(nn.Module):
rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
return rotary_pos_emb
def compute_attn_mask_seqlen(
self, cu_seqlens: torch.Tensor
) -> tuple[int | None, list[int] | None]:
max_seqlen, seqlens = None, None
def compute_attn_mask_seqlen(self, cu_seqlens: torch.Tensor) -> int | None:
max_seqlen = None
if (
self.attn_backend == AttentionBackendEnum.FLASH_ATTN
or self.attn_backend == AttentionBackendEnum.ROCM_AITER_FA
):
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
elif self.attn_backend == AttentionBackendEnum.XFORMERS:
seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
return max_seqlen, seqlens
return max_seqlen
def forward(
self, hidden_states: torch.Tensor, grid_thw: torch.Tensor, num_pad=0
@ -598,8 +576,8 @@ class Ernie4_5_VisionTransformer(nn.Module):
if hidden_states.ndim == 2:
hidden_states = hidden_states.unsqueeze(dim=1)
# pre-compute seqlens for attn mask to reduce cuMemcpy operations
max_seqlen, seqlens = self.compute_attn_mask_seqlen(cu_seqlens)
# pre-compute max_seqlen for attn mask to reduce cuMemcpy operations
max_seqlen = self.compute_attn_mask_seqlen(cu_seqlens)
for i, blk in enumerate(self.blocks):
hidden_states = blk(
@ -607,7 +585,6 @@ class Ernie4_5_VisionTransformer(nn.Module):
cu_seqlens=cu_seqlens,
rotary_pos_emb=rotary_pos_emb,
max_seqlen=max_seqlen,
seqlens=seqlens,
)
final_output = self.ln(hidden_states)

View File

@ -309,7 +309,6 @@ class Glm4vVisionAttention(nn.Module):
if self.attn_backend not in {
AttentionBackendEnum.FLASH_ATTN,
AttentionBackendEnum.TORCH_SDPA,
AttentionBackendEnum.XFORMERS,
AttentionBackendEnum.ROCM_AITER_FA,
}:
raise RuntimeError(
@ -345,7 +344,6 @@ class Glm4vVisionAttention(nn.Module):
rotary_pos_emb_cos: torch.Tensor,
rotary_pos_emb_sin: torch.Tensor,
max_seqlen: int | None = None, # Only used for Flash Attention
seqlens: list[int] | None = None, # Only used for xFormers
) -> torch.Tensor:
# [s, b, c] --> [s, b, head * 3 * head_dim]
x, _ = self.qkv(x)
@ -400,20 +398,6 @@ class Glm4vVisionAttention(nn.Module):
context_layer = rearrange(
context_layer, "b s h d -> s b (h d)"
).contiguous()
elif self.attn_backend == AttentionBackendEnum.XFORMERS:
from xformers import ops as xops
from xformers.ops.fmha.attn_bias import BlockDiagonalMask
attn_bias = BlockDiagonalMask.from_seqlens(
q_seqlen=seqlens, kv_seqlen=None, device=q.device
)
context_layer = xops.memory_efficient_attention_forward(
q, k, v, attn_bias=attn_bias, p=0, scale=None
)
context_layer = rearrange(
context_layer, "b s h d -> s b (h d)"
).contiguous()
output, _ = self.proj(context_layer)
return output
@ -461,7 +445,6 @@ class Glm4vVisionBlock(nn.Module):
rotary_pos_emb_cos: torch.Tensor,
rotary_pos_emb_sin: torch.Tensor,
max_seqlen: int | None = None, # Only used for Flash Attention
seqlens: list[int] | None = None, # Only used for xFormers
) -> torch.Tensor:
x_attn = self.attn(
self.norm1(x),
@ -469,7 +452,6 @@ class Glm4vVisionBlock(nn.Module):
rotary_pos_emb_cos=rotary_pos_emb_cos,
rotary_pos_emb_sin=rotary_pos_emb_sin,
max_seqlen=max_seqlen,
seqlens=seqlens,
)
x_fused_norm, residual = self.norm2(x, residual=x_attn)
x = residual + self.mlp(x_fused_norm)
@ -803,15 +785,14 @@ class Glm4vVisionTransformer(nn.Module):
def compute_attn_mask_seqlen(
self,
cu_seqlens: torch.Tensor,
) -> tuple[int | None, list[int] | None]:
max_seqlen, seqlens = None, None
seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
) -> int | None:
max_seqlen = None
if (
self.attn_backend == AttentionBackendEnum.FLASH_ATTN
or self.attn_backend == AttentionBackendEnum.ROCM_AITER_FA
):
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
return max_seqlen, seqlens
return max_seqlen
def forward(
self,
@ -836,8 +817,9 @@ class Glm4vVisionTransformer(nn.Module):
).cumsum(dim=0, dtype=torch.int32)
cu_seqlens = F.pad(cu_seqlens, (1, 0), "constant", 0)
# pre-compute seqlens for attn mask to reduce cuMemcpy operations
max_seqlen, seqlens = self.compute_attn_mask_seqlen(cu_seqlens)
# pre-compute max_seqlen for attn mask to reduce cuMemcpy operations
max_seqlen = self.compute_attn_mask_seqlen(cu_seqlens)
seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
x = self.embeddings(
x, seqlens, grid_thw, image_type_ids[:, 0], image_type_ids[:, 1]
)
@ -851,7 +833,6 @@ class Glm4vVisionTransformer(nn.Module):
rotary_pos_emb_cos=rotary_pos_emb_cos,
rotary_pos_emb_sin=rotary_pos_emb_sin,
max_seqlen=max_seqlen,
seqlens=seqlens,
)
# adapter

View File

@ -9,6 +9,7 @@ from typing import Annotated, Any, Literal, TypeAlias, TypeVar
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from transformers import PretrainedConfig
from transformers.activations import GELUActivation
@ -424,7 +425,7 @@ class KeyeSiglipAttention(nn.Module):
if self.attn_backend not in {
AttentionBackendEnum.FLASH_ATTN,
AttentionBackendEnum.XFORMERS,
AttentionBackendEnum.TORCH_SDPA,
AttentionBackendEnum.ROCM_AITER_FA,
}:
raise RuntimeError(
@ -451,7 +452,6 @@ class KeyeSiglipAttention(nn.Module):
)
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
batch_size = q.shape[0]
if rope_emb is None:
@ -498,17 +498,21 @@ class KeyeSiglipAttention(nn.Module):
softmax_scale=self.scale,
)
context_layer = rearrange(output, "(b s) ... -> b s ...", b=batch_size)
elif self.attn_backend == AttentionBackendEnum.XFORMERS:
from xformers import ops as xops
from xformers.ops.fmha.attn_bias import BlockDiagonalMask
attn_bias = BlockDiagonalMask.from_seqlens(
q_seqlen=seqlens, kv_seqlen=None, device=q.device
)
context_layer = xops.memory_efficient_attention_forward(
q, k, v, attn_bias=attn_bias, p=0, scale=None
)
elif self.attn_backend == AttentionBackendEnum.TORCH_SDPA:
outputs = []
for i in range(1, len(cu_seqlens)):
start_idx = cu_seqlens[i - 1]
end_idx = cu_seqlens[i]
q_i = q[:, start_idx:end_idx]
k_i = k[:, start_idx:end_idx]
v_i = v[:, start_idx:end_idx]
q_i, k_i, v_i = (
rearrange(x, "b s h d -> b h s d") for x in (q_i, k_i, v_i)
)
output_i = F.scaled_dot_product_attention(q_i, k_i, v_i, dropout_p=0.0)
output_i = rearrange(output_i, "b h s d -> b s h d ")
outputs.append(output_i)
context_layer = torch.cat(outputs, dim=1) if outputs else q[:, :0]
context_layer = rearrange(context_layer, "b s h d -> b s (h d)").contiguous()

View File

@ -38,7 +38,6 @@ from vllm.attention.layer import (
)
from vllm.attention.ops.vit_attn_wrappers import (
vit_flash_attn_wrapper,
vit_xformers_attn_wrapper,
)
from vllm.config import VllmConfig
from vllm.config.multimodal import BaseDummyOptions
@ -657,7 +656,6 @@ class SiglipAttention(nn.Module):
cu_seqlens: torch.Tensor,
rotary_pos_emb: torch.Tensor | None,
max_seqlen: torch.Tensor | None,
seqlens: torch.Tensor | None,
) -> torch.Tensor:
batch_size, _, _ = hidden_states.shape
@ -703,10 +701,6 @@ class SiglipAttention(nn.Module):
context_layer = rearrange(
context_layer, "b s h d -> s b (h d)"
).contiguous()
elif self.attn_backend == AttentionBackendEnum.XFORMERS:
if seqlens is None:
raise ValueError("xFormers attention backend requires seqlens tensor.")
context_layer = vit_xformers_attn_wrapper(q, k, v, seqlens)
else:
raise RuntimeError(
f"PaddleOCR-VL does not support {self.attn_backend} backend now."
@ -818,7 +812,6 @@ class SiglipEncoderLayer(nn.Module):
cu_seqlens: torch.Tensor,
rotary_pos_emb: torch.Tensor | None,
max_seqlen: torch.Tensor | None,
seqlens: torch.Tensor | None,
) -> torch.Tensor:
residual = hidden_states
@ -828,7 +821,6 @@ class SiglipEncoderLayer(nn.Module):
cu_seqlens=cu_seqlens,
rotary_pos_emb=rotary_pos_emb,
max_seqlen=max_seqlen,
seqlens=seqlens,
)
hidden_states = residual + hidden_states
@ -870,7 +862,6 @@ class SiglipEncoder(nn.Module):
if self.attn_backend not in {
AttentionBackendEnum.FLASH_ATTN,
AttentionBackendEnum.TORCH_SDPA,
AttentionBackendEnum.XFORMERS,
AttentionBackendEnum.ROCM_AITER_FA,
}:
raise RuntimeError(
@ -943,14 +934,11 @@ class SiglipEncoder(nn.Module):
cu_seqlens = cu_seqlens.to(device=device)
max_seqlen = None
seqlens = None
if self.attn_backend in {
AttentionBackendEnum.FLASH_ATTN,
AttentionBackendEnum.ROCM_AITER_FA,
}:
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
elif self.attn_backend == AttentionBackendEnum.XFORMERS:
seqlens = cu_seqlens[1:] - cu_seqlens[:-1]
hidden_states = inputs_embeds
for encoder_layer in self.layers:
@ -959,7 +947,6 @@ class SiglipEncoder(nn.Module):
cu_seqlens=cu_seqlens,
rotary_pos_emb=rotary_pos_emb,
max_seqlen=max_seqlen,
seqlens=seqlens,
)
return hidden_states

View File

@ -74,6 +74,7 @@ from .vision import (
)
try:
# Note: vLLM does not install xformers by default.
from xformers import ops as xops
if current_platform.is_cuda() and current_platform.has_device_capability(100):

View File

@ -46,7 +46,6 @@ from vllm.attention.layer import maybe_get_vit_flash_attn_backend
from vllm.attention.ops.vit_attn_wrappers import (
vit_flash_attn_wrapper,
vit_torch_sdpa_wrapper,
vit_xformers_attn_wrapper,
)
from vllm.compilation.decorators import support_torch_compile
from vllm.config import VllmConfig
@ -375,7 +374,6 @@ class Qwen2_5_VisionAttention(nn.Module):
rotary_pos_emb_cos: torch.Tensor,
rotary_pos_emb_sin: torch.Tensor,
max_seqlen: torch.Tensor, # Only used for Flash Attention
seqlens: torch.Tensor, # Only used for xFormers
) -> torch.Tensor:
# [s, b, c] --> [s, b, head * 3 * head_dim]
x, _ = self.qkv(x)
@ -435,8 +433,6 @@ class Qwen2_5_VisionAttention(nn.Module):
v,
cu_seqlens,
)
elif self.attn_backend == AttentionBackendEnum.XFORMERS:
context_layer = vit_xformers_attn_wrapper(q, k, v, seqlens)
output, _ = self.proj(context_layer)
return output
@ -448,9 +444,7 @@ class Qwen2_5_VisionAttention(nn.Module):
"cu_seqlens": 0,
"rotary_pos_emb_cos": 0,
"rotary_pos_emb_sin": 0,
"seqlens": 0,
},
mark_unbacked_dims={"seqlens": 0},
enable_if=should_torch_compile_mm_vit,
)
class Qwen2_5_VisionBlock(nn.Module):
@ -501,7 +495,6 @@ class Qwen2_5_VisionBlock(nn.Module):
rotary_pos_emb_cos: torch.Tensor,
rotary_pos_emb_sin: torch.Tensor,
max_seqlen: torch.Tensor, # Only used for Flash Attention
seqlens: torch.Tensor, # Only used for xFormers
) -> torch.Tensor:
x_attn = self.attn(
self.norm1(x),
@ -509,7 +502,6 @@ class Qwen2_5_VisionBlock(nn.Module):
rotary_pos_emb_cos=rotary_pos_emb_cos,
rotary_pos_emb_sin=rotary_pos_emb_sin,
max_seqlen=max_seqlen,
seqlens=seqlens,
)
x_fused_norm, residual = self.norm2(x, residual=x_attn)
x = residual + self.mlp(x_fused_norm)
@ -670,7 +662,6 @@ class Qwen2_5_VisionTransformer(nn.Module):
if self.attn_backend not in {
AttentionBackendEnum.FLASH_ATTN,
AttentionBackendEnum.TORCH_SDPA,
AttentionBackendEnum.XFORMERS,
AttentionBackendEnum.ROCM_AITER_FA,
}:
raise RuntimeError(
@ -822,17 +813,14 @@ class Qwen2_5_VisionTransformer(nn.Module):
def compute_attn_mask_seqlen(
self,
cu_seqlens: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
) -> torch.Tensor:
max_seqlen = torch.zeros([], device=cu_seqlens.device)
seqlens = torch.zeros(1, device=cu_seqlens.device)
if self.attn_backend in {
AttentionBackendEnum.FLASH_ATTN,
AttentionBackendEnum.ROCM_AITER_FA,
}:
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
elif self.attn_backend == AttentionBackendEnum.XFORMERS:
seqlens = cu_seqlens[1:] - cu_seqlens[:-1]
return max_seqlen, seqlens
return max_seqlen
@staticmethod
def invert_permutation(perm: torch.Tensor) -> torch.Tensor:
@ -897,10 +885,8 @@ class Qwen2_5_VisionTransformer(nn.Module):
# transformers
# pre-compute seqlens for window/full attn to reduce cuMemcpy operations
max_seqlen_full, seqlens_full = self.compute_attn_mask_seqlen(cu_seqlens)
max_seqlen_window, seqlens_window = self.compute_attn_mask_seqlen(
cu_window_seqlens
)
max_seqlen_full = self.compute_attn_mask_seqlen(cu_seqlens)
max_seqlen_window = self.compute_attn_mask_seqlen(cu_window_seqlens)
cu_seqlens = cu_seqlens.to(device=self.device, non_blocking=True)
cu_window_seqlens = cu_window_seqlens.to(device=self.device, non_blocking=True)
@ -927,11 +913,9 @@ class Qwen2_5_VisionTransformer(nn.Module):
if layer_num in self.fullatt_block_indexes:
cu_seqlens_now = cu_seqlens
max_seqlen_now = max_seqlen_full
seqlens_now = seqlens_full
else:
cu_seqlens_now = cu_window_seqlens
max_seqlen_now = max_seqlen_window
seqlens_now = seqlens_window
hidden_states = blk(
hidden_states,
@ -939,7 +923,6 @@ class Qwen2_5_VisionTransformer(nn.Module):
rotary_pos_emb_cos=rotary_pos_emb_cos,
rotary_pos_emb_sin=rotary_pos_emb_sin,
max_seqlen=max_seqlen_now,
seqlens=seqlens_now,
)
# For Qwen2.5-VL-3B, float16 will overflow at last block

View File

@ -348,7 +348,6 @@ class Qwen2VisionAttention(nn.Module):
if self.attn_backend not in {
AttentionBackendEnum.FLASH_ATTN,
AttentionBackendEnum.TORCH_SDPA,
AttentionBackendEnum.XFORMERS,
AttentionBackendEnum.ROCM_AITER_FA,
}:
raise RuntimeError(
@ -384,7 +383,6 @@ class Qwen2VisionAttention(nn.Module):
rotary_pos_emb_cos: torch.Tensor,
rotary_pos_emb_sin: torch.Tensor,
max_seqlen: int | None = None, # Only used for Flash Attention
seqlens: list[int] | None = None, # Only used for xFormers
) -> torch.Tensor:
# [s, b, c] --> [s, b, 3 * head * head_dim]
x, _ = self.qkv(x)
@ -445,20 +443,6 @@ class Qwen2VisionAttention(nn.Module):
context_layer = rearrange(
context_layer, "b s h d -> s b (h d)"
).contiguous()
elif self.attn_backend == AttentionBackendEnum.XFORMERS:
from xformers import ops as xops
from xformers.ops.fmha.attn_bias import BlockDiagonalMask
attn_bias = BlockDiagonalMask.from_seqlens(
q_seqlen=seqlens, kv_seqlen=None, device=q.device
)
context_layer = xops.memory_efficient_attention_forward(
q, k, v, attn_bias=attn_bias, p=0, scale=None
)
context_layer = rearrange(
context_layer, "b s h d -> s b (h d)"
).contiguous()
output, _ = self.proj(context_layer)
return output
@ -509,7 +493,6 @@ class Qwen2VisionBlock(nn.Module):
rotary_pos_emb_cos: torch.Tensor,
rotary_pos_emb_sin: torch.Tensor,
max_seqlen: int | None = None, # Only used for Flash Attention
seqlens: list[int] | None = None, # Only used for xFormers
) -> torch.Tensor:
x = x + self.attn(
self.norm1(x),
@ -517,7 +500,6 @@ class Qwen2VisionBlock(nn.Module):
rotary_pos_emb_cos=rotary_pos_emb_cos,
rotary_pos_emb_sin=rotary_pos_emb_sin,
max_seqlen=max_seqlen,
seqlens=seqlens,
)
x = x + self.mlp(self.norm2(x))
@ -728,18 +710,14 @@ class Qwen2VisionTransformer(nn.Module):
sin_combined = sin[pos_ids].flatten(1)
return cos_combined, sin_combined
def compute_attn_mask_seqlen(
self, cu_seqlens: torch.Tensor
) -> tuple[int | None, list[int] | None]:
max_seqlen, seqlens = None, None
def compute_attn_mask_seqlen(self, cu_seqlens: torch.Tensor) -> int | None:
max_seqlen = None
if self.attn_backend in {
AttentionBackendEnum.FLASH_ATTN,
AttentionBackendEnum.ROCM_AITER_FA,
}:
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
elif self.attn_backend == AttentionBackendEnum.XFORMERS:
seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
return max_seqlen, seqlens
return max_seqlen
def forward(
self,
@ -771,7 +749,7 @@ class Qwen2VisionTransformer(nn.Module):
x = x.unsqueeze(1)
# pre-compute seqlens for attn mask to reduce cuMemcpy operations
max_seqlen, seqlens = self.compute_attn_mask_seqlen(cu_seqlens)
max_seqlen = self.compute_attn_mask_seqlen(cu_seqlens)
cu_seqlens = cu_seqlens.to(self.device, non_blocking=True)
for blk in self.blocks:
x = blk(
@ -780,7 +758,6 @@ class Qwen2VisionTransformer(nn.Module):
rotary_pos_emb_cos=rotary_pos_emb_cos,
rotary_pos_emb_sin=rotary_pos_emb_sin,
max_seqlen=max_seqlen,
seqlens=seqlens,
)
# adapter

View File

@ -224,7 +224,6 @@ class Qwen3_VisionBlock(nn.Module):
rotary_pos_emb_cos: torch.Tensor,
rotary_pos_emb_sin: torch.Tensor,
max_seqlen: torch.Tensor, # Only used for Flash Attention
seqlens: torch.Tensor, # Only used for xFormers
) -> torch.Tensor:
x = x + self.attn(
self.norm1(x),
@ -232,7 +231,6 @@ class Qwen3_VisionBlock(nn.Module):
rotary_pos_emb_cos=rotary_pos_emb_cos,
rotary_pos_emb_sin=rotary_pos_emb_sin,
max_seqlen=max_seqlen,
seqlens=seqlens,
)
x = x + self.mlp(self.norm2(x))
@ -500,14 +498,11 @@ class Qwen3Omni_VisionTransformer(nn.Module):
def compute_attn_mask_seqlen(
self,
cu_seqlens: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
) -> torch.Tensor:
max_seqlen = torch.zeros([], device=cu_seqlens.device)
seqlens = torch.zeros(1, device=cu_seqlens.device)
if self.attn_backend == AttentionBackendEnum.FLASH_ATTN:
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
elif self.attn_backend == AttentionBackendEnum.XFORMERS:
seqlens = cu_seqlens[1:] - cu_seqlens[:-1]
return max_seqlen, seqlens
return max_seqlen
def forward(
self,
@ -533,7 +528,7 @@ class Qwen3Omni_VisionTransformer(nn.Module):
hidden_states = hidden_states.unsqueeze(1)
rotary_pos_emb_cos = rotary_pos_emb_cos.to(hidden_states.device)
rotary_pos_emb_sin = rotary_pos_emb_sin.to(hidden_states.device)
max_seqlen, seqlens = self.compute_attn_mask_seqlen(cu_seqlens)
max_seqlen = self.compute_attn_mask_seqlen(cu_seqlens)
hidden_states_list = []
deepstack_visual_indexes = self.deepstack_visual_indexes
@ -545,7 +540,6 @@ class Qwen3Omni_VisionTransformer(nn.Module):
rotary_pos_emb_cos=rotary_pos_emb_cos,
rotary_pos_emb_sin=rotary_pos_emb_sin,
max_seqlen=max_seqlen,
seqlens=seqlens,
)
if (
deepstack_visual_indexes is not None

View File

@ -235,7 +235,6 @@ class Qwen3_VisionBlock(nn.Module):
rotary_pos_emb_cos: torch.Tensor,
rotary_pos_emb_sin: torch.Tensor,
max_seqlen: torch.Tensor, # Only used for Flash Attention
seqlens: torch.Tensor, # Only used for xFormers
) -> torch.Tensor:
x = x + self.attn(
self.norm1(x),
@ -243,7 +242,6 @@ class Qwen3_VisionBlock(nn.Module):
rotary_pos_emb_cos=rotary_pos_emb_cos,
rotary_pos_emb_sin=rotary_pos_emb_sin,
max_seqlen=max_seqlen,
seqlens=seqlens,
)
x = x + self.mlp(self.norm2(x))
@ -391,7 +389,6 @@ class Qwen3_VisionTransformer(nn.Module):
if self.attn_backend not in {
AttentionBackendEnum.FLASH_ATTN,
AttentionBackendEnum.TORCH_SDPA,
AttentionBackendEnum.XFORMERS,
AttentionBackendEnum.ROCM_AITER_FA,
}:
raise RuntimeError(
@ -531,17 +528,14 @@ class Qwen3_VisionTransformer(nn.Module):
def compute_attn_mask_seqlen(
self,
cu_seqlens: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
) -> torch.Tensor:
max_seqlen = torch.zeros([], device=cu_seqlens.device)
seqlens = torch.zeros(1, device=cu_seqlens.device)
if (
self.attn_backend == AttentionBackendEnum.FLASH_ATTN
or self.attn_backend == AttentionBackendEnum.ROCM_AITER_FA
):
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
elif self.attn_backend == AttentionBackendEnum.XFORMERS:
seqlens = cu_seqlens[1:] - cu_seqlens[:-1]
return max_seqlen, seqlens
return max_seqlen
def forward(
self,
@ -569,7 +563,7 @@ class Qwen3_VisionTransformer(nn.Module):
cu_seqlens = torch.from_numpy(cu_seqlens)
hidden_states = hidden_states.unsqueeze(1)
max_seqlen, seqlens = self.compute_attn_mask_seqlen(cu_seqlens)
max_seqlen = self.compute_attn_mask_seqlen(cu_seqlens)
cu_seqlens = cu_seqlens.to(self.device, non_blocking=True)
deepstack_feature_lists = []
@ -580,7 +574,6 @@ class Qwen3_VisionTransformer(nn.Module):
rotary_pos_emb_cos=rotary_pos_emb_cos,
rotary_pos_emb_sin=rotary_pos_emb_sin,
max_seqlen=max_seqlen,
seqlens=seqlens,
)
if layer_num in self.deepstack_visual_indexes:
deepstack_merger_idx = self.deepstack_visual_indexes.index(layer_num)

View File

@ -277,12 +277,7 @@ class CudaPlatformBase(Platform):
except ImportError:
pass
if cls.has_device_capability(100):
# xFormers doesn't support Blackwell, fall back to SDPA
# See https://github.com/facebookresearch/xformers/issues/1317#issuecomment-3199392579 # noqa: E501
return AttentionBackendEnum.TORCH_SDPA
else:
return AttentionBackendEnum.XFORMERS
return AttentionBackendEnum.TORCH_SDPA
@classmethod
def get_valid_backends(

View File

@ -49,7 +49,6 @@ STR_BACKEND_ENV_VAR: str = "VLLM_ATTENTION_BACKEND"
# Possible string values of STR_BACKEND_ENV_VAR
# register, corresponding to possible backends
STR_FLASHINFER_ATTN_VAL: str = "FLASHINFER"
STR_XFORMERS_ATTN_VAL: str = "XFORMERS"
STR_FLASH_ATTN_VAL: str = "FLASH_ATTN"
STR_INVALID_VAL: str = "INVALID"

View File

@ -1,420 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Attention layer with XFormersAttention."""
from dataclasses import dataclass
from typing import ClassVar, Optional
import torch
from vllm.attention.backends.abstract import (
AttentionBackend,
AttentionImpl,
AttentionType,
MultipleOf,
)
from vllm.attention.ops.triton_unified_attention import unified_attention
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.v1.attention.backends.utils import (
AttentionMetadataBuilder,
CommonAttentionMetadata,
split_decodes_and_prefills,
)
from vllm.v1.kv_cache_interface import AttentionSpec
try:
from xformers import ops as xops
from xformers.ops.fmha.attn_bias import (
AttentionBias,
PagedBlockDiagonalCausalWithOffsetPaddedKeysMask,
)
XFORMERS_AVAILABLE = True
except ImportError:
XFORMERS_AVAILABLE = False
from vllm import _custom_ops as ops
logger = init_logger(__name__)
class XFormersAttentionBackend(AttentionBackend):
accept_output_buffer: bool = True
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
@staticmethod
def get_supported_kernel_block_sizes() -> list[int | MultipleOf]:
return [MultipleOf(16)]
@classmethod
def get_supported_head_sizes(cls) -> list[int]:
return [
32,
40,
48,
56,
64,
72,
80,
88,
96,
104,
112,
120,
128,
136,
144,
152,
160,
168,
176,
184,
192,
200,
208,
216,
224,
232,
240,
248,
256,
]
@staticmethod
def get_name() -> str:
return "XFORMERS"
@staticmethod
def get_impl_cls() -> type["XFormersAttentionImpl"]:
return XFormersAttentionImpl
@staticmethod
def get_kv_cache_shape(
num_blocks: int,
block_size: int,
num_kv_heads: int,
head_size: int,
cache_dtype_str: str = "auto",
) -> tuple[int, ...]:
if block_size % 16 != 0:
raise ValueError("Block size must be a multiple of 16.")
return (2, num_blocks, block_size, num_kv_heads, head_size)
@staticmethod
def get_builder_cls() -> type["XFormersAttentionMetadataBuilder"]:
return XFormersAttentionMetadataBuilder
@staticmethod
def use_cascade_attention(*args, **kwargs) -> bool:
return False
@dataclass
class XFormersAttentionMetadata:
num_actual_tokens: int # Number of tokens excluding padding.
max_query_len: int
query_start_loc: torch.Tensor
max_seq_len: int
seq_lens: torch.Tensor
block_table: torch.Tensor
slot_mapping: torch.Tensor
num_prefill_tokens: int = 0
num_decode_tokens: int = 0
num_prefills: int = 0
num_decodes: int = 0
# Biases for different attention types.
attn_bias: Optional["AttentionBias"] = None
# Self-attention prefill/decode metadata cache
_cached_prefill_metadata: Optional["XFormersAttentionMetadata"] = None
_cached_decode_metadata: Optional["XFormersAttentionMetadata"] = None
@property
def prefill_metadata(self) -> Optional["XFormersAttentionMetadata"]:
if self.num_prefills == 0:
return None
if self._cached_prefill_metadata is not None:
# Recover cached prefill-phase attention
# metadata structure
return self._cached_prefill_metadata
q_start_loc = self.query_start_loc[self.num_decodes :]
q_seqlens = torch.diff(q_start_loc)
kv_seqlens = self.seq_lens[self.num_decodes :]
# Construct & cache prefill-phase attention metadata structure
self._cached_prefill_metadata = XFormersAttentionMetadata(
num_actual_tokens=self.num_prefill_tokens,
max_query_len=int(q_seqlens.max().item()),
query_start_loc=q_start_loc - q_start_loc[0],
max_seq_len=int(kv_seqlens.max().item()),
seq_lens=kv_seqlens,
block_table=self.block_table[self.num_decodes :],
slot_mapping=self.slot_mapping[self.num_decode_tokens :],
)
return self._cached_prefill_metadata
@property
def decode_metadata(self) -> Optional["XFormersAttentionMetadata"]:
if self.num_decode_tokens == 0:
return None
if self._cached_decode_metadata is not None:
# Recover cached decode-phase attention
# metadata structure
return self._cached_decode_metadata
q_start_loc = self.query_start_loc
q_seqlens = torch.diff(q_start_loc)
decode_kv_seqlens = self.seq_lens[: self.num_decodes]
# Construct & cache decode-phase attention metadata structure
self._cached_decode_metadata = XFormersAttentionMetadata(
num_actual_tokens=self.num_decode_tokens,
max_query_len=int(q_seqlens[: self.num_decodes].max().item()),
query_start_loc=q_start_loc[: self.num_decodes + 1],
max_seq_len=int(decode_kv_seqlens.max().item()),
seq_lens=decode_kv_seqlens,
block_table=self.block_table[: self.num_decodes],
slot_mapping=self.slot_mapping[: self.num_decode_tokens],
attn_bias=self.attn_bias,
)
return self._cached_decode_metadata
class XFormersAttentionMetadataBuilder(
AttentionMetadataBuilder[XFormersAttentionMetadata]
):
reorder_batch_threshold: int = 1
def __init__(
self,
kv_cache_spec: AttentionSpec,
layer_names: list[str],
vllm_config: VllmConfig,
device: torch.device,
):
super().__init__(kv_cache_spec, layer_names, vllm_config, device)
assert XFORMERS_AVAILABLE
self.block_size = kv_cache_spec.block_size
self._num_decodes = 0
self._num_decode_tokens = 0
def build(
self,
common_prefix_len: int,
common_attn_metadata: CommonAttentionMetadata,
fast_build: bool = False,
) -> XFormersAttentionMetadata:
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
split_decodes_and_prefills(
common_attn_metadata, decode_threshold=self.reorder_batch_threshold
)
)
num_actual_tokens = common_attn_metadata.num_actual_tokens
q_start_loc = common_attn_metadata.query_start_loc
q_seqlens = torch.diff(q_start_loc)
max_query_len = common_attn_metadata.max_query_len
kv_seqlens = common_attn_metadata.seq_lens
max_seq_len = common_attn_metadata.max_seq_len
block_table = common_attn_metadata.block_table_tensor
slot_mapping = common_attn_metadata.slot_mapping
bias = None
if num_decodes > 0:
# Construct the decoder bias.
decode_q_seqlens = q_seqlens[:num_decodes]
decode_kv_seqlens = kv_seqlens[:num_decodes]
bias = PagedBlockDiagonalCausalWithOffsetPaddedKeysMask.from_seqlens(
q_seqlen=decode_q_seqlens.tolist(),
kv_seqlen=decode_kv_seqlens.tolist(),
page_size=self.block_size,
block_tables=block_table[:num_decodes],
device=block_table.device,
)
return XFormersAttentionMetadata(
num_actual_tokens=num_actual_tokens,
num_prefill_tokens=num_prefill_tokens,
num_decode_tokens=num_decode_tokens,
num_prefills=num_prefills,
num_decodes=num_decodes,
max_query_len=max_query_len,
query_start_loc=q_start_loc,
max_seq_len=max_seq_len,
seq_lens=kv_seqlens,
block_table=block_table,
slot_mapping=slot_mapping,
attn_bias=bias,
)
class XFormersAttentionImpl(AttentionImpl):
def __init__(
self,
num_heads: int,
head_size: int,
scale: float,
num_kv_heads: int,
alibi_slopes: list[float] | None,
sliding_window: int | None,
kv_cache_dtype: str,
logits_soft_cap: float | None = None,
attn_type: AttentionType = AttentionType.DECODER,
kv_sharing_target_layer_name: str | None = None,
) -> None:
if kv_sharing_target_layer_name is not None:
raise NotImplementedError("KV sharing is not supported in V0.")
if alibi_slopes is not None:
raise NotImplementedError("XFormers does not support alibi slopes yet.")
self.num_heads = num_heads
self.head_size = head_size
self.scale = float(scale)
self.num_kv_heads = num_kv_heads
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
self.kv_cache_dtype = kv_cache_dtype
self.kv_sharing_target_layer_name = kv_sharing_target_layer_name
if alibi_slopes is not None:
alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
self.alibi_slopes = alibi_slopes
if sliding_window is None:
self.sliding_window = (-1, -1)
else:
self.sliding_window = (sliding_window - 1, 0)
if logits_soft_cap is None:
# Setting logits_soft_cap to 0 means no soft cap.
logits_soft_cap = 0
self.logits_soft_cap = logits_soft_cap
if attn_type != AttentionType.DECODER:
raise NotImplementedError(
"Encoder self-attention and "
"encoder/decoder cross-attention "
"are not implemented for "
"XFormersAttentionImpl."
)
def forward(
self,
layer: torch.nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: XFormersAttentionMetadata,
output: torch.Tensor | None = None,
output_scale: torch.Tensor | None = None,
output_block_scale: torch.Tensor | None = None,
) -> torch.Tensor:
"""Forward pass with XFormers.
Args:
query: shape = [num_tokens, num_heads, head_size]
key: shape = [num_tokens, num_kv_heads, head_size]
value: shape = [num_tokens, num_kv_heads, head_size]
kv_cache: shape =
[2, num_blocks, block_size, num_kv_heads, head_size]
attn_metadata: Metadata for attention.
Returns:
shape = [num_tokens, num_heads * head_size]
"""
assert output is not None, "Output tensor must be provided."
if output_scale is not None or output_block_scale is not None:
raise NotImplementedError(
"fused output quantization is not yet supported"
" for XFormersAttentionImpl"
)
if attn_metadata is None:
# Profiling run.
return output.fill_(0)
# Cache the input KVs.
key_cache, value_cache = kv_cache.unbind(0)
if self.kv_sharing_target_layer_name is None:
# Reshape the input keys and values and store them in the cache.
# Skip this if sharing KV cache with an earlier attention layer.
# NOTE(woosuk): Here, key and value are padded while slot_mapping is
# not padded. However, we don't need to do key[:num_actual_tokens]
# and value[:num_actual_tokens] because the reshape_and_cache_flash
# op uses the slot_mapping's shape to determine the number of
# actual tokens.
ops.reshape_and_cache_flash(
key,
value,
key_cache,
value_cache,
attn_metadata.slot_mapping,
self.kv_cache_dtype,
layer._k_scale,
layer._v_scale,
)
num_actual_tokens = attn_metadata.num_actual_tokens
num_decode_tokens = attn_metadata.num_decode_tokens
if prefill_meta := attn_metadata.prefill_metadata:
descale_shape = (prefill_meta.query_start_loc.shape[0] - 1, key.shape[1])
unified_attention(
q=query[num_decode_tokens:num_actual_tokens],
k=key_cache,
v=value_cache,
out=output[num_decode_tokens:num_actual_tokens],
cu_seqlens_q=prefill_meta.query_start_loc,
max_seqlen_q=prefill_meta.max_query_len,
seqused_k=prefill_meta.seq_lens,
max_seqlen_k=prefill_meta.max_seq_len,
softmax_scale=self.scale,
causal=True,
alibi_slopes=self.alibi_slopes,
window_size=self.sliding_window,
block_table=prefill_meta.block_table,
softcap=self.logits_soft_cap,
q_descale=None, # Not supported
k_descale=layer._k_scale.expand(descale_shape),
v_descale=layer._v_scale.expand(descale_shape),
)
if decode_meta := attn_metadata.decode_metadata:
# Query for decode. KV is not needed because it is already cached.
decode_query = query[:num_decode_tokens]
# Reshape query to [1, B_T, G, H, D].
q = decode_query.view(
1, -1, self.num_kv_heads, self.num_queries_per_kv, self.head_size
)
# Reshape the k and v caches to [1, Bkv_T, G, H, D]
cache_k = key_cache.view(
1, -1, self.num_kv_heads, 1, self.head_size
).expand(
1,
-1,
self.num_kv_heads,
self.num_queries_per_kv,
self.head_size,
)
cache_v = value_cache.view(
1, -1, self.num_kv_heads, 1, self.head_size
).expand(
1,
-1,
self.num_kv_heads,
self.num_queries_per_kv,
self.head_size,
)
attn_bias = decode_meta.attn_bias
output[:num_decode_tokens] = xops.memory_efficient_attention_forward(
q,
cache_k,
cache_v,
attn_bias=attn_bias,
p=0.0,
scale=self.scale,
).view(decode_query.shape)
# Reshape the output tensor.
return output