mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 00:45:01 +08:00
[Core] Deprecate xformers (#29262)
Signed-off-by: Roger Wang <hey@rogerw.io>
This commit is contained in:
parent
5253f4276f
commit
0ff70821c9
@ -76,34 +76,6 @@ RUN --mount=type=cache,target=/root/.cache/uv \
|
|||||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||||
uv pip install --system -r requirements/common.txt
|
uv pip install --system -r requirements/common.txt
|
||||||
|
|
||||||
# must put before installing xformers, so it can install the correct version of xfomrers.
|
|
||||||
ARG torch_cuda_arch_list='8.0;8.6;8.9;9.0'
|
|
||||||
ENV TORCH_CUDA_ARCH_LIST=${torch_cuda_arch_list}
|
|
||||||
|
|
||||||
# Build xformers with cuda and torch nightly
|
|
||||||
# following official xformers guidance: https://github.com/facebookresearch/xformers#build
|
|
||||||
# todo(elainewy): cache xformers build result for faster build
|
|
||||||
ARG max_jobs=16
|
|
||||||
ENV MAX_JOBS=${max_jobs}
|
|
||||||
ARG XFORMERS_COMMIT=f2de641ef670510cadab099ce6954031f52f191c
|
|
||||||
|
|
||||||
ENV CCACHE_DIR=/root/.cache/ccache
|
|
||||||
RUN --mount=type=cache,target=/root/.cache/ccache \
|
|
||||||
--mount=type=cache,target=/root/.cache/uv \
|
|
||||||
echo 'git clone xformers...' \
|
|
||||||
&& git clone https://github.com/facebookresearch/xformers.git --recursive \
|
|
||||||
&& cd xformers \
|
|
||||||
&& git checkout ${XFORMERS_COMMIT} \
|
|
||||||
&& git submodule update --init --recursive \
|
|
||||||
&& echo 'finish git clone xformers...' \
|
|
||||||
&& rm -rf build \
|
|
||||||
&& python3 setup.py bdist_wheel --dist-dir=../xformers-dist --verbose \
|
|
||||||
&& cd .. \
|
|
||||||
&& rm -rf xformers
|
|
||||||
|
|
||||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
|
||||||
uv pip install --system xformers-dist/*.whl --verbose
|
|
||||||
|
|
||||||
# build can take a long time, and the torch nightly version fetched from url can be different in next docker stage.
|
# build can take a long time, and the torch nightly version fetched from url can be different in next docker stage.
|
||||||
# track the nightly torch version used in the build, when we set up runtime environment we can make sure the version is the same
|
# track the nightly torch version used in the build, when we set up runtime environment we can make sure the version is the same
|
||||||
RUN uv pip freeze | grep -i '^torch\|^torchvision\|^torchaudio' > torch_build_versions.txt
|
RUN uv pip freeze | grep -i '^torch\|^torchvision\|^torchaudio' > torch_build_versions.txt
|
||||||
@ -233,11 +205,6 @@ RUN --mount=type=bind,from=build,src=/workspace/dist,target=/vllm-workspace/vllm
|
|||||||
--mount=type=cache,target=/root/.cache/uv \
|
--mount=type=cache,target=/root/.cache/uv \
|
||||||
uv pip install --system vllm-dist/*.whl --verbose
|
uv pip install --system vllm-dist/*.whl --verbose
|
||||||
|
|
||||||
# install xformers again for the new environment
|
|
||||||
RUN --mount=type=bind,from=base,src=/workspace/xformers-dist,target=/vllm-workspace/xformers-dist \
|
|
||||||
--mount=type=cache,target=/root/.cache/uv \
|
|
||||||
uv pip install --system /vllm-workspace/xformers-dist/*.whl --verbose
|
|
||||||
|
|
||||||
ARG torch_cuda_arch_list='8.0;8.6;8.9;9.0'
|
ARG torch_cuda_arch_list='8.0;8.6;8.9;9.0'
|
||||||
|
|
||||||
# install package for build flashinfer
|
# install package for build flashinfer
|
||||||
@ -307,7 +274,7 @@ RUN --mount=type=cache,target=/root/.cache/uv \
|
|||||||
uv pip install --system -r requirements/nightly_torch_test.txt
|
uv pip install --system -r requirements/nightly_torch_test.txt
|
||||||
|
|
||||||
# Logging to confirm the torch versions
|
# Logging to confirm the torch versions
|
||||||
RUN pip freeze | grep -E 'torch|xformers|vllm|flashinfer'
|
RUN pip freeze | grep -E 'torch|vllm|flashinfer'
|
||||||
|
|
||||||
# Logging to confirm all the packages are installed
|
# Logging to confirm all the packages are installed
|
||||||
RUN pip freeze
|
RUN pip freeze
|
||||||
|
|||||||
@ -98,21 +98,6 @@ to warm it up so that future builds are faster.
|
|||||||
<img width="60%" alt="Buildkite new build popup" src="https://github.com/user-attachments/assets/a8ff0fcd-76e0-4e91-b72f-014e3fdb6b94">
|
<img width="60%" alt="Buildkite new build popup" src="https://github.com/user-attachments/assets/a8ff0fcd-76e0-4e91-b72f-014e3fdb6b94">
|
||||||
</p>
|
</p>
|
||||||
|
|
||||||
## Update dependencies
|
|
||||||
|
|
||||||
Several vLLM dependencies like xFormers depend on PyTorch and need
|
|
||||||
to be updated accordingly. Rather than waiting for all of them to publish new
|
|
||||||
releases (which would take too much time), they can be built from
|
|
||||||
source to unblock the update process.
|
|
||||||
|
|
||||||
### xFormers
|
|
||||||
|
|
||||||
```bash
|
|
||||||
export TORCH_CUDA_ARCH_LIST='7.5 8.0+PTX 9.0a'
|
|
||||||
MAX_JOBS=16 uv pip install --system \
|
|
||||||
--no-build-isolation "git+https://github.com/facebookresearch/xformers@v0.0.32.post2"
|
|
||||||
```
|
|
||||||
|
|
||||||
## Update all the different vLLM platforms
|
## Update all the different vLLM platforms
|
||||||
|
|
||||||
Rather than attempting to update all vLLM platforms in a single pull request, it's more manageable
|
Rather than attempting to update all vLLM platforms in a single pull request, it's more manageable
|
||||||
|
|||||||
@ -283,7 +283,7 @@ Currently, vLLM supports multiple backends for efficient Attention computation a
|
|||||||
|
|
||||||
If desired, you can also manually set the backend of your choice by configuring the environment variable `VLLM_ATTENTION_BACKEND` to one of the following options:
|
If desired, you can also manually set the backend of your choice by configuring the environment variable `VLLM_ATTENTION_BACKEND` to one of the following options:
|
||||||
|
|
||||||
- On NVIDIA CUDA: `FLASH_ATTN`, `FLASHINFER` or `XFORMERS`.
|
- On NVIDIA CUDA: `FLASH_ATTN` or `FLASHINFER`.
|
||||||
- On AMD ROCm: `TRITON_ATTN`, `ROCM_ATTN`, `ROCM_AITER_FA` or `ROCM_AITER_UNIFIED_ATTN`.
|
- On AMD ROCm: `TRITON_ATTN`, `ROCM_ATTN`, `ROCM_AITER_FA` or `ROCM_AITER_UNIFIED_ATTN`.
|
||||||
|
|
||||||
For AMD ROCm, you can further control the specific Attention implementation using the following variables:
|
For AMD ROCm, you can further control the specific Attention implementation using the following variables:
|
||||||
|
|||||||
@ -22,7 +22,6 @@ API_KEY=${API_KEY:-"your-api-key"}
|
|||||||
POOLING_TYPE=${POOLING_TYPE:-"auto"} # auto, MEAN, CLS, LAST
|
POOLING_TYPE=${POOLING_TYPE:-"auto"} # auto, MEAN, CLS, LAST
|
||||||
export VLLM_ENABLE_CHUNKED_PROCESSING=true
|
export VLLM_ENABLE_CHUNKED_PROCESSING=true
|
||||||
export CUDA_VISIBLE_DEVICES=2,3,4,5
|
export CUDA_VISIBLE_DEVICES=2,3,4,5
|
||||||
# export VLLM_ATTENTION_BACKEND=XFORMERS
|
|
||||||
|
|
||||||
echo "🚀 Starting vLLM Embedding Server with Enhanced Chunked Processing"
|
echo "🚀 Starting vLLM Embedding Server with Enhanced Chunked Processing"
|
||||||
echo "=================================================================="
|
echo "=================================================================="
|
||||||
|
|||||||
@ -9,6 +9,5 @@ torch==2.9.0
|
|||||||
torchaudio==2.9.0
|
torchaudio==2.9.0
|
||||||
# These must be updated alongside torch
|
# These must be updated alongside torch
|
||||||
torchvision==0.24.0 # Required for phi3v processor. See https://github.com/pytorch/vision?tab=readme-ov-file#installation for corresponding version
|
torchvision==0.24.0 # Required for phi3v processor. See https://github.com/pytorch/vision?tab=readme-ov-file#installation for corresponding version
|
||||||
xformers==0.0.33.post1; platform_system == 'Linux' and platform_machine == 'x86_64' # Requires PyTorch >= 2.9
|
|
||||||
# FlashInfer should be updated together with the Dockerfile
|
# FlashInfer should be updated together with the Dockerfile
|
||||||
flashinfer-python==0.5.2
|
flashinfer-python==0.5.2
|
||||||
|
|||||||
@ -74,9 +74,6 @@ def test_models(
|
|||||||
model_executor: str,
|
model_executor: str,
|
||||||
enable_prompt_embeds: bool,
|
enable_prompt_embeds: bool,
|
||||||
) -> None:
|
) -> None:
|
||||||
if backend == "XFORMERS" and model == "google/gemma-2-2b-it":
|
|
||||||
pytest.skip(f"{backend} does not support gemma2 with full context length.")
|
|
||||||
|
|
||||||
with monkeypatch.context() as m:
|
with monkeypatch.context() as m:
|
||||||
m.setenv("VLLM_ATTENTION_BACKEND", backend)
|
m.setenv("VLLM_ATTENTION_BACKEND", backend)
|
||||||
|
|
||||||
|
|||||||
@ -13,12 +13,6 @@ from vllm.attention.layer import Attention, MultiHeadAttention
|
|||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.utils.mem_utils import get_max_shared_memory_bytes
|
from vllm.utils.mem_utils import get_max_shared_memory_bytes
|
||||||
|
|
||||||
if not current_platform.is_rocm():
|
|
||||||
from xformers import ops as xops
|
|
||||||
from xformers.ops.fmha.attn_bias import BlockDiagonalCausalMask
|
|
||||||
|
|
||||||
from tests.kernels.utils import make_alibi_bias
|
|
||||||
|
|
||||||
FLOAT32_BYTES = torch.finfo(torch.float).bits // 8
|
FLOAT32_BYTES = torch.finfo(torch.float).bits // 8
|
||||||
# This will change depending on the compute capability.
|
# This will change depending on the compute capability.
|
||||||
# - 512 as a buffer
|
# - 512 as a buffer
|
||||||
@ -448,129 +442,6 @@ def ref_multi_query_kv_attention(
|
|||||||
return torch.cat(ref_outputs, dim=0)
|
return torch.cat(ref_outputs, dim=0)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("num_seqs", NUM_PREFILL_SEQS)
|
|
||||||
@pytest.mark.parametrize("num_heads", NUM_HEADS)
|
|
||||||
@pytest.mark.parametrize("head_size", HEAD_SIZES)
|
|
||||||
@pytest.mark.parametrize("dtype", DTYPES)
|
|
||||||
@pytest.mark.parametrize("seed", SEEDS)
|
|
||||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
|
||||||
@pytest.mark.skipif(
|
|
||||||
current_platform.is_rocm(), reason="Xformers backend is not supported on ROCm."
|
|
||||||
)
|
|
||||||
@torch.inference_mode()
|
|
||||||
def test_multi_query_kv_attention(
|
|
||||||
num_seqs: int,
|
|
||||||
num_heads: tuple[int, int],
|
|
||||||
head_size: int,
|
|
||||||
dtype: torch.dtype,
|
|
||||||
seed: int,
|
|
||||||
device: str,
|
|
||||||
use_alibi: bool = False,
|
|
||||||
) -> None:
|
|
||||||
current_platform.seed_everything(seed)
|
|
||||||
torch.set_default_device(device)
|
|
||||||
# MAX_SEQ_LEN sometimes causes OOM in the reference implementation.
|
|
||||||
# As the xformers library is already tested with its own tests, we can use
|
|
||||||
# a smaller MAX_SEQ_LEN here.
|
|
||||||
max_len = min(MAX_SEQ_LEN, 4096)
|
|
||||||
seq_lens = random.sample(range(1, max_len), num_seqs)
|
|
||||||
num_tokens = sum(seq_lens)
|
|
||||||
|
|
||||||
scale = float(1.0 / (head_size**0.5))
|
|
||||||
num_query_heads, num_kv_heads = num_heads
|
|
||||||
qkv = torch.empty(
|
|
||||||
num_tokens, num_query_heads + 2 * num_kv_heads, head_size, dtype=dtype
|
|
||||||
)
|
|
||||||
qkv.uniform_(-scale, scale)
|
|
||||||
query, key, value = qkv.split([num_query_heads, num_kv_heads, num_kv_heads], dim=1)
|
|
||||||
|
|
||||||
num_queries_per_kv = num_query_heads // num_kv_heads
|
|
||||||
if num_queries_per_kv > 1:
|
|
||||||
# Handle MQA and GQA
|
|
||||||
key = torch.repeat_interleave(key, num_queries_per_kv, dim=1)
|
|
||||||
value = torch.repeat_interleave(value, num_queries_per_kv, dim=1)
|
|
||||||
alibi_bias = None
|
|
||||||
if use_alibi:
|
|
||||||
alibi_slopes = torch.randn(num_query_heads, dtype=torch.float)
|
|
||||||
attn_bias = make_alibi_bias(alibi_slopes, num_kv_heads, dtype, seq_lens)
|
|
||||||
output = torch.empty_like(query)
|
|
||||||
start = 0
|
|
||||||
# Dynamic sequence length not supported with custom attn_bias.
|
|
||||||
for i, seq_len in enumerate(seq_lens):
|
|
||||||
end = start + seq_len
|
|
||||||
out = xops.memory_efficient_attention_forward(
|
|
||||||
query[None, start:end],
|
|
||||||
key[None, start:end],
|
|
||||||
value[None, start:end],
|
|
||||||
attn_bias=attn_bias[i],
|
|
||||||
p=0.0,
|
|
||||||
scale=scale,
|
|
||||||
)
|
|
||||||
output[start:end].copy_(out.view_as(query[start:end]))
|
|
||||||
start += seq_len
|
|
||||||
# xformers.AttentionBias to Tensor for use in reference impl.
|
|
||||||
alibi_bias = [
|
|
||||||
b.materialize((1, num_query_heads, i, i), device=device).squeeze()
|
|
||||||
for b, i in zip(attn_bias, seq_lens)
|
|
||||||
]
|
|
||||||
else:
|
|
||||||
attn_bias = BlockDiagonalCausalMask.from_seqlens(seq_lens)
|
|
||||||
output = xops.memory_efficient_attention_forward(
|
|
||||||
query.unsqueeze(0),
|
|
||||||
key.unsqueeze(0),
|
|
||||||
value.unsqueeze(0),
|
|
||||||
attn_bias=attn_bias,
|
|
||||||
p=0.0,
|
|
||||||
scale=scale,
|
|
||||||
)
|
|
||||||
output = output.squeeze(0)
|
|
||||||
|
|
||||||
cu_seq_lens = [0]
|
|
||||||
for seq_len in seq_lens:
|
|
||||||
cu_seq_lens.append(cu_seq_lens[-1] + seq_len)
|
|
||||||
ref_output = ref_multi_query_kv_attention(
|
|
||||||
cu_seq_lens,
|
|
||||||
query,
|
|
||||||
key,
|
|
||||||
value,
|
|
||||||
scale,
|
|
||||||
alibi_bias,
|
|
||||||
dtype,
|
|
||||||
)
|
|
||||||
atol = get_default_atol(output) if current_platform.is_rocm() else 1e-3
|
|
||||||
rtol = get_default_rtol(output) if current_platform.is_rocm() else 1e-5
|
|
||||||
torch.testing.assert_close(output, ref_output, atol=atol, rtol=rtol)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("num_seqs", NUM_PREFILL_SEQS)
|
|
||||||
@pytest.mark.parametrize("num_heads", NUM_HEADS)
|
|
||||||
@pytest.mark.parametrize("head_size", [64])
|
|
||||||
@pytest.mark.parametrize("dtype", DTYPES)
|
|
||||||
@pytest.mark.parametrize("seed", SEEDS)
|
|
||||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
|
||||||
@pytest.mark.skipif(
|
|
||||||
current_platform.is_rocm(), reason="Xformers backend is not supported on ROCm."
|
|
||||||
)
|
|
||||||
@torch.inference_mode()
|
|
||||||
def test_multi_query_kv_attention_with_alibi(
|
|
||||||
num_seqs: int,
|
|
||||||
num_heads: tuple[int, int],
|
|
||||||
head_size: int,
|
|
||||||
dtype: torch.dtype,
|
|
||||||
seed: int,
|
|
||||||
device: str,
|
|
||||||
) -> None:
|
|
||||||
return test_multi_query_kv_attention(
|
|
||||||
num_seqs,
|
|
||||||
num_heads,
|
|
||||||
head_size,
|
|
||||||
dtype,
|
|
||||||
seed,
|
|
||||||
device,
|
|
||||||
use_alibi=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("attention_cls", [Attention, MultiHeadAttention])
|
@pytest.mark.parametrize("attention_cls", [Attention, MultiHeadAttention])
|
||||||
def test_num_heads_not_divisble_by_num_kv_heads(attention_cls: type) -> None:
|
def test_num_heads_not_divisble_by_num_kv_heads(attention_cls: type) -> None:
|
||||||
head_size = 64
|
head_size = 64
|
||||||
|
|||||||
@ -34,7 +34,7 @@ DEVICE_MLA_BACKENDS = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
DEVICE_REGULAR_ATTN_BACKENDS = {
|
DEVICE_REGULAR_ATTN_BACKENDS = {
|
||||||
"cuda": ["XFORMERS", "FLASHINFER", "FLASH_ATTN"],
|
"cuda": ["FLASHINFER", "FLASH_ATTN"],
|
||||||
"hip": ["ROCM_ATTN"],
|
"hip": ["ROCM_ATTN"],
|
||||||
"cpu": ["CPU_ATTN"],
|
"cpu": ["CPU_ATTN"],
|
||||||
}
|
}
|
||||||
@ -207,12 +207,6 @@ def test_env(
|
|||||||
)
|
)
|
||||||
expected = "FLASHINFER"
|
expected = "FLASHINFER"
|
||||||
assert backend.get_name() == expected
|
assert backend.get_name() == expected
|
||||||
elif name == "XFORMERS":
|
|
||||||
backend = get_attn_backend(
|
|
||||||
32, torch.float16, None, block_size, use_mla=use_mla
|
|
||||||
)
|
|
||||||
expected = "XFORMERS"
|
|
||||||
assert backend.get_name() == expected
|
|
||||||
elif name == "FLASH_ATTN":
|
elif name == "FLASH_ATTN":
|
||||||
backend = get_attn_backend(
|
backend = get_attn_backend(
|
||||||
32, torch.float16, None, block_size, use_mla=use_mla
|
32, torch.float16, None, block_size, use_mla=use_mla
|
||||||
|
|||||||
@ -24,10 +24,6 @@ from vllm.platforms.rocm import RocmPlatform
|
|||||||
def clear_cache():
|
def clear_cache():
|
||||||
"""Clear lru cache to ensure each test case runs without caching."""
|
"""Clear lru cache to ensure each test case runs without caching."""
|
||||||
_cached_get_attn_backend.cache_clear()
|
_cached_get_attn_backend.cache_clear()
|
||||||
# Clear xformers availability cache
|
|
||||||
import vllm.attention.layer as layer_module
|
|
||||||
|
|
||||||
layer_module.USE_XFORMERS_OPS = None
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("device", ["cpu", "hip", "cuda"])
|
@pytest.mark.parametrize("device", ["cpu", "hip", "cuda"])
|
||||||
|
|||||||
@ -509,43 +509,6 @@ def pack_qkv(qkv: QKVInputs, device: torch.device | str) -> PackedQKVInputs:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def make_alibi_bias(
|
|
||||||
alibi_slopes: torch.Tensor,
|
|
||||||
num_kv_heads: int,
|
|
||||||
dtype: torch.dtype,
|
|
||||||
seq_lens: list[int],
|
|
||||||
) -> list[Any]:
|
|
||||||
"""Create ALiBi biases compatible with xFormers attention tests."""
|
|
||||||
from xformers.ops.fmha.attn_bias import LowerTriangularMaskWithTensorBias
|
|
||||||
|
|
||||||
if alibi_slopes is None:
|
|
||||||
return [None for _ in seq_lens]
|
|
||||||
|
|
||||||
attn_biases: list[Any] = []
|
|
||||||
num_heads = alibi_slopes.shape[0]
|
|
||||||
assert num_heads >= num_kv_heads, (
|
|
||||||
"ALiBi slopes expect at least as many heads as KV heads"
|
|
||||||
)
|
|
||||||
|
|
||||||
for seq_len in seq_lens:
|
|
||||||
bias = torch.arange(seq_len, dtype=dtype, device=alibi_slopes.device)
|
|
||||||
bias = bias[None, :] - bias[:, None]
|
|
||||||
|
|
||||||
padded_len = (seq_len + 7) // 8 * 8
|
|
||||||
bias_tensor = torch.empty(
|
|
||||||
1,
|
|
||||||
num_heads,
|
|
||||||
seq_len,
|
|
||||||
padded_len,
|
|
||||||
device=alibi_slopes.device,
|
|
||||||
dtype=dtype,
|
|
||||||
)[:, :, :, :seq_len].copy_(bias)
|
|
||||||
bias_tensor.mul_(alibi_slopes[:, None, None])
|
|
||||||
attn_biases.append(LowerTriangularMaskWithTensorBias(bias_tensor))
|
|
||||||
|
|
||||||
return attn_biases
|
|
||||||
|
|
||||||
|
|
||||||
def _make_metadata_tensors(
|
def _make_metadata_tensors(
|
||||||
seq_lens: list[int] | None,
|
seq_lens: list[int] | None,
|
||||||
context_lens: list[int] | None,
|
context_lens: list[int] | None,
|
||||||
@ -649,23 +612,12 @@ def make_kv_cache(
|
|||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
|
|
||||||
* kv_cache: 2 x num_blocks x (block_size * num_heads * head_size)
|
|
||||||
* for backend 'XFORMERS'
|
|
||||||
* kv_cache: 2 x num_blocks x block_size x num_heads x head_size
|
* kv_cache: 2 x num_blocks x block_size x num_heads x head_size
|
||||||
* for backend 'FLASH_ATTN'
|
* for backend 'FLASH_ATTN'
|
||||||
"""
|
"""
|
||||||
if backend == "XFORMERS":
|
if backend != "FLASH_ATTN":
|
||||||
kv_cache = torch.rand((2, num_blocks, block_size * num_heads * head_size)).to(
|
raise ValueError(f"Unknown backend value: '{backend}'. Expected 'FLASH_ATTN'.")
|
||||||
device
|
kv_cache = torch.rand((2, num_blocks, block_size, num_heads, head_size)).to(device)
|
||||||
)
|
|
||||||
elif backend == "FLASH_ATTN":
|
|
||||||
kv_cache = torch.rand((2, num_blocks, block_size, num_heads, head_size)).to(
|
|
||||||
device
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
raise ValueError(
|
|
||||||
f"Unknown backend value: '{backend}'. Expected 'XFORMERS' or 'FLASH_ATTN'."
|
|
||||||
)
|
|
||||||
if default_val is not None:
|
if default_val is not None:
|
||||||
kv_cache[:, :, :] = default_val
|
kv_cache[:, :, :] = default_val
|
||||||
return kv_cache
|
return kv_cache
|
||||||
@ -843,22 +795,14 @@ def assert_actual_matches_ideal(
|
|||||||
* output_under_test: actually observed output value
|
* output_under_test: actually observed output value
|
||||||
"""
|
"""
|
||||||
ideal_output = test_params.packed_qkvo.ideal_output
|
ideal_output = test_params.packed_qkvo.ideal_output
|
||||||
if backend == "XFORMERS":
|
if backend != "FLASH_ATTN":
|
||||||
torch.testing.assert_close(
|
raise ValueError(f"Unknown backend value: '{backend}'. Expected 'FLASH_ATTN'.")
|
||||||
ideal_output, output_under_test.view_as(ideal_output)
|
# For FlashAttention override the accuracy thresholds to non default
|
||||||
)
|
# values since we notice a higher difference between the ideal and
|
||||||
|
# actual output.
|
||||||
elif backend == "FLASH_ATTN":
|
torch.testing.assert_close(
|
||||||
# For FlashAttention override the accuracy thresholds to non default
|
ideal_output, output_under_test.view_as(ideal_output), atol=0.01, rtol=0.016
|
||||||
# values since we notice a higher difference between the ideal and
|
)
|
||||||
# actual output.
|
|
||||||
torch.testing.assert_close(
|
|
||||||
ideal_output, output_under_test.view_as(ideal_output), atol=0.01, rtol=0.016
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
raise ValueError(
|
|
||||||
f"Unknown backend value: '{backend}'. Expected 'XFORMERS' or 'FLASH_ATTN'."
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# Copied/modified from torch._refs.__init__.py
|
# Copied/modified from torch._refs.__init__.py
|
||||||
|
|||||||
@ -57,10 +57,6 @@ def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> list[str]:
|
|||||||
return generated_texts
|
return generated_texts
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.xfail(
|
|
||||||
current_platform.is_rocm(),
|
|
||||||
reason="MiniCPM-V dependency xformers incompatible with ROCm",
|
|
||||||
)
|
|
||||||
def test_minicpmv_lora(minicpmv_lora_files):
|
def test_minicpmv_lora(minicpmv_lora_files):
|
||||||
llm = vllm.LLM(
|
llm = vllm.LLM(
|
||||||
MODEL_PATH,
|
MODEL_PATH,
|
||||||
@ -84,10 +80,6 @@ def test_minicpmv_lora(minicpmv_lora_files):
|
|||||||
@pytest.mark.skipif(
|
@pytest.mark.skipif(
|
||||||
current_platform.is_cuda_alike(), reason="Skipping to avoid redundant model tests"
|
current_platform.is_cuda_alike(), reason="Skipping to avoid redundant model tests"
|
||||||
)
|
)
|
||||||
@pytest.mark.xfail(
|
|
||||||
current_platform.is_rocm(),
|
|
||||||
reason="MiniCPM-V dependency xformers incompatible with ROCm",
|
|
||||||
)
|
|
||||||
@multi_gpu_test(num_gpus=4)
|
@multi_gpu_test(num_gpus=4)
|
||||||
def test_minicpmv_tp4_wo_fully_sharded_loras(minicpmv_lora_files):
|
def test_minicpmv_tp4_wo_fully_sharded_loras(minicpmv_lora_files):
|
||||||
llm = vllm.LLM(
|
llm = vllm.LLM(
|
||||||
@ -108,10 +100,6 @@ def test_minicpmv_tp4_wo_fully_sharded_loras(minicpmv_lora_files):
|
|||||||
@pytest.mark.skipif(
|
@pytest.mark.skipif(
|
||||||
current_platform.is_cuda_alike(), reason="Skipping to avoid redundant model tests"
|
current_platform.is_cuda_alike(), reason="Skipping to avoid redundant model tests"
|
||||||
)
|
)
|
||||||
@pytest.mark.xfail(
|
|
||||||
current_platform.is_rocm(),
|
|
||||||
reason="MiniCPM-V dependency xformers incompatible with ROCm",
|
|
||||||
)
|
|
||||||
@multi_gpu_test(num_gpus=4)
|
@multi_gpu_test(num_gpus=4)
|
||||||
def test_minicpmv_tp4_fully_sharded_loras(minicpmv_lora_files):
|
def test_minicpmv_tp4_fully_sharded_loras(minicpmv_lora_files):
|
||||||
llm = vllm.LLM(
|
llm = vllm.LLM(
|
||||||
|
|||||||
@ -2,12 +2,9 @@
|
|||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
import vllm
|
import vllm
|
||||||
from vllm.assets.image import ImageAsset
|
from vllm.assets.image import ImageAsset
|
||||||
from vllm.lora.request import LoRARequest
|
from vllm.lora.request import LoRARequest
|
||||||
from vllm.platforms import current_platform
|
|
||||||
from vllm.sampling_params import BeamSearchParams
|
from vllm.sampling_params import BeamSearchParams
|
||||||
|
|
||||||
|
|
||||||
@ -142,10 +139,6 @@ QWEN2VL_MODEL_PATH = "Qwen/Qwen2-VL-2B-Instruct"
|
|||||||
QWEN25VL_MODEL_PATH = "Qwen/Qwen2.5-VL-3B-Instruct"
|
QWEN25VL_MODEL_PATH = "Qwen/Qwen2.5-VL-3B-Instruct"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.xfail(
|
|
||||||
current_platform.is_rocm(),
|
|
||||||
reason="Qwen2-VL dependency xformers incompatible with ROCm",
|
|
||||||
)
|
|
||||||
def test_qwen2vl_lora(qwen2vl_lora_files):
|
def test_qwen2vl_lora(qwen2vl_lora_files):
|
||||||
"""Test Qwen 2.0 VL model with LoRA"""
|
"""Test Qwen 2.0 VL model with LoRA"""
|
||||||
config = TestConfig(model_path=QWEN2VL_MODEL_PATH, lora_path=qwen2vl_lora_files)
|
config = TestConfig(model_path=QWEN2VL_MODEL_PATH, lora_path=qwen2vl_lora_files)
|
||||||
@ -156,10 +149,6 @@ def test_qwen2vl_lora(qwen2vl_lora_files):
|
|||||||
tester.run_test(TEST_IMAGES, expected_outputs=EXPECTED_OUTPUTS, lora_id=lora_id)
|
tester.run_test(TEST_IMAGES, expected_outputs=EXPECTED_OUTPUTS, lora_id=lora_id)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.xfail(
|
|
||||||
current_platform.is_rocm(),
|
|
||||||
reason="Qwen2-VL dependency xformers incompatible with ROCm",
|
|
||||||
)
|
|
||||||
def test_qwen2vl_lora_beam_search(qwen2vl_lora_files):
|
def test_qwen2vl_lora_beam_search(qwen2vl_lora_files):
|
||||||
"""Test Qwen 2.0 VL model with LoRA through beam search."""
|
"""Test Qwen 2.0 VL model with LoRA through beam search."""
|
||||||
config = TestConfig(model_path=QWEN2VL_MODEL_PATH, lora_path=qwen2vl_lora_files)
|
config = TestConfig(model_path=QWEN2VL_MODEL_PATH, lora_path=qwen2vl_lora_files)
|
||||||
@ -178,10 +167,6 @@ def test_qwen2vl_lora_beam_search(qwen2vl_lora_files):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.xfail(
|
|
||||||
current_platform.is_rocm(),
|
|
||||||
reason="Qwen2.5-VL dependency xformers incompatible with ROCm",
|
|
||||||
)
|
|
||||||
def test_qwen25vl_lora(qwen25vl_lora_files):
|
def test_qwen25vl_lora(qwen25vl_lora_files):
|
||||||
"""Test Qwen 2.5 VL model with LoRA"""
|
"""Test Qwen 2.5 VL model with LoRA"""
|
||||||
config = TestConfig(model_path=QWEN25VL_MODEL_PATH, lora_path=qwen25vl_lora_files)
|
config = TestConfig(model_path=QWEN25VL_MODEL_PATH, lora_path=qwen25vl_lora_files)
|
||||||
|
|||||||
@ -43,7 +43,6 @@ class AttentionBackendEnum(Enum, metaclass=_AttentionBackendEnumMeta):
|
|||||||
|
|
||||||
FLASH_ATTN = "vllm.v1.attention.backends.flash_attn.FlashAttentionBackend"
|
FLASH_ATTN = "vllm.v1.attention.backends.flash_attn.FlashAttentionBackend"
|
||||||
TRITON_ATTN = "vllm.v1.attention.backends.triton_attn.TritonAttentionBackend"
|
TRITON_ATTN = "vllm.v1.attention.backends.triton_attn.TritonAttentionBackend"
|
||||||
XFORMERS = "vllm.v1.attention.backends.xformers.XFormersAttentionBackend"
|
|
||||||
ROCM_ATTN = "vllm.v1.attention.backends.rocm_attn.RocmAttentionBackend"
|
ROCM_ATTN = "vllm.v1.attention.backends.rocm_attn.RocmAttentionBackend"
|
||||||
ROCM_AITER_MLA = "vllm.v1.attention.backends.mla.rocm_aiter_mla.AiterMLABackend"
|
ROCM_AITER_MLA = "vllm.v1.attention.backends.mla.rocm_aiter_mla.AiterMLABackend"
|
||||||
ROCM_AITER_TRITON_MLA = (
|
ROCM_AITER_TRITON_MLA = (
|
||||||
|
|||||||
@ -51,31 +51,6 @@ else:
|
|||||||
|
|
||||||
FP8_DTYPE = current_platform.fp8_dtype()
|
FP8_DTYPE = current_platform.fp8_dtype()
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
USE_XFORMERS_OPS = None
|
|
||||||
|
|
||||||
|
|
||||||
def check_xformers_availability():
|
|
||||||
global USE_XFORMERS_OPS
|
|
||||||
if USE_XFORMERS_OPS is not None:
|
|
||||||
return USE_XFORMERS_OPS
|
|
||||||
|
|
||||||
if current_platform.is_cuda() and current_platform.has_device_capability(100):
|
|
||||||
# Xformers FA is not compatible with B200
|
|
||||||
USE_XFORMERS_OPS = False
|
|
||||||
else:
|
|
||||||
try:
|
|
||||||
from importlib.util import find_spec
|
|
||||||
|
|
||||||
find_spec("xformers.ops")
|
|
||||||
USE_XFORMERS_OPS = True
|
|
||||||
except ImportError:
|
|
||||||
USE_XFORMERS_OPS = False
|
|
||||||
|
|
||||||
# the warning only needs to be shown once
|
|
||||||
if not USE_XFORMERS_OPS:
|
|
||||||
logger.warning("Xformers is not available, falling back.")
|
|
||||||
|
|
||||||
return USE_XFORMERS_OPS
|
|
||||||
|
|
||||||
|
|
||||||
def check_upstream_fa_availability(dtype: torch.dtype):
|
def check_upstream_fa_availability(dtype: torch.dtype):
|
||||||
@ -533,7 +508,6 @@ class MultiHeadAttention(nn.Module):
|
|||||||
if backend
|
if backend
|
||||||
in {
|
in {
|
||||||
AttentionBackendEnum.TORCH_SDPA,
|
AttentionBackendEnum.TORCH_SDPA,
|
||||||
AttentionBackendEnum.XFORMERS,
|
|
||||||
AttentionBackendEnum.PALLAS,
|
AttentionBackendEnum.PALLAS,
|
||||||
AttentionBackendEnum.ROCM_AITER_FA,
|
AttentionBackendEnum.ROCM_AITER_FA,
|
||||||
AttentionBackendEnum.FLASH_ATTN,
|
AttentionBackendEnum.FLASH_ATTN,
|
||||||
@ -549,12 +523,6 @@ class MultiHeadAttention(nn.Module):
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
if (
|
|
||||||
self.attn_backend == AttentionBackendEnum.XFORMERS
|
|
||||||
and not check_xformers_availability()
|
|
||||||
):
|
|
||||||
self.attn_backend = AttentionBackendEnum.TORCH_SDPA
|
|
||||||
|
|
||||||
self.is_flash_attn_backend = self.attn_backend in {
|
self.is_flash_attn_backend = self.attn_backend in {
|
||||||
AttentionBackendEnum.FLASH_ATTN,
|
AttentionBackendEnum.FLASH_ATTN,
|
||||||
AttentionBackendEnum.ROCM_AITER_FA,
|
AttentionBackendEnum.ROCM_AITER_FA,
|
||||||
@ -614,12 +582,6 @@ class MultiHeadAttention(nn.Module):
|
|||||||
max_seqlen_k=kv_len,
|
max_seqlen_k=kv_len,
|
||||||
softmax_scale=self.scale,
|
softmax_scale=self.scale,
|
||||||
)
|
)
|
||||||
elif self.attn_backend == AttentionBackendEnum.XFORMERS:
|
|
||||||
from xformers import ops as xops
|
|
||||||
|
|
||||||
out = xops.memory_efficient_attention_forward(
|
|
||||||
query, key, value, scale=self.scale
|
|
||||||
)
|
|
||||||
elif self.attn_backend == AttentionBackendEnum.TORCH_SDPA:
|
elif self.attn_backend == AttentionBackendEnum.TORCH_SDPA:
|
||||||
query, key, value = (x.transpose(1, 2) for x in (query, key, value))
|
query, key, value = (x.transpose(1, 2) for x in (query, key, value))
|
||||||
out = F.scaled_dot_product_attention(query, key, value, scale=self.scale)
|
out = F.scaled_dot_product_attention(query, key, value, scale=self.scale)
|
||||||
|
|||||||
@ -3,7 +3,7 @@
|
|||||||
"""
|
"""
|
||||||
This file contains ops for ViT attention to be compatible with torch.compile
|
This file contains ops for ViT attention to be compatible with torch.compile
|
||||||
as there are operations here not supported by torch.compile (for instance,
|
as there are operations here not supported by torch.compile (for instance,
|
||||||
`to_list` in xformers attn, or `.item()` in flash attention)
|
`.item()` in flash attention)
|
||||||
|
|
||||||
Using these ops and wrapping vision blocks with `torch.compile` can speed up
|
Using these ops and wrapping vision blocks with `torch.compile` can speed up
|
||||||
throughput in vision models by ~5% relative on H100, and improve token
|
throughput in vision models by ~5% relative on H100, and improve token
|
||||||
@ -19,42 +19,6 @@ import torch.nn.functional as F
|
|||||||
from vllm.utils.torch_utils import direct_register_custom_op
|
from vllm.utils.torch_utils import direct_register_custom_op
|
||||||
|
|
||||||
|
|
||||||
def xformers_attn_seqlens_wrapper(
|
|
||||||
q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, seqlens: torch.Tensor
|
|
||||||
) -> torch.Tensor:
|
|
||||||
from xformers import ops as xops
|
|
||||||
from xformers.ops.fmha.attn_bias import BlockDiagonalMask
|
|
||||||
|
|
||||||
attn_bias = BlockDiagonalMask.from_seqlens(
|
|
||||||
q_seqlen=seqlens.tolist(), kv_seqlen=None, device=q.device
|
|
||||||
)
|
|
||||||
context_layer = xops.memory_efficient_attention_forward(
|
|
||||||
q, k, v, attn_bias=attn_bias, p=0, scale=None
|
|
||||||
)
|
|
||||||
context_layer = einops.rearrange(context_layer, "b s h d -> s b (h d)").contiguous()
|
|
||||||
return context_layer
|
|
||||||
|
|
||||||
|
|
||||||
def xformers_attn_seqlens_wrapper_fake(
|
|
||||||
q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, seqlens: torch.Tensor
|
|
||||||
) -> torch.Tensor:
|
|
||||||
b, s, h, d = q.shape
|
|
||||||
return torch.empty((s, b, h * d), dtype=q.dtype, device=q.device)
|
|
||||||
|
|
||||||
|
|
||||||
direct_register_custom_op(
|
|
||||||
op_name="xformers_attn_seqlens_wrapper",
|
|
||||||
op_func=xformers_attn_seqlens_wrapper,
|
|
||||||
fake_impl=xformers_attn_seqlens_wrapper_fake,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def vit_xformers_attn_wrapper(
|
|
||||||
q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, seqlens: torch.Tensor
|
|
||||||
) -> torch.Tensor:
|
|
||||||
return torch.ops.vllm.xformers_attn_seqlens_wrapper(q, k, v, seqlens)
|
|
||||||
|
|
||||||
|
|
||||||
def flash_attn_maxseqlen_wrapper(
|
def flash_attn_maxseqlen_wrapper(
|
||||||
q: torch.Tensor,
|
q: torch.Tensor,
|
||||||
k: torch.Tensor,
|
k: torch.Tensor,
|
||||||
|
|||||||
@ -36,7 +36,14 @@ def get_env_variable_attn_backend() -> AttentionBackendEnum | None:
|
|||||||
* None otherwise
|
* None otherwise
|
||||||
"""
|
"""
|
||||||
backend_name = os.environ.get(STR_BACKEND_ENV_VAR)
|
backend_name = os.environ.get(STR_BACKEND_ENV_VAR)
|
||||||
return None if backend_name is None else AttentionBackendEnum[backend_name]
|
if backend_name is None:
|
||||||
|
return None
|
||||||
|
if backend_name == "XFORMERS":
|
||||||
|
raise ValueError(
|
||||||
|
"Attention backend 'XFORMERS' has been removed (See PR #29262 for "
|
||||||
|
"details). Please select a supported attention backend."
|
||||||
|
)
|
||||||
|
return AttentionBackendEnum[backend_name]
|
||||||
|
|
||||||
|
|
||||||
# Global state allows a particular choice of backend
|
# Global state allows a particular choice of backend
|
||||||
|
|||||||
@ -173,6 +173,12 @@ class MultiModalConfig:
|
|||||||
# We need to import the real type here (deferred to avoid circular import).
|
# We need to import the real type here (deferred to avoid circular import).
|
||||||
from vllm.attention.backends.registry import AttentionBackendEnum
|
from vllm.attention.backends.registry import AttentionBackendEnum
|
||||||
|
|
||||||
|
if isinstance(value, str) and value.upper() == "XFORMERS":
|
||||||
|
raise ValueError(
|
||||||
|
"Attention backend 'XFORMERS' has been removed (See PR #29262 for "
|
||||||
|
"details). Please select a supported attention backend."
|
||||||
|
)
|
||||||
|
|
||||||
if value is None or isinstance(value, AttentionBackendEnum):
|
if value is None or isinstance(value, AttentionBackendEnum):
|
||||||
return value
|
return value
|
||||||
|
|
||||||
|
|||||||
@ -640,7 +640,6 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
|||||||
# Example options:
|
# Example options:
|
||||||
# - "TORCH_SDPA": use torch.nn.MultiheadAttention
|
# - "TORCH_SDPA": use torch.nn.MultiheadAttention
|
||||||
# - "FLASH_ATTN": use FlashAttention
|
# - "FLASH_ATTN": use FlashAttention
|
||||||
# - "XFORMERS": use XFormers
|
|
||||||
# - "FLASHINFER": use flashinfer
|
# - "FLASHINFER": use flashinfer
|
||||||
# - "FLASHMLA": use FlashMLA
|
# - "FLASHMLA": use FlashMLA
|
||||||
# - "FLASH_ATTN_MLA": use FlashAttention for MLA
|
# - "FLASH_ATTN_MLA": use FlashAttention for MLA
|
||||||
|
|||||||
@ -306,7 +306,6 @@ class DotsVisionAttention(nn.Module):
|
|||||||
if self.attn_backend not in {
|
if self.attn_backend not in {
|
||||||
AttentionBackendEnum.FLASH_ATTN,
|
AttentionBackendEnum.FLASH_ATTN,
|
||||||
AttentionBackendEnum.TORCH_SDPA,
|
AttentionBackendEnum.TORCH_SDPA,
|
||||||
AttentionBackendEnum.XFORMERS,
|
|
||||||
AttentionBackendEnum.ROCM_AITER_FA,
|
AttentionBackendEnum.ROCM_AITER_FA,
|
||||||
}:
|
}:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
@ -324,7 +323,6 @@ class DotsVisionAttention(nn.Module):
|
|||||||
rotary_pos_emb: torch.Tensor | None = None,
|
rotary_pos_emb: torch.Tensor | None = None,
|
||||||
*,
|
*,
|
||||||
max_seqlen: int | None = None,
|
max_seqlen: int | None = None,
|
||||||
seqlens: list[int] | None = None,
|
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
# [S, C] -> [S, B=1, C]
|
# [S, C] -> [S, B=1, C]
|
||||||
x = hidden_states.unsqueeze(1)
|
x = hidden_states.unsqueeze(1)
|
||||||
@ -374,16 +372,6 @@ class DotsVisionAttention(nn.Module):
|
|||||||
out_i = out_i.permute(0, 2, 1, 3)
|
out_i = out_i.permute(0, 2, 1, 3)
|
||||||
outputs.append(out_i)
|
outputs.append(out_i)
|
||||||
context_layer = torch.cat(outputs, dim=1) if outputs else q[:, :0]
|
context_layer = torch.cat(outputs, dim=1) if outputs else q[:, :0]
|
||||||
elif self.attn_backend == AttentionBackendEnum.XFORMERS:
|
|
||||||
from xformers import ops as xops
|
|
||||||
from xformers.ops.fmha.attn_bias import BlockDiagonalMask
|
|
||||||
|
|
||||||
attn_bias = BlockDiagonalMask.from_seqlens(
|
|
||||||
q_seqlen=seqlens, kv_seqlen=None, device=q.device
|
|
||||||
)
|
|
||||||
context_layer = xops.memory_efficient_attention_forward(
|
|
||||||
q, k, v, attn_bias=attn_bias, p=0, scale=None
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
raise RuntimeError("Unsupported attention backend")
|
raise RuntimeError("Unsupported attention backend")
|
||||||
|
|
||||||
@ -545,14 +533,12 @@ class DotsVisionBlock(nn.Module):
|
|||||||
cu_seqlens: torch.Tensor,
|
cu_seqlens: torch.Tensor,
|
||||||
rotary_pos_emb: torch.Tensor,
|
rotary_pos_emb: torch.Tensor,
|
||||||
max_seqlen: int | None = None,
|
max_seqlen: int | None = None,
|
||||||
seqlens: list[int] | None = None,
|
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
hidden_states = hidden_states + self.attn(
|
hidden_states = hidden_states + self.attn(
|
||||||
self.norm1(hidden_states),
|
self.norm1(hidden_states),
|
||||||
cu_seqlens=cu_seqlens,
|
cu_seqlens=cu_seqlens,
|
||||||
rotary_pos_emb=rotary_pos_emb,
|
rotary_pos_emb=rotary_pos_emb,
|
||||||
max_seqlen=max_seqlen,
|
max_seqlen=max_seqlen,
|
||||||
seqlens=seqlens,
|
|
||||||
)
|
)
|
||||||
hidden_states = hidden_states + self.mlp(self.norm2(hidden_states))
|
hidden_states = hidden_states + self.mlp(self.norm2(hidden_states))
|
||||||
return hidden_states
|
return hidden_states
|
||||||
@ -663,18 +649,14 @@ class DotsVisionTransformer(nn.Module):
|
|||||||
rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
|
rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
|
||||||
return rotary_pos_emb
|
return rotary_pos_emb
|
||||||
|
|
||||||
def compute_attn_mask_seqlen(
|
def compute_attn_mask_seqlen(self, cu_seqlens: torch.Tensor) -> int | None:
|
||||||
self, cu_seqlens: torch.Tensor
|
max_seqlen = None
|
||||||
) -> tuple[int | None, list[int] | None]:
|
|
||||||
max_seqlen, seqlens = None, None
|
|
||||||
if (
|
if (
|
||||||
self.attn_backend == AttentionBackendEnum.FLASH_ATTN
|
self.attn_backend == AttentionBackendEnum.FLASH_ATTN
|
||||||
or self.attn_backend == AttentionBackendEnum.ROCM_AITER_FA
|
or self.attn_backend == AttentionBackendEnum.ROCM_AITER_FA
|
||||||
):
|
):
|
||||||
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
|
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
|
||||||
elif self.attn_backend == AttentionBackendEnum.XFORMERS:
|
return max_seqlen
|
||||||
seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
|
|
||||||
return max_seqlen, seqlens
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self, hidden_states: torch.Tensor, grid_thw: list[list[int]]
|
self, hidden_states: torch.Tensor, grid_thw: list[list[int]]
|
||||||
@ -694,14 +676,13 @@ class DotsVisionTransformer(nn.Module):
|
|||||||
)
|
)
|
||||||
cu_seqlens = torch.cat([cu_seqlens.new_zeros(1), cu_seqlens])
|
cu_seqlens = torch.cat([cu_seqlens.new_zeros(1), cu_seqlens])
|
||||||
|
|
||||||
max_seqlen, seqlens = self.compute_attn_mask_seqlen(cu_seqlens)
|
max_seqlen = self.compute_attn_mask_seqlen(cu_seqlens)
|
||||||
for blk in self.blocks:
|
for blk in self.blocks:
|
||||||
hidden_states = blk(
|
hidden_states = blk(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
cu_seqlens=cu_seqlens,
|
cu_seqlens=cu_seqlens,
|
||||||
rotary_pos_emb=rotary_pos_emb,
|
rotary_pos_emb=rotary_pos_emb,
|
||||||
max_seqlen=max_seqlen,
|
max_seqlen=max_seqlen,
|
||||||
seqlens=seqlens,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.post_trunk_norm is not None:
|
if self.post_trunk_norm is not None:
|
||||||
|
|||||||
@ -214,7 +214,6 @@ class Ernie4_5_VisionAttention(nn.Module):
|
|||||||
if self.attn_backend not in {
|
if self.attn_backend not in {
|
||||||
AttentionBackendEnum.FLASH_ATTN,
|
AttentionBackendEnum.FLASH_ATTN,
|
||||||
AttentionBackendEnum.TORCH_SDPA,
|
AttentionBackendEnum.TORCH_SDPA,
|
||||||
AttentionBackendEnum.XFORMERS,
|
|
||||||
AttentionBackendEnum.ROCM_AITER_FA,
|
AttentionBackendEnum.ROCM_AITER_FA,
|
||||||
}:
|
}:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
@ -259,7 +258,6 @@ class Ernie4_5_VisionAttention(nn.Module):
|
|||||||
cu_seqlens: torch.Tensor,
|
cu_seqlens: torch.Tensor,
|
||||||
rotary_pos_emb: torch.Tensor,
|
rotary_pos_emb: torch.Tensor,
|
||||||
max_seqlen: int | None = None, # Only used for Flash Attention
|
max_seqlen: int | None = None, # Only used for Flash Attention
|
||||||
seqlens: list[int] | None = None, # Only used for xFormers
|
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
# [s, b, c] --> [s, b, head * 3 * head_dim]
|
# [s, b, c] --> [s, b, head * 3 * head_dim]
|
||||||
x, _ = self.qkv(x)
|
x, _ = self.qkv(x)
|
||||||
@ -311,20 +309,6 @@ class Ernie4_5_VisionAttention(nn.Module):
|
|||||||
context_layer = rearrange(
|
context_layer = rearrange(
|
||||||
context_layer, "b s h d -> s b (h d)"
|
context_layer, "b s h d -> s b (h d)"
|
||||||
).contiguous()
|
).contiguous()
|
||||||
elif self.attn_backend == AttentionBackendEnum.XFORMERS:
|
|
||||||
from xformers import ops as xops
|
|
||||||
from xformers.ops.fmha.attn_bias import BlockDiagonalMask
|
|
||||||
|
|
||||||
attn_bias = BlockDiagonalMask.from_seqlens(
|
|
||||||
q_seqlen=seqlens, kv_seqlen=None, device=q.device
|
|
||||||
)
|
|
||||||
|
|
||||||
context_layer = xops.memory_efficient_attention_forward(
|
|
||||||
q, k, v, attn_bias=attn_bias, p=0, scale=None
|
|
||||||
)
|
|
||||||
context_layer = rearrange(
|
|
||||||
context_layer, "b s h d -> s b (h d)"
|
|
||||||
).contiguous()
|
|
||||||
|
|
||||||
output, _ = self.proj(context_layer)
|
output, _ = self.proj(context_layer)
|
||||||
return output
|
return output
|
||||||
@ -404,14 +388,12 @@ class Ernie4_5_VisionBlock(nn.Module):
|
|||||||
cu_seqlens: torch.Tensor,
|
cu_seqlens: torch.Tensor,
|
||||||
rotary_pos_emb: torch.Tensor,
|
rotary_pos_emb: torch.Tensor,
|
||||||
max_seqlen: int | None = None, # Only used for Flash Attention
|
max_seqlen: int | None = None, # Only used for Flash Attention
|
||||||
seqlens: list[int] | None = None, # Only used for xFormers
|
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
hidden_states = hidden_states + self.attn(
|
hidden_states = hidden_states + self.attn(
|
||||||
self.norm1(hidden_states),
|
self.norm1(hidden_states),
|
||||||
cu_seqlens=cu_seqlens,
|
cu_seqlens=cu_seqlens,
|
||||||
rotary_pos_emb=rotary_pos_emb,
|
rotary_pos_emb=rotary_pos_emb,
|
||||||
max_seqlen=max_seqlen,
|
max_seqlen=max_seqlen,
|
||||||
seqlens=seqlens,
|
|
||||||
)
|
)
|
||||||
hidden_states = hidden_states + self.mlp(self.norm2(hidden_states))
|
hidden_states = hidden_states + self.mlp(self.norm2(hidden_states))
|
||||||
return hidden_states
|
return hidden_states
|
||||||
@ -562,18 +544,14 @@ class Ernie4_5_VisionTransformer(nn.Module):
|
|||||||
rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
|
rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
|
||||||
return rotary_pos_emb
|
return rotary_pos_emb
|
||||||
|
|
||||||
def compute_attn_mask_seqlen(
|
def compute_attn_mask_seqlen(self, cu_seqlens: torch.Tensor) -> int | None:
|
||||||
self, cu_seqlens: torch.Tensor
|
max_seqlen = None
|
||||||
) -> tuple[int | None, list[int] | None]:
|
|
||||||
max_seqlen, seqlens = None, None
|
|
||||||
if (
|
if (
|
||||||
self.attn_backend == AttentionBackendEnum.FLASH_ATTN
|
self.attn_backend == AttentionBackendEnum.FLASH_ATTN
|
||||||
or self.attn_backend == AttentionBackendEnum.ROCM_AITER_FA
|
or self.attn_backend == AttentionBackendEnum.ROCM_AITER_FA
|
||||||
):
|
):
|
||||||
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
|
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
|
||||||
elif self.attn_backend == AttentionBackendEnum.XFORMERS:
|
return max_seqlen
|
||||||
seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
|
|
||||||
return max_seqlen, seqlens
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self, hidden_states: torch.Tensor, grid_thw: torch.Tensor, num_pad=0
|
self, hidden_states: torch.Tensor, grid_thw: torch.Tensor, num_pad=0
|
||||||
@ -598,8 +576,8 @@ class Ernie4_5_VisionTransformer(nn.Module):
|
|||||||
if hidden_states.ndim == 2:
|
if hidden_states.ndim == 2:
|
||||||
hidden_states = hidden_states.unsqueeze(dim=1)
|
hidden_states = hidden_states.unsqueeze(dim=1)
|
||||||
|
|
||||||
# pre-compute seqlens for attn mask to reduce cuMemcpy operations
|
# pre-compute max_seqlen for attn mask to reduce cuMemcpy operations
|
||||||
max_seqlen, seqlens = self.compute_attn_mask_seqlen(cu_seqlens)
|
max_seqlen = self.compute_attn_mask_seqlen(cu_seqlens)
|
||||||
|
|
||||||
for i, blk in enumerate(self.blocks):
|
for i, blk in enumerate(self.blocks):
|
||||||
hidden_states = blk(
|
hidden_states = blk(
|
||||||
@ -607,7 +585,6 @@ class Ernie4_5_VisionTransformer(nn.Module):
|
|||||||
cu_seqlens=cu_seqlens,
|
cu_seqlens=cu_seqlens,
|
||||||
rotary_pos_emb=rotary_pos_emb,
|
rotary_pos_emb=rotary_pos_emb,
|
||||||
max_seqlen=max_seqlen,
|
max_seqlen=max_seqlen,
|
||||||
seqlens=seqlens,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
final_output = self.ln(hidden_states)
|
final_output = self.ln(hidden_states)
|
||||||
|
|||||||
@ -309,7 +309,6 @@ class Glm4vVisionAttention(nn.Module):
|
|||||||
if self.attn_backend not in {
|
if self.attn_backend not in {
|
||||||
AttentionBackendEnum.FLASH_ATTN,
|
AttentionBackendEnum.FLASH_ATTN,
|
||||||
AttentionBackendEnum.TORCH_SDPA,
|
AttentionBackendEnum.TORCH_SDPA,
|
||||||
AttentionBackendEnum.XFORMERS,
|
|
||||||
AttentionBackendEnum.ROCM_AITER_FA,
|
AttentionBackendEnum.ROCM_AITER_FA,
|
||||||
}:
|
}:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
@ -345,7 +344,6 @@ class Glm4vVisionAttention(nn.Module):
|
|||||||
rotary_pos_emb_cos: torch.Tensor,
|
rotary_pos_emb_cos: torch.Tensor,
|
||||||
rotary_pos_emb_sin: torch.Tensor,
|
rotary_pos_emb_sin: torch.Tensor,
|
||||||
max_seqlen: int | None = None, # Only used for Flash Attention
|
max_seqlen: int | None = None, # Only used for Flash Attention
|
||||||
seqlens: list[int] | None = None, # Only used for xFormers
|
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
# [s, b, c] --> [s, b, head * 3 * head_dim]
|
# [s, b, c] --> [s, b, head * 3 * head_dim]
|
||||||
x, _ = self.qkv(x)
|
x, _ = self.qkv(x)
|
||||||
@ -400,20 +398,6 @@ class Glm4vVisionAttention(nn.Module):
|
|||||||
context_layer = rearrange(
|
context_layer = rearrange(
|
||||||
context_layer, "b s h d -> s b (h d)"
|
context_layer, "b s h d -> s b (h d)"
|
||||||
).contiguous()
|
).contiguous()
|
||||||
elif self.attn_backend == AttentionBackendEnum.XFORMERS:
|
|
||||||
from xformers import ops as xops
|
|
||||||
from xformers.ops.fmha.attn_bias import BlockDiagonalMask
|
|
||||||
|
|
||||||
attn_bias = BlockDiagonalMask.from_seqlens(
|
|
||||||
q_seqlen=seqlens, kv_seqlen=None, device=q.device
|
|
||||||
)
|
|
||||||
|
|
||||||
context_layer = xops.memory_efficient_attention_forward(
|
|
||||||
q, k, v, attn_bias=attn_bias, p=0, scale=None
|
|
||||||
)
|
|
||||||
context_layer = rearrange(
|
|
||||||
context_layer, "b s h d -> s b (h d)"
|
|
||||||
).contiguous()
|
|
||||||
|
|
||||||
output, _ = self.proj(context_layer)
|
output, _ = self.proj(context_layer)
|
||||||
return output
|
return output
|
||||||
@ -461,7 +445,6 @@ class Glm4vVisionBlock(nn.Module):
|
|||||||
rotary_pos_emb_cos: torch.Tensor,
|
rotary_pos_emb_cos: torch.Tensor,
|
||||||
rotary_pos_emb_sin: torch.Tensor,
|
rotary_pos_emb_sin: torch.Tensor,
|
||||||
max_seqlen: int | None = None, # Only used for Flash Attention
|
max_seqlen: int | None = None, # Only used for Flash Attention
|
||||||
seqlens: list[int] | None = None, # Only used for xFormers
|
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
x_attn = self.attn(
|
x_attn = self.attn(
|
||||||
self.norm1(x),
|
self.norm1(x),
|
||||||
@ -469,7 +452,6 @@ class Glm4vVisionBlock(nn.Module):
|
|||||||
rotary_pos_emb_cos=rotary_pos_emb_cos,
|
rotary_pos_emb_cos=rotary_pos_emb_cos,
|
||||||
rotary_pos_emb_sin=rotary_pos_emb_sin,
|
rotary_pos_emb_sin=rotary_pos_emb_sin,
|
||||||
max_seqlen=max_seqlen,
|
max_seqlen=max_seqlen,
|
||||||
seqlens=seqlens,
|
|
||||||
)
|
)
|
||||||
x_fused_norm, residual = self.norm2(x, residual=x_attn)
|
x_fused_norm, residual = self.norm2(x, residual=x_attn)
|
||||||
x = residual + self.mlp(x_fused_norm)
|
x = residual + self.mlp(x_fused_norm)
|
||||||
@ -803,15 +785,14 @@ class Glm4vVisionTransformer(nn.Module):
|
|||||||
def compute_attn_mask_seqlen(
|
def compute_attn_mask_seqlen(
|
||||||
self,
|
self,
|
||||||
cu_seqlens: torch.Tensor,
|
cu_seqlens: torch.Tensor,
|
||||||
) -> tuple[int | None, list[int] | None]:
|
) -> int | None:
|
||||||
max_seqlen, seqlens = None, None
|
max_seqlen = None
|
||||||
seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
|
|
||||||
if (
|
if (
|
||||||
self.attn_backend == AttentionBackendEnum.FLASH_ATTN
|
self.attn_backend == AttentionBackendEnum.FLASH_ATTN
|
||||||
or self.attn_backend == AttentionBackendEnum.ROCM_AITER_FA
|
or self.attn_backend == AttentionBackendEnum.ROCM_AITER_FA
|
||||||
):
|
):
|
||||||
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
|
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
|
||||||
return max_seqlen, seqlens
|
return max_seqlen
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -836,8 +817,9 @@ class Glm4vVisionTransformer(nn.Module):
|
|||||||
).cumsum(dim=0, dtype=torch.int32)
|
).cumsum(dim=0, dtype=torch.int32)
|
||||||
cu_seqlens = F.pad(cu_seqlens, (1, 0), "constant", 0)
|
cu_seqlens = F.pad(cu_seqlens, (1, 0), "constant", 0)
|
||||||
|
|
||||||
# pre-compute seqlens for attn mask to reduce cuMemcpy operations
|
# pre-compute max_seqlen for attn mask to reduce cuMemcpy operations
|
||||||
max_seqlen, seqlens = self.compute_attn_mask_seqlen(cu_seqlens)
|
max_seqlen = self.compute_attn_mask_seqlen(cu_seqlens)
|
||||||
|
seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
|
||||||
x = self.embeddings(
|
x = self.embeddings(
|
||||||
x, seqlens, grid_thw, image_type_ids[:, 0], image_type_ids[:, 1]
|
x, seqlens, grid_thw, image_type_ids[:, 0], image_type_ids[:, 1]
|
||||||
)
|
)
|
||||||
@ -851,7 +833,6 @@ class Glm4vVisionTransformer(nn.Module):
|
|||||||
rotary_pos_emb_cos=rotary_pos_emb_cos,
|
rotary_pos_emb_cos=rotary_pos_emb_cos,
|
||||||
rotary_pos_emb_sin=rotary_pos_emb_sin,
|
rotary_pos_emb_sin=rotary_pos_emb_sin,
|
||||||
max_seqlen=max_seqlen,
|
max_seqlen=max_seqlen,
|
||||||
seqlens=seqlens,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# adapter
|
# adapter
|
||||||
|
|||||||
@ -9,6 +9,7 @@ from typing import Annotated, Any, Literal, TypeAlias, TypeVar
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
from einops import rearrange
|
from einops import rearrange
|
||||||
from transformers import PretrainedConfig
|
from transformers import PretrainedConfig
|
||||||
from transformers.activations import GELUActivation
|
from transformers.activations import GELUActivation
|
||||||
@ -424,7 +425,7 @@ class KeyeSiglipAttention(nn.Module):
|
|||||||
|
|
||||||
if self.attn_backend not in {
|
if self.attn_backend not in {
|
||||||
AttentionBackendEnum.FLASH_ATTN,
|
AttentionBackendEnum.FLASH_ATTN,
|
||||||
AttentionBackendEnum.XFORMERS,
|
AttentionBackendEnum.TORCH_SDPA,
|
||||||
AttentionBackendEnum.ROCM_AITER_FA,
|
AttentionBackendEnum.ROCM_AITER_FA,
|
||||||
}:
|
}:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
@ -451,7 +452,6 @@ class KeyeSiglipAttention(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
|
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
|
||||||
seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
|
|
||||||
batch_size = q.shape[0]
|
batch_size = q.shape[0]
|
||||||
|
|
||||||
if rope_emb is None:
|
if rope_emb is None:
|
||||||
@ -498,17 +498,21 @@ class KeyeSiglipAttention(nn.Module):
|
|||||||
softmax_scale=self.scale,
|
softmax_scale=self.scale,
|
||||||
)
|
)
|
||||||
context_layer = rearrange(output, "(b s) ... -> b s ...", b=batch_size)
|
context_layer = rearrange(output, "(b s) ... -> b s ...", b=batch_size)
|
||||||
elif self.attn_backend == AttentionBackendEnum.XFORMERS:
|
elif self.attn_backend == AttentionBackendEnum.TORCH_SDPA:
|
||||||
from xformers import ops as xops
|
outputs = []
|
||||||
from xformers.ops.fmha.attn_bias import BlockDiagonalMask
|
for i in range(1, len(cu_seqlens)):
|
||||||
|
start_idx = cu_seqlens[i - 1]
|
||||||
attn_bias = BlockDiagonalMask.from_seqlens(
|
end_idx = cu_seqlens[i]
|
||||||
q_seqlen=seqlens, kv_seqlen=None, device=q.device
|
q_i = q[:, start_idx:end_idx]
|
||||||
)
|
k_i = k[:, start_idx:end_idx]
|
||||||
|
v_i = v[:, start_idx:end_idx]
|
||||||
context_layer = xops.memory_efficient_attention_forward(
|
q_i, k_i, v_i = (
|
||||||
q, k, v, attn_bias=attn_bias, p=0, scale=None
|
rearrange(x, "b s h d -> b h s d") for x in (q_i, k_i, v_i)
|
||||||
)
|
)
|
||||||
|
output_i = F.scaled_dot_product_attention(q_i, k_i, v_i, dropout_p=0.0)
|
||||||
|
output_i = rearrange(output_i, "b h s d -> b s h d ")
|
||||||
|
outputs.append(output_i)
|
||||||
|
context_layer = torch.cat(outputs, dim=1) if outputs else q[:, :0]
|
||||||
|
|
||||||
context_layer = rearrange(context_layer, "b s h d -> b s (h d)").contiguous()
|
context_layer = rearrange(context_layer, "b s h d -> b s (h d)").contiguous()
|
||||||
|
|
||||||
|
|||||||
@ -38,7 +38,6 @@ from vllm.attention.layer import (
|
|||||||
)
|
)
|
||||||
from vllm.attention.ops.vit_attn_wrappers import (
|
from vllm.attention.ops.vit_attn_wrappers import (
|
||||||
vit_flash_attn_wrapper,
|
vit_flash_attn_wrapper,
|
||||||
vit_xformers_attn_wrapper,
|
|
||||||
)
|
)
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
from vllm.config.multimodal import BaseDummyOptions
|
from vllm.config.multimodal import BaseDummyOptions
|
||||||
@ -657,7 +656,6 @@ class SiglipAttention(nn.Module):
|
|||||||
cu_seqlens: torch.Tensor,
|
cu_seqlens: torch.Tensor,
|
||||||
rotary_pos_emb: torch.Tensor | None,
|
rotary_pos_emb: torch.Tensor | None,
|
||||||
max_seqlen: torch.Tensor | None,
|
max_seqlen: torch.Tensor | None,
|
||||||
seqlens: torch.Tensor | None,
|
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
batch_size, _, _ = hidden_states.shape
|
batch_size, _, _ = hidden_states.shape
|
||||||
|
|
||||||
@ -703,10 +701,6 @@ class SiglipAttention(nn.Module):
|
|||||||
context_layer = rearrange(
|
context_layer = rearrange(
|
||||||
context_layer, "b s h d -> s b (h d)"
|
context_layer, "b s h d -> s b (h d)"
|
||||||
).contiguous()
|
).contiguous()
|
||||||
elif self.attn_backend == AttentionBackendEnum.XFORMERS:
|
|
||||||
if seqlens is None:
|
|
||||||
raise ValueError("xFormers attention backend requires seqlens tensor.")
|
|
||||||
context_layer = vit_xformers_attn_wrapper(q, k, v, seqlens)
|
|
||||||
else:
|
else:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
f"PaddleOCR-VL does not support {self.attn_backend} backend now."
|
f"PaddleOCR-VL does not support {self.attn_backend} backend now."
|
||||||
@ -818,7 +812,6 @@ class SiglipEncoderLayer(nn.Module):
|
|||||||
cu_seqlens: torch.Tensor,
|
cu_seqlens: torch.Tensor,
|
||||||
rotary_pos_emb: torch.Tensor | None,
|
rotary_pos_emb: torch.Tensor | None,
|
||||||
max_seqlen: torch.Tensor | None,
|
max_seqlen: torch.Tensor | None,
|
||||||
seqlens: torch.Tensor | None,
|
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
residual = hidden_states
|
residual = hidden_states
|
||||||
|
|
||||||
@ -828,7 +821,6 @@ class SiglipEncoderLayer(nn.Module):
|
|||||||
cu_seqlens=cu_seqlens,
|
cu_seqlens=cu_seqlens,
|
||||||
rotary_pos_emb=rotary_pos_emb,
|
rotary_pos_emb=rotary_pos_emb,
|
||||||
max_seqlen=max_seqlen,
|
max_seqlen=max_seqlen,
|
||||||
seqlens=seqlens,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
hidden_states = residual + hidden_states
|
hidden_states = residual + hidden_states
|
||||||
@ -870,7 +862,6 @@ class SiglipEncoder(nn.Module):
|
|||||||
if self.attn_backend not in {
|
if self.attn_backend not in {
|
||||||
AttentionBackendEnum.FLASH_ATTN,
|
AttentionBackendEnum.FLASH_ATTN,
|
||||||
AttentionBackendEnum.TORCH_SDPA,
|
AttentionBackendEnum.TORCH_SDPA,
|
||||||
AttentionBackendEnum.XFORMERS,
|
|
||||||
AttentionBackendEnum.ROCM_AITER_FA,
|
AttentionBackendEnum.ROCM_AITER_FA,
|
||||||
}:
|
}:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
@ -943,14 +934,11 @@ class SiglipEncoder(nn.Module):
|
|||||||
cu_seqlens = cu_seqlens.to(device=device)
|
cu_seqlens = cu_seqlens.to(device=device)
|
||||||
|
|
||||||
max_seqlen = None
|
max_seqlen = None
|
||||||
seqlens = None
|
|
||||||
if self.attn_backend in {
|
if self.attn_backend in {
|
||||||
AttentionBackendEnum.FLASH_ATTN,
|
AttentionBackendEnum.FLASH_ATTN,
|
||||||
AttentionBackendEnum.ROCM_AITER_FA,
|
AttentionBackendEnum.ROCM_AITER_FA,
|
||||||
}:
|
}:
|
||||||
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
|
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
|
||||||
elif self.attn_backend == AttentionBackendEnum.XFORMERS:
|
|
||||||
seqlens = cu_seqlens[1:] - cu_seqlens[:-1]
|
|
||||||
|
|
||||||
hidden_states = inputs_embeds
|
hidden_states = inputs_embeds
|
||||||
for encoder_layer in self.layers:
|
for encoder_layer in self.layers:
|
||||||
@ -959,7 +947,6 @@ class SiglipEncoder(nn.Module):
|
|||||||
cu_seqlens=cu_seqlens,
|
cu_seqlens=cu_seqlens,
|
||||||
rotary_pos_emb=rotary_pos_emb,
|
rotary_pos_emb=rotary_pos_emb,
|
||||||
max_seqlen=max_seqlen,
|
max_seqlen=max_seqlen,
|
||||||
seqlens=seqlens,
|
|
||||||
)
|
)
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
|||||||
@ -74,6 +74,7 @@ from .vision import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
# Note: vLLM does not install xformers by default.
|
||||||
from xformers import ops as xops
|
from xformers import ops as xops
|
||||||
|
|
||||||
if current_platform.is_cuda() and current_platform.has_device_capability(100):
|
if current_platform.is_cuda() and current_platform.has_device_capability(100):
|
||||||
|
|||||||
@ -46,7 +46,6 @@ from vllm.attention.layer import maybe_get_vit_flash_attn_backend
|
|||||||
from vllm.attention.ops.vit_attn_wrappers import (
|
from vllm.attention.ops.vit_attn_wrappers import (
|
||||||
vit_flash_attn_wrapper,
|
vit_flash_attn_wrapper,
|
||||||
vit_torch_sdpa_wrapper,
|
vit_torch_sdpa_wrapper,
|
||||||
vit_xformers_attn_wrapper,
|
|
||||||
)
|
)
|
||||||
from vllm.compilation.decorators import support_torch_compile
|
from vllm.compilation.decorators import support_torch_compile
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
@ -375,7 +374,6 @@ class Qwen2_5_VisionAttention(nn.Module):
|
|||||||
rotary_pos_emb_cos: torch.Tensor,
|
rotary_pos_emb_cos: torch.Tensor,
|
||||||
rotary_pos_emb_sin: torch.Tensor,
|
rotary_pos_emb_sin: torch.Tensor,
|
||||||
max_seqlen: torch.Tensor, # Only used for Flash Attention
|
max_seqlen: torch.Tensor, # Only used for Flash Attention
|
||||||
seqlens: torch.Tensor, # Only used for xFormers
|
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
# [s, b, c] --> [s, b, head * 3 * head_dim]
|
# [s, b, c] --> [s, b, head * 3 * head_dim]
|
||||||
x, _ = self.qkv(x)
|
x, _ = self.qkv(x)
|
||||||
@ -435,8 +433,6 @@ class Qwen2_5_VisionAttention(nn.Module):
|
|||||||
v,
|
v,
|
||||||
cu_seqlens,
|
cu_seqlens,
|
||||||
)
|
)
|
||||||
elif self.attn_backend == AttentionBackendEnum.XFORMERS:
|
|
||||||
context_layer = vit_xformers_attn_wrapper(q, k, v, seqlens)
|
|
||||||
|
|
||||||
output, _ = self.proj(context_layer)
|
output, _ = self.proj(context_layer)
|
||||||
return output
|
return output
|
||||||
@ -448,9 +444,7 @@ class Qwen2_5_VisionAttention(nn.Module):
|
|||||||
"cu_seqlens": 0,
|
"cu_seqlens": 0,
|
||||||
"rotary_pos_emb_cos": 0,
|
"rotary_pos_emb_cos": 0,
|
||||||
"rotary_pos_emb_sin": 0,
|
"rotary_pos_emb_sin": 0,
|
||||||
"seqlens": 0,
|
|
||||||
},
|
},
|
||||||
mark_unbacked_dims={"seqlens": 0},
|
|
||||||
enable_if=should_torch_compile_mm_vit,
|
enable_if=should_torch_compile_mm_vit,
|
||||||
)
|
)
|
||||||
class Qwen2_5_VisionBlock(nn.Module):
|
class Qwen2_5_VisionBlock(nn.Module):
|
||||||
@ -501,7 +495,6 @@ class Qwen2_5_VisionBlock(nn.Module):
|
|||||||
rotary_pos_emb_cos: torch.Tensor,
|
rotary_pos_emb_cos: torch.Tensor,
|
||||||
rotary_pos_emb_sin: torch.Tensor,
|
rotary_pos_emb_sin: torch.Tensor,
|
||||||
max_seqlen: torch.Tensor, # Only used for Flash Attention
|
max_seqlen: torch.Tensor, # Only used for Flash Attention
|
||||||
seqlens: torch.Tensor, # Only used for xFormers
|
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
x_attn = self.attn(
|
x_attn = self.attn(
|
||||||
self.norm1(x),
|
self.norm1(x),
|
||||||
@ -509,7 +502,6 @@ class Qwen2_5_VisionBlock(nn.Module):
|
|||||||
rotary_pos_emb_cos=rotary_pos_emb_cos,
|
rotary_pos_emb_cos=rotary_pos_emb_cos,
|
||||||
rotary_pos_emb_sin=rotary_pos_emb_sin,
|
rotary_pos_emb_sin=rotary_pos_emb_sin,
|
||||||
max_seqlen=max_seqlen,
|
max_seqlen=max_seqlen,
|
||||||
seqlens=seqlens,
|
|
||||||
)
|
)
|
||||||
x_fused_norm, residual = self.norm2(x, residual=x_attn)
|
x_fused_norm, residual = self.norm2(x, residual=x_attn)
|
||||||
x = residual + self.mlp(x_fused_norm)
|
x = residual + self.mlp(x_fused_norm)
|
||||||
@ -670,7 +662,6 @@ class Qwen2_5_VisionTransformer(nn.Module):
|
|||||||
if self.attn_backend not in {
|
if self.attn_backend not in {
|
||||||
AttentionBackendEnum.FLASH_ATTN,
|
AttentionBackendEnum.FLASH_ATTN,
|
||||||
AttentionBackendEnum.TORCH_SDPA,
|
AttentionBackendEnum.TORCH_SDPA,
|
||||||
AttentionBackendEnum.XFORMERS,
|
|
||||||
AttentionBackendEnum.ROCM_AITER_FA,
|
AttentionBackendEnum.ROCM_AITER_FA,
|
||||||
}:
|
}:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
@ -822,17 +813,14 @@ class Qwen2_5_VisionTransformer(nn.Module):
|
|||||||
def compute_attn_mask_seqlen(
|
def compute_attn_mask_seqlen(
|
||||||
self,
|
self,
|
||||||
cu_seqlens: torch.Tensor,
|
cu_seqlens: torch.Tensor,
|
||||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
) -> torch.Tensor:
|
||||||
max_seqlen = torch.zeros([], device=cu_seqlens.device)
|
max_seqlen = torch.zeros([], device=cu_seqlens.device)
|
||||||
seqlens = torch.zeros(1, device=cu_seqlens.device)
|
|
||||||
if self.attn_backend in {
|
if self.attn_backend in {
|
||||||
AttentionBackendEnum.FLASH_ATTN,
|
AttentionBackendEnum.FLASH_ATTN,
|
||||||
AttentionBackendEnum.ROCM_AITER_FA,
|
AttentionBackendEnum.ROCM_AITER_FA,
|
||||||
}:
|
}:
|
||||||
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
|
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
|
||||||
elif self.attn_backend == AttentionBackendEnum.XFORMERS:
|
return max_seqlen
|
||||||
seqlens = cu_seqlens[1:] - cu_seqlens[:-1]
|
|
||||||
return max_seqlen, seqlens
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def invert_permutation(perm: torch.Tensor) -> torch.Tensor:
|
def invert_permutation(perm: torch.Tensor) -> torch.Tensor:
|
||||||
@ -897,10 +885,8 @@ class Qwen2_5_VisionTransformer(nn.Module):
|
|||||||
|
|
||||||
# transformers
|
# transformers
|
||||||
# pre-compute seqlens for window/full attn to reduce cuMemcpy operations
|
# pre-compute seqlens for window/full attn to reduce cuMemcpy operations
|
||||||
max_seqlen_full, seqlens_full = self.compute_attn_mask_seqlen(cu_seqlens)
|
max_seqlen_full = self.compute_attn_mask_seqlen(cu_seqlens)
|
||||||
max_seqlen_window, seqlens_window = self.compute_attn_mask_seqlen(
|
max_seqlen_window = self.compute_attn_mask_seqlen(cu_window_seqlens)
|
||||||
cu_window_seqlens
|
|
||||||
)
|
|
||||||
|
|
||||||
cu_seqlens = cu_seqlens.to(device=self.device, non_blocking=True)
|
cu_seqlens = cu_seqlens.to(device=self.device, non_blocking=True)
|
||||||
cu_window_seqlens = cu_window_seqlens.to(device=self.device, non_blocking=True)
|
cu_window_seqlens = cu_window_seqlens.to(device=self.device, non_blocking=True)
|
||||||
@ -927,11 +913,9 @@ class Qwen2_5_VisionTransformer(nn.Module):
|
|||||||
if layer_num in self.fullatt_block_indexes:
|
if layer_num in self.fullatt_block_indexes:
|
||||||
cu_seqlens_now = cu_seqlens
|
cu_seqlens_now = cu_seqlens
|
||||||
max_seqlen_now = max_seqlen_full
|
max_seqlen_now = max_seqlen_full
|
||||||
seqlens_now = seqlens_full
|
|
||||||
else:
|
else:
|
||||||
cu_seqlens_now = cu_window_seqlens
|
cu_seqlens_now = cu_window_seqlens
|
||||||
max_seqlen_now = max_seqlen_window
|
max_seqlen_now = max_seqlen_window
|
||||||
seqlens_now = seqlens_window
|
|
||||||
|
|
||||||
hidden_states = blk(
|
hidden_states = blk(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
@ -939,7 +923,6 @@ class Qwen2_5_VisionTransformer(nn.Module):
|
|||||||
rotary_pos_emb_cos=rotary_pos_emb_cos,
|
rotary_pos_emb_cos=rotary_pos_emb_cos,
|
||||||
rotary_pos_emb_sin=rotary_pos_emb_sin,
|
rotary_pos_emb_sin=rotary_pos_emb_sin,
|
||||||
max_seqlen=max_seqlen_now,
|
max_seqlen=max_seqlen_now,
|
||||||
seqlens=seqlens_now,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# For Qwen2.5-VL-3B, float16 will overflow at last block
|
# For Qwen2.5-VL-3B, float16 will overflow at last block
|
||||||
|
|||||||
@ -348,7 +348,6 @@ class Qwen2VisionAttention(nn.Module):
|
|||||||
if self.attn_backend not in {
|
if self.attn_backend not in {
|
||||||
AttentionBackendEnum.FLASH_ATTN,
|
AttentionBackendEnum.FLASH_ATTN,
|
||||||
AttentionBackendEnum.TORCH_SDPA,
|
AttentionBackendEnum.TORCH_SDPA,
|
||||||
AttentionBackendEnum.XFORMERS,
|
|
||||||
AttentionBackendEnum.ROCM_AITER_FA,
|
AttentionBackendEnum.ROCM_AITER_FA,
|
||||||
}:
|
}:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
@ -384,7 +383,6 @@ class Qwen2VisionAttention(nn.Module):
|
|||||||
rotary_pos_emb_cos: torch.Tensor,
|
rotary_pos_emb_cos: torch.Tensor,
|
||||||
rotary_pos_emb_sin: torch.Tensor,
|
rotary_pos_emb_sin: torch.Tensor,
|
||||||
max_seqlen: int | None = None, # Only used for Flash Attention
|
max_seqlen: int | None = None, # Only used for Flash Attention
|
||||||
seqlens: list[int] | None = None, # Only used for xFormers
|
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
# [s, b, c] --> [s, b, 3 * head * head_dim]
|
# [s, b, c] --> [s, b, 3 * head * head_dim]
|
||||||
x, _ = self.qkv(x)
|
x, _ = self.qkv(x)
|
||||||
@ -445,20 +443,6 @@ class Qwen2VisionAttention(nn.Module):
|
|||||||
context_layer = rearrange(
|
context_layer = rearrange(
|
||||||
context_layer, "b s h d -> s b (h d)"
|
context_layer, "b s h d -> s b (h d)"
|
||||||
).contiguous()
|
).contiguous()
|
||||||
elif self.attn_backend == AttentionBackendEnum.XFORMERS:
|
|
||||||
from xformers import ops as xops
|
|
||||||
from xformers.ops.fmha.attn_bias import BlockDiagonalMask
|
|
||||||
|
|
||||||
attn_bias = BlockDiagonalMask.from_seqlens(
|
|
||||||
q_seqlen=seqlens, kv_seqlen=None, device=q.device
|
|
||||||
)
|
|
||||||
|
|
||||||
context_layer = xops.memory_efficient_attention_forward(
|
|
||||||
q, k, v, attn_bias=attn_bias, p=0, scale=None
|
|
||||||
)
|
|
||||||
context_layer = rearrange(
|
|
||||||
context_layer, "b s h d -> s b (h d)"
|
|
||||||
).contiguous()
|
|
||||||
|
|
||||||
output, _ = self.proj(context_layer)
|
output, _ = self.proj(context_layer)
|
||||||
return output
|
return output
|
||||||
@ -509,7 +493,6 @@ class Qwen2VisionBlock(nn.Module):
|
|||||||
rotary_pos_emb_cos: torch.Tensor,
|
rotary_pos_emb_cos: torch.Tensor,
|
||||||
rotary_pos_emb_sin: torch.Tensor,
|
rotary_pos_emb_sin: torch.Tensor,
|
||||||
max_seqlen: int | None = None, # Only used for Flash Attention
|
max_seqlen: int | None = None, # Only used for Flash Attention
|
||||||
seqlens: list[int] | None = None, # Only used for xFormers
|
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
x = x + self.attn(
|
x = x + self.attn(
|
||||||
self.norm1(x),
|
self.norm1(x),
|
||||||
@ -517,7 +500,6 @@ class Qwen2VisionBlock(nn.Module):
|
|||||||
rotary_pos_emb_cos=rotary_pos_emb_cos,
|
rotary_pos_emb_cos=rotary_pos_emb_cos,
|
||||||
rotary_pos_emb_sin=rotary_pos_emb_sin,
|
rotary_pos_emb_sin=rotary_pos_emb_sin,
|
||||||
max_seqlen=max_seqlen,
|
max_seqlen=max_seqlen,
|
||||||
seqlens=seqlens,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
x = x + self.mlp(self.norm2(x))
|
x = x + self.mlp(self.norm2(x))
|
||||||
@ -728,18 +710,14 @@ class Qwen2VisionTransformer(nn.Module):
|
|||||||
sin_combined = sin[pos_ids].flatten(1)
|
sin_combined = sin[pos_ids].flatten(1)
|
||||||
return cos_combined, sin_combined
|
return cos_combined, sin_combined
|
||||||
|
|
||||||
def compute_attn_mask_seqlen(
|
def compute_attn_mask_seqlen(self, cu_seqlens: torch.Tensor) -> int | None:
|
||||||
self, cu_seqlens: torch.Tensor
|
max_seqlen = None
|
||||||
) -> tuple[int | None, list[int] | None]:
|
|
||||||
max_seqlen, seqlens = None, None
|
|
||||||
if self.attn_backend in {
|
if self.attn_backend in {
|
||||||
AttentionBackendEnum.FLASH_ATTN,
|
AttentionBackendEnum.FLASH_ATTN,
|
||||||
AttentionBackendEnum.ROCM_AITER_FA,
|
AttentionBackendEnum.ROCM_AITER_FA,
|
||||||
}:
|
}:
|
||||||
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
|
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
|
||||||
elif self.attn_backend == AttentionBackendEnum.XFORMERS:
|
return max_seqlen
|
||||||
seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
|
|
||||||
return max_seqlen, seqlens
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -771,7 +749,7 @@ class Qwen2VisionTransformer(nn.Module):
|
|||||||
x = x.unsqueeze(1)
|
x = x.unsqueeze(1)
|
||||||
|
|
||||||
# pre-compute seqlens for attn mask to reduce cuMemcpy operations
|
# pre-compute seqlens for attn mask to reduce cuMemcpy operations
|
||||||
max_seqlen, seqlens = self.compute_attn_mask_seqlen(cu_seqlens)
|
max_seqlen = self.compute_attn_mask_seqlen(cu_seqlens)
|
||||||
cu_seqlens = cu_seqlens.to(self.device, non_blocking=True)
|
cu_seqlens = cu_seqlens.to(self.device, non_blocking=True)
|
||||||
for blk in self.blocks:
|
for blk in self.blocks:
|
||||||
x = blk(
|
x = blk(
|
||||||
@ -780,7 +758,6 @@ class Qwen2VisionTransformer(nn.Module):
|
|||||||
rotary_pos_emb_cos=rotary_pos_emb_cos,
|
rotary_pos_emb_cos=rotary_pos_emb_cos,
|
||||||
rotary_pos_emb_sin=rotary_pos_emb_sin,
|
rotary_pos_emb_sin=rotary_pos_emb_sin,
|
||||||
max_seqlen=max_seqlen,
|
max_seqlen=max_seqlen,
|
||||||
seqlens=seqlens,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# adapter
|
# adapter
|
||||||
|
|||||||
@ -224,7 +224,6 @@ class Qwen3_VisionBlock(nn.Module):
|
|||||||
rotary_pos_emb_cos: torch.Tensor,
|
rotary_pos_emb_cos: torch.Tensor,
|
||||||
rotary_pos_emb_sin: torch.Tensor,
|
rotary_pos_emb_sin: torch.Tensor,
|
||||||
max_seqlen: torch.Tensor, # Only used for Flash Attention
|
max_seqlen: torch.Tensor, # Only used for Flash Attention
|
||||||
seqlens: torch.Tensor, # Only used for xFormers
|
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
x = x + self.attn(
|
x = x + self.attn(
|
||||||
self.norm1(x),
|
self.norm1(x),
|
||||||
@ -232,7 +231,6 @@ class Qwen3_VisionBlock(nn.Module):
|
|||||||
rotary_pos_emb_cos=rotary_pos_emb_cos,
|
rotary_pos_emb_cos=rotary_pos_emb_cos,
|
||||||
rotary_pos_emb_sin=rotary_pos_emb_sin,
|
rotary_pos_emb_sin=rotary_pos_emb_sin,
|
||||||
max_seqlen=max_seqlen,
|
max_seqlen=max_seqlen,
|
||||||
seqlens=seqlens,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
x = x + self.mlp(self.norm2(x))
|
x = x + self.mlp(self.norm2(x))
|
||||||
@ -500,14 +498,11 @@ class Qwen3Omni_VisionTransformer(nn.Module):
|
|||||||
def compute_attn_mask_seqlen(
|
def compute_attn_mask_seqlen(
|
||||||
self,
|
self,
|
||||||
cu_seqlens: torch.Tensor,
|
cu_seqlens: torch.Tensor,
|
||||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
) -> torch.Tensor:
|
||||||
max_seqlen = torch.zeros([], device=cu_seqlens.device)
|
max_seqlen = torch.zeros([], device=cu_seqlens.device)
|
||||||
seqlens = torch.zeros(1, device=cu_seqlens.device)
|
|
||||||
if self.attn_backend == AttentionBackendEnum.FLASH_ATTN:
|
if self.attn_backend == AttentionBackendEnum.FLASH_ATTN:
|
||||||
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
|
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
|
||||||
elif self.attn_backend == AttentionBackendEnum.XFORMERS:
|
return max_seqlen
|
||||||
seqlens = cu_seqlens[1:] - cu_seqlens[:-1]
|
|
||||||
return max_seqlen, seqlens
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -533,7 +528,7 @@ class Qwen3Omni_VisionTransformer(nn.Module):
|
|||||||
hidden_states = hidden_states.unsqueeze(1)
|
hidden_states = hidden_states.unsqueeze(1)
|
||||||
rotary_pos_emb_cos = rotary_pos_emb_cos.to(hidden_states.device)
|
rotary_pos_emb_cos = rotary_pos_emb_cos.to(hidden_states.device)
|
||||||
rotary_pos_emb_sin = rotary_pos_emb_sin.to(hidden_states.device)
|
rotary_pos_emb_sin = rotary_pos_emb_sin.to(hidden_states.device)
|
||||||
max_seqlen, seqlens = self.compute_attn_mask_seqlen(cu_seqlens)
|
max_seqlen = self.compute_attn_mask_seqlen(cu_seqlens)
|
||||||
|
|
||||||
hidden_states_list = []
|
hidden_states_list = []
|
||||||
deepstack_visual_indexes = self.deepstack_visual_indexes
|
deepstack_visual_indexes = self.deepstack_visual_indexes
|
||||||
@ -545,7 +540,6 @@ class Qwen3Omni_VisionTransformer(nn.Module):
|
|||||||
rotary_pos_emb_cos=rotary_pos_emb_cos,
|
rotary_pos_emb_cos=rotary_pos_emb_cos,
|
||||||
rotary_pos_emb_sin=rotary_pos_emb_sin,
|
rotary_pos_emb_sin=rotary_pos_emb_sin,
|
||||||
max_seqlen=max_seqlen,
|
max_seqlen=max_seqlen,
|
||||||
seqlens=seqlens,
|
|
||||||
)
|
)
|
||||||
if (
|
if (
|
||||||
deepstack_visual_indexes is not None
|
deepstack_visual_indexes is not None
|
||||||
|
|||||||
@ -235,7 +235,6 @@ class Qwen3_VisionBlock(nn.Module):
|
|||||||
rotary_pos_emb_cos: torch.Tensor,
|
rotary_pos_emb_cos: torch.Tensor,
|
||||||
rotary_pos_emb_sin: torch.Tensor,
|
rotary_pos_emb_sin: torch.Tensor,
|
||||||
max_seqlen: torch.Tensor, # Only used for Flash Attention
|
max_seqlen: torch.Tensor, # Only used for Flash Attention
|
||||||
seqlens: torch.Tensor, # Only used for xFormers
|
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
x = x + self.attn(
|
x = x + self.attn(
|
||||||
self.norm1(x),
|
self.norm1(x),
|
||||||
@ -243,7 +242,6 @@ class Qwen3_VisionBlock(nn.Module):
|
|||||||
rotary_pos_emb_cos=rotary_pos_emb_cos,
|
rotary_pos_emb_cos=rotary_pos_emb_cos,
|
||||||
rotary_pos_emb_sin=rotary_pos_emb_sin,
|
rotary_pos_emb_sin=rotary_pos_emb_sin,
|
||||||
max_seqlen=max_seqlen,
|
max_seqlen=max_seqlen,
|
||||||
seqlens=seqlens,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
x = x + self.mlp(self.norm2(x))
|
x = x + self.mlp(self.norm2(x))
|
||||||
@ -391,7 +389,6 @@ class Qwen3_VisionTransformer(nn.Module):
|
|||||||
if self.attn_backend not in {
|
if self.attn_backend not in {
|
||||||
AttentionBackendEnum.FLASH_ATTN,
|
AttentionBackendEnum.FLASH_ATTN,
|
||||||
AttentionBackendEnum.TORCH_SDPA,
|
AttentionBackendEnum.TORCH_SDPA,
|
||||||
AttentionBackendEnum.XFORMERS,
|
|
||||||
AttentionBackendEnum.ROCM_AITER_FA,
|
AttentionBackendEnum.ROCM_AITER_FA,
|
||||||
}:
|
}:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
@ -531,17 +528,14 @@ class Qwen3_VisionTransformer(nn.Module):
|
|||||||
def compute_attn_mask_seqlen(
|
def compute_attn_mask_seqlen(
|
||||||
self,
|
self,
|
||||||
cu_seqlens: torch.Tensor,
|
cu_seqlens: torch.Tensor,
|
||||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
) -> torch.Tensor:
|
||||||
max_seqlen = torch.zeros([], device=cu_seqlens.device)
|
max_seqlen = torch.zeros([], device=cu_seqlens.device)
|
||||||
seqlens = torch.zeros(1, device=cu_seqlens.device)
|
|
||||||
if (
|
if (
|
||||||
self.attn_backend == AttentionBackendEnum.FLASH_ATTN
|
self.attn_backend == AttentionBackendEnum.FLASH_ATTN
|
||||||
or self.attn_backend == AttentionBackendEnum.ROCM_AITER_FA
|
or self.attn_backend == AttentionBackendEnum.ROCM_AITER_FA
|
||||||
):
|
):
|
||||||
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
|
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
|
||||||
elif self.attn_backend == AttentionBackendEnum.XFORMERS:
|
return max_seqlen
|
||||||
seqlens = cu_seqlens[1:] - cu_seqlens[:-1]
|
|
||||||
return max_seqlen, seqlens
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -569,7 +563,7 @@ class Qwen3_VisionTransformer(nn.Module):
|
|||||||
cu_seqlens = torch.from_numpy(cu_seqlens)
|
cu_seqlens = torch.from_numpy(cu_seqlens)
|
||||||
|
|
||||||
hidden_states = hidden_states.unsqueeze(1)
|
hidden_states = hidden_states.unsqueeze(1)
|
||||||
max_seqlen, seqlens = self.compute_attn_mask_seqlen(cu_seqlens)
|
max_seqlen = self.compute_attn_mask_seqlen(cu_seqlens)
|
||||||
cu_seqlens = cu_seqlens.to(self.device, non_blocking=True)
|
cu_seqlens = cu_seqlens.to(self.device, non_blocking=True)
|
||||||
|
|
||||||
deepstack_feature_lists = []
|
deepstack_feature_lists = []
|
||||||
@ -580,7 +574,6 @@ class Qwen3_VisionTransformer(nn.Module):
|
|||||||
rotary_pos_emb_cos=rotary_pos_emb_cos,
|
rotary_pos_emb_cos=rotary_pos_emb_cos,
|
||||||
rotary_pos_emb_sin=rotary_pos_emb_sin,
|
rotary_pos_emb_sin=rotary_pos_emb_sin,
|
||||||
max_seqlen=max_seqlen,
|
max_seqlen=max_seqlen,
|
||||||
seqlens=seqlens,
|
|
||||||
)
|
)
|
||||||
if layer_num in self.deepstack_visual_indexes:
|
if layer_num in self.deepstack_visual_indexes:
|
||||||
deepstack_merger_idx = self.deepstack_visual_indexes.index(layer_num)
|
deepstack_merger_idx = self.deepstack_visual_indexes.index(layer_num)
|
||||||
|
|||||||
@ -277,12 +277,7 @@ class CudaPlatformBase(Platform):
|
|||||||
except ImportError:
|
except ImportError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
if cls.has_device_capability(100):
|
return AttentionBackendEnum.TORCH_SDPA
|
||||||
# xFormers doesn't support Blackwell, fall back to SDPA
|
|
||||||
# See https://github.com/facebookresearch/xformers/issues/1317#issuecomment-3199392579 # noqa: E501
|
|
||||||
return AttentionBackendEnum.TORCH_SDPA
|
|
||||||
else:
|
|
||||||
return AttentionBackendEnum.XFORMERS
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_valid_backends(
|
def get_valid_backends(
|
||||||
|
|||||||
@ -49,7 +49,6 @@ STR_BACKEND_ENV_VAR: str = "VLLM_ATTENTION_BACKEND"
|
|||||||
# Possible string values of STR_BACKEND_ENV_VAR
|
# Possible string values of STR_BACKEND_ENV_VAR
|
||||||
# register, corresponding to possible backends
|
# register, corresponding to possible backends
|
||||||
STR_FLASHINFER_ATTN_VAL: str = "FLASHINFER"
|
STR_FLASHINFER_ATTN_VAL: str = "FLASHINFER"
|
||||||
STR_XFORMERS_ATTN_VAL: str = "XFORMERS"
|
|
||||||
STR_FLASH_ATTN_VAL: str = "FLASH_ATTN"
|
STR_FLASH_ATTN_VAL: str = "FLASH_ATTN"
|
||||||
STR_INVALID_VAL: str = "INVALID"
|
STR_INVALID_VAL: str = "INVALID"
|
||||||
|
|
||||||
|
|||||||
@ -1,420 +0,0 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
||||||
"""Attention layer with XFormersAttention."""
|
|
||||||
|
|
||||||
from dataclasses import dataclass
|
|
||||||
from typing import ClassVar, Optional
|
|
||||||
|
|
||||||
import torch
|
|
||||||
|
|
||||||
from vllm.attention.backends.abstract import (
|
|
||||||
AttentionBackend,
|
|
||||||
AttentionImpl,
|
|
||||||
AttentionType,
|
|
||||||
MultipleOf,
|
|
||||||
)
|
|
||||||
from vllm.attention.ops.triton_unified_attention import unified_attention
|
|
||||||
from vllm.config import VllmConfig
|
|
||||||
from vllm.logger import init_logger
|
|
||||||
from vllm.v1.attention.backends.utils import (
|
|
||||||
AttentionMetadataBuilder,
|
|
||||||
CommonAttentionMetadata,
|
|
||||||
split_decodes_and_prefills,
|
|
||||||
)
|
|
||||||
from vllm.v1.kv_cache_interface import AttentionSpec
|
|
||||||
|
|
||||||
try:
|
|
||||||
from xformers import ops as xops
|
|
||||||
from xformers.ops.fmha.attn_bias import (
|
|
||||||
AttentionBias,
|
|
||||||
PagedBlockDiagonalCausalWithOffsetPaddedKeysMask,
|
|
||||||
)
|
|
||||||
|
|
||||||
XFORMERS_AVAILABLE = True
|
|
||||||
except ImportError:
|
|
||||||
XFORMERS_AVAILABLE = False
|
|
||||||
|
|
||||||
from vllm import _custom_ops as ops
|
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class XFormersAttentionBackend(AttentionBackend):
|
|
||||||
accept_output_buffer: bool = True
|
|
||||||
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def get_supported_kernel_block_sizes() -> list[int | MultipleOf]:
|
|
||||||
return [MultipleOf(16)]
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def get_supported_head_sizes(cls) -> list[int]:
|
|
||||||
return [
|
|
||||||
32,
|
|
||||||
40,
|
|
||||||
48,
|
|
||||||
56,
|
|
||||||
64,
|
|
||||||
72,
|
|
||||||
80,
|
|
||||||
88,
|
|
||||||
96,
|
|
||||||
104,
|
|
||||||
112,
|
|
||||||
120,
|
|
||||||
128,
|
|
||||||
136,
|
|
||||||
144,
|
|
||||||
152,
|
|
||||||
160,
|
|
||||||
168,
|
|
||||||
176,
|
|
||||||
184,
|
|
||||||
192,
|
|
||||||
200,
|
|
||||||
208,
|
|
||||||
216,
|
|
||||||
224,
|
|
||||||
232,
|
|
||||||
240,
|
|
||||||
248,
|
|
||||||
256,
|
|
||||||
]
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def get_name() -> str:
|
|
||||||
return "XFORMERS"
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def get_impl_cls() -> type["XFormersAttentionImpl"]:
|
|
||||||
return XFormersAttentionImpl
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def get_kv_cache_shape(
|
|
||||||
num_blocks: int,
|
|
||||||
block_size: int,
|
|
||||||
num_kv_heads: int,
|
|
||||||
head_size: int,
|
|
||||||
cache_dtype_str: str = "auto",
|
|
||||||
) -> tuple[int, ...]:
|
|
||||||
if block_size % 16 != 0:
|
|
||||||
raise ValueError("Block size must be a multiple of 16.")
|
|
||||||
return (2, num_blocks, block_size, num_kv_heads, head_size)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def get_builder_cls() -> type["XFormersAttentionMetadataBuilder"]:
|
|
||||||
return XFormersAttentionMetadataBuilder
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def use_cascade_attention(*args, **kwargs) -> bool:
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class XFormersAttentionMetadata:
|
|
||||||
num_actual_tokens: int # Number of tokens excluding padding.
|
|
||||||
max_query_len: int
|
|
||||||
query_start_loc: torch.Tensor
|
|
||||||
max_seq_len: int
|
|
||||||
seq_lens: torch.Tensor
|
|
||||||
block_table: torch.Tensor
|
|
||||||
slot_mapping: torch.Tensor
|
|
||||||
|
|
||||||
num_prefill_tokens: int = 0
|
|
||||||
num_decode_tokens: int = 0
|
|
||||||
num_prefills: int = 0
|
|
||||||
num_decodes: int = 0
|
|
||||||
|
|
||||||
# Biases for different attention types.
|
|
||||||
attn_bias: Optional["AttentionBias"] = None
|
|
||||||
|
|
||||||
# Self-attention prefill/decode metadata cache
|
|
||||||
_cached_prefill_metadata: Optional["XFormersAttentionMetadata"] = None
|
|
||||||
_cached_decode_metadata: Optional["XFormersAttentionMetadata"] = None
|
|
||||||
|
|
||||||
@property
|
|
||||||
def prefill_metadata(self) -> Optional["XFormersAttentionMetadata"]:
|
|
||||||
if self.num_prefills == 0:
|
|
||||||
return None
|
|
||||||
|
|
||||||
if self._cached_prefill_metadata is not None:
|
|
||||||
# Recover cached prefill-phase attention
|
|
||||||
# metadata structure
|
|
||||||
return self._cached_prefill_metadata
|
|
||||||
|
|
||||||
q_start_loc = self.query_start_loc[self.num_decodes :]
|
|
||||||
q_seqlens = torch.diff(q_start_loc)
|
|
||||||
kv_seqlens = self.seq_lens[self.num_decodes :]
|
|
||||||
# Construct & cache prefill-phase attention metadata structure
|
|
||||||
self._cached_prefill_metadata = XFormersAttentionMetadata(
|
|
||||||
num_actual_tokens=self.num_prefill_tokens,
|
|
||||||
max_query_len=int(q_seqlens.max().item()),
|
|
||||||
query_start_loc=q_start_loc - q_start_loc[0],
|
|
||||||
max_seq_len=int(kv_seqlens.max().item()),
|
|
||||||
seq_lens=kv_seqlens,
|
|
||||||
block_table=self.block_table[self.num_decodes :],
|
|
||||||
slot_mapping=self.slot_mapping[self.num_decode_tokens :],
|
|
||||||
)
|
|
||||||
return self._cached_prefill_metadata
|
|
||||||
|
|
||||||
@property
|
|
||||||
def decode_metadata(self) -> Optional["XFormersAttentionMetadata"]:
|
|
||||||
if self.num_decode_tokens == 0:
|
|
||||||
return None
|
|
||||||
|
|
||||||
if self._cached_decode_metadata is not None:
|
|
||||||
# Recover cached decode-phase attention
|
|
||||||
# metadata structure
|
|
||||||
return self._cached_decode_metadata
|
|
||||||
|
|
||||||
q_start_loc = self.query_start_loc
|
|
||||||
q_seqlens = torch.diff(q_start_loc)
|
|
||||||
decode_kv_seqlens = self.seq_lens[: self.num_decodes]
|
|
||||||
# Construct & cache decode-phase attention metadata structure
|
|
||||||
self._cached_decode_metadata = XFormersAttentionMetadata(
|
|
||||||
num_actual_tokens=self.num_decode_tokens,
|
|
||||||
max_query_len=int(q_seqlens[: self.num_decodes].max().item()),
|
|
||||||
query_start_loc=q_start_loc[: self.num_decodes + 1],
|
|
||||||
max_seq_len=int(decode_kv_seqlens.max().item()),
|
|
||||||
seq_lens=decode_kv_seqlens,
|
|
||||||
block_table=self.block_table[: self.num_decodes],
|
|
||||||
slot_mapping=self.slot_mapping[: self.num_decode_tokens],
|
|
||||||
attn_bias=self.attn_bias,
|
|
||||||
)
|
|
||||||
return self._cached_decode_metadata
|
|
||||||
|
|
||||||
|
|
||||||
class XFormersAttentionMetadataBuilder(
|
|
||||||
AttentionMetadataBuilder[XFormersAttentionMetadata]
|
|
||||||
):
|
|
||||||
reorder_batch_threshold: int = 1
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
kv_cache_spec: AttentionSpec,
|
|
||||||
layer_names: list[str],
|
|
||||||
vllm_config: VllmConfig,
|
|
||||||
device: torch.device,
|
|
||||||
):
|
|
||||||
super().__init__(kv_cache_spec, layer_names, vllm_config, device)
|
|
||||||
|
|
||||||
assert XFORMERS_AVAILABLE
|
|
||||||
self.block_size = kv_cache_spec.block_size
|
|
||||||
self._num_decodes = 0
|
|
||||||
self._num_decode_tokens = 0
|
|
||||||
|
|
||||||
def build(
|
|
||||||
self,
|
|
||||||
common_prefix_len: int,
|
|
||||||
common_attn_metadata: CommonAttentionMetadata,
|
|
||||||
fast_build: bool = False,
|
|
||||||
) -> XFormersAttentionMetadata:
|
|
||||||
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
|
|
||||||
split_decodes_and_prefills(
|
|
||||||
common_attn_metadata, decode_threshold=self.reorder_batch_threshold
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
num_actual_tokens = common_attn_metadata.num_actual_tokens
|
|
||||||
q_start_loc = common_attn_metadata.query_start_loc
|
|
||||||
q_seqlens = torch.diff(q_start_loc)
|
|
||||||
max_query_len = common_attn_metadata.max_query_len
|
|
||||||
kv_seqlens = common_attn_metadata.seq_lens
|
|
||||||
max_seq_len = common_attn_metadata.max_seq_len
|
|
||||||
block_table = common_attn_metadata.block_table_tensor
|
|
||||||
slot_mapping = common_attn_metadata.slot_mapping
|
|
||||||
|
|
||||||
bias = None
|
|
||||||
if num_decodes > 0:
|
|
||||||
# Construct the decoder bias.
|
|
||||||
decode_q_seqlens = q_seqlens[:num_decodes]
|
|
||||||
decode_kv_seqlens = kv_seqlens[:num_decodes]
|
|
||||||
bias = PagedBlockDiagonalCausalWithOffsetPaddedKeysMask.from_seqlens(
|
|
||||||
q_seqlen=decode_q_seqlens.tolist(),
|
|
||||||
kv_seqlen=decode_kv_seqlens.tolist(),
|
|
||||||
page_size=self.block_size,
|
|
||||||
block_tables=block_table[:num_decodes],
|
|
||||||
device=block_table.device,
|
|
||||||
)
|
|
||||||
|
|
||||||
return XFormersAttentionMetadata(
|
|
||||||
num_actual_tokens=num_actual_tokens,
|
|
||||||
num_prefill_tokens=num_prefill_tokens,
|
|
||||||
num_decode_tokens=num_decode_tokens,
|
|
||||||
num_prefills=num_prefills,
|
|
||||||
num_decodes=num_decodes,
|
|
||||||
max_query_len=max_query_len,
|
|
||||||
query_start_loc=q_start_loc,
|
|
||||||
max_seq_len=max_seq_len,
|
|
||||||
seq_lens=kv_seqlens,
|
|
||||||
block_table=block_table,
|
|
||||||
slot_mapping=slot_mapping,
|
|
||||||
attn_bias=bias,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class XFormersAttentionImpl(AttentionImpl):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
num_heads: int,
|
|
||||||
head_size: int,
|
|
||||||
scale: float,
|
|
||||||
num_kv_heads: int,
|
|
||||||
alibi_slopes: list[float] | None,
|
|
||||||
sliding_window: int | None,
|
|
||||||
kv_cache_dtype: str,
|
|
||||||
logits_soft_cap: float | None = None,
|
|
||||||
attn_type: AttentionType = AttentionType.DECODER,
|
|
||||||
kv_sharing_target_layer_name: str | None = None,
|
|
||||||
) -> None:
|
|
||||||
if kv_sharing_target_layer_name is not None:
|
|
||||||
raise NotImplementedError("KV sharing is not supported in V0.")
|
|
||||||
if alibi_slopes is not None:
|
|
||||||
raise NotImplementedError("XFormers does not support alibi slopes yet.")
|
|
||||||
self.num_heads = num_heads
|
|
||||||
self.head_size = head_size
|
|
||||||
self.scale = float(scale)
|
|
||||||
self.num_kv_heads = num_kv_heads
|
|
||||||
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
|
||||||
self.kv_cache_dtype = kv_cache_dtype
|
|
||||||
self.kv_sharing_target_layer_name = kv_sharing_target_layer_name
|
|
||||||
if alibi_slopes is not None:
|
|
||||||
alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
|
|
||||||
self.alibi_slopes = alibi_slopes
|
|
||||||
if sliding_window is None:
|
|
||||||
self.sliding_window = (-1, -1)
|
|
||||||
else:
|
|
||||||
self.sliding_window = (sliding_window - 1, 0)
|
|
||||||
if logits_soft_cap is None:
|
|
||||||
# Setting logits_soft_cap to 0 means no soft cap.
|
|
||||||
logits_soft_cap = 0
|
|
||||||
self.logits_soft_cap = logits_soft_cap
|
|
||||||
|
|
||||||
if attn_type != AttentionType.DECODER:
|
|
||||||
raise NotImplementedError(
|
|
||||||
"Encoder self-attention and "
|
|
||||||
"encoder/decoder cross-attention "
|
|
||||||
"are not implemented for "
|
|
||||||
"XFormersAttentionImpl."
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
layer: torch.nn.Module,
|
|
||||||
query: torch.Tensor,
|
|
||||||
key: torch.Tensor,
|
|
||||||
value: torch.Tensor,
|
|
||||||
kv_cache: torch.Tensor,
|
|
||||||
attn_metadata: XFormersAttentionMetadata,
|
|
||||||
output: torch.Tensor | None = None,
|
|
||||||
output_scale: torch.Tensor | None = None,
|
|
||||||
output_block_scale: torch.Tensor | None = None,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
"""Forward pass with XFormers.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
query: shape = [num_tokens, num_heads, head_size]
|
|
||||||
key: shape = [num_tokens, num_kv_heads, head_size]
|
|
||||||
value: shape = [num_tokens, num_kv_heads, head_size]
|
|
||||||
kv_cache: shape =
|
|
||||||
[2, num_blocks, block_size, num_kv_heads, head_size]
|
|
||||||
attn_metadata: Metadata for attention.
|
|
||||||
Returns:
|
|
||||||
shape = [num_tokens, num_heads * head_size]
|
|
||||||
"""
|
|
||||||
assert output is not None, "Output tensor must be provided."
|
|
||||||
|
|
||||||
if output_scale is not None or output_block_scale is not None:
|
|
||||||
raise NotImplementedError(
|
|
||||||
"fused output quantization is not yet supported"
|
|
||||||
" for XFormersAttentionImpl"
|
|
||||||
)
|
|
||||||
|
|
||||||
if attn_metadata is None:
|
|
||||||
# Profiling run.
|
|
||||||
return output.fill_(0)
|
|
||||||
|
|
||||||
# Cache the input KVs.
|
|
||||||
key_cache, value_cache = kv_cache.unbind(0)
|
|
||||||
if self.kv_sharing_target_layer_name is None:
|
|
||||||
# Reshape the input keys and values and store them in the cache.
|
|
||||||
# Skip this if sharing KV cache with an earlier attention layer.
|
|
||||||
# NOTE(woosuk): Here, key and value are padded while slot_mapping is
|
|
||||||
# not padded. However, we don't need to do key[:num_actual_tokens]
|
|
||||||
# and value[:num_actual_tokens] because the reshape_and_cache_flash
|
|
||||||
# op uses the slot_mapping's shape to determine the number of
|
|
||||||
# actual tokens.
|
|
||||||
ops.reshape_and_cache_flash(
|
|
||||||
key,
|
|
||||||
value,
|
|
||||||
key_cache,
|
|
||||||
value_cache,
|
|
||||||
attn_metadata.slot_mapping,
|
|
||||||
self.kv_cache_dtype,
|
|
||||||
layer._k_scale,
|
|
||||||
layer._v_scale,
|
|
||||||
)
|
|
||||||
|
|
||||||
num_actual_tokens = attn_metadata.num_actual_tokens
|
|
||||||
num_decode_tokens = attn_metadata.num_decode_tokens
|
|
||||||
if prefill_meta := attn_metadata.prefill_metadata:
|
|
||||||
descale_shape = (prefill_meta.query_start_loc.shape[0] - 1, key.shape[1])
|
|
||||||
unified_attention(
|
|
||||||
q=query[num_decode_tokens:num_actual_tokens],
|
|
||||||
k=key_cache,
|
|
||||||
v=value_cache,
|
|
||||||
out=output[num_decode_tokens:num_actual_tokens],
|
|
||||||
cu_seqlens_q=prefill_meta.query_start_loc,
|
|
||||||
max_seqlen_q=prefill_meta.max_query_len,
|
|
||||||
seqused_k=prefill_meta.seq_lens,
|
|
||||||
max_seqlen_k=prefill_meta.max_seq_len,
|
|
||||||
softmax_scale=self.scale,
|
|
||||||
causal=True,
|
|
||||||
alibi_slopes=self.alibi_slopes,
|
|
||||||
window_size=self.sliding_window,
|
|
||||||
block_table=prefill_meta.block_table,
|
|
||||||
softcap=self.logits_soft_cap,
|
|
||||||
q_descale=None, # Not supported
|
|
||||||
k_descale=layer._k_scale.expand(descale_shape),
|
|
||||||
v_descale=layer._v_scale.expand(descale_shape),
|
|
||||||
)
|
|
||||||
|
|
||||||
if decode_meta := attn_metadata.decode_metadata:
|
|
||||||
# Query for decode. KV is not needed because it is already cached.
|
|
||||||
decode_query = query[:num_decode_tokens]
|
|
||||||
# Reshape query to [1, B_T, G, H, D].
|
|
||||||
q = decode_query.view(
|
|
||||||
1, -1, self.num_kv_heads, self.num_queries_per_kv, self.head_size
|
|
||||||
)
|
|
||||||
# Reshape the k and v caches to [1, Bkv_T, G, H, D]
|
|
||||||
cache_k = key_cache.view(
|
|
||||||
1, -1, self.num_kv_heads, 1, self.head_size
|
|
||||||
).expand(
|
|
||||||
1,
|
|
||||||
-1,
|
|
||||||
self.num_kv_heads,
|
|
||||||
self.num_queries_per_kv,
|
|
||||||
self.head_size,
|
|
||||||
)
|
|
||||||
cache_v = value_cache.view(
|
|
||||||
1, -1, self.num_kv_heads, 1, self.head_size
|
|
||||||
).expand(
|
|
||||||
1,
|
|
||||||
-1,
|
|
||||||
self.num_kv_heads,
|
|
||||||
self.num_queries_per_kv,
|
|
||||||
self.head_size,
|
|
||||||
)
|
|
||||||
|
|
||||||
attn_bias = decode_meta.attn_bias
|
|
||||||
output[:num_decode_tokens] = xops.memory_efficient_attention_forward(
|
|
||||||
q,
|
|
||||||
cache_k,
|
|
||||||
cache_v,
|
|
||||||
attn_bias=attn_bias,
|
|
||||||
p=0.0,
|
|
||||||
scale=self.scale,
|
|
||||||
).view(decode_query.shape)
|
|
||||||
|
|
||||||
# Reshape the output tensor.
|
|
||||||
return output
|
|
||||||
Loading…
x
Reference in New Issue
Block a user