diff --git a/.buildkite/scripts/hardware_ci/run-amd-test.sh b/.buildkite/scripts/hardware_ci/run-amd-test.sh index 156456c92e63c..5e5a532cb57d5 100755 --- a/.buildkite/scripts/hardware_ci/run-amd-test.sh +++ b/.buildkite/scripts/hardware_ci/run-amd-test.sh @@ -108,7 +108,6 @@ fi if [[ $commands == *" kernels/attention"* ]]; then commands="${commands} \ --ignore=kernels/attention/test_attention_selector.py \ - --ignore=kernels/attention/test_blocksparse_attention.py \ --ignore=kernels/attention/test_encoder_decoder_attn.py \ --ignore=kernels/attention/test_flash_attn.py \ --ignore=kernels/attention/test_flashinfer.py \ diff --git a/docs/models/supported_models.md b/docs/models/supported_models.md index 887f754a3d1c8..f5a89ab6cf7dd 100644 --- a/docs/models/supported_models.md +++ b/docs/models/supported_models.md @@ -376,7 +376,6 @@ Specified using `--task generate`. | `OrionForCausalLM` | Orion | `OrionStarAI/Orion-14B-Base`, `OrionStarAI/Orion-14B-Chat`, etc. | | ✅︎ | ✅︎ | | `PhiForCausalLM` | Phi | `microsoft/phi-1_5`, `microsoft/phi-2`, etc. | ✅︎ | ✅︎ | ✅︎ | | `Phi3ForCausalLM` | Phi-4, Phi-3 | `microsoft/Phi-4-mini-instruct`, `microsoft/Phi-4`, `microsoft/Phi-3-mini-4k-instruct`, `microsoft/Phi-3-mini-128k-instruct`, `microsoft/Phi-3-medium-128k-instruct`, etc. | ✅︎ | ✅︎ | ✅︎ | -| `Phi3SmallForCausalLM` | Phi-3-Small | `microsoft/Phi-3-small-8k-instruct`, `microsoft/Phi-3-small-128k-instruct`, etc. | | ✅︎ | ✅︎ | | `PhiMoEForCausalLM` | Phi-3.5-MoE | `microsoft/Phi-3.5-MoE-instruct`, etc. | ✅︎ | ✅︎ | ✅︎ | | `Phi4FlashForCausalLM` | Phi-4-mini-flash-reasoning | `microsoft/microsoft/Phi-4-mini-instruct`, etc. | | | | | `PersimmonForCausalLM` | Persimmon | `adept/persimmon-8b-base`, `adept/persimmon-8b-chat`, etc. | | ✅︎ | ✅︎ | diff --git a/tests/kernels/attention/test_blocksparse_attention.py b/tests/kernels/attention/test_blocksparse_attention.py deleted file mode 100644 index 9aee818c99569..0000000000000 --- a/tests/kernels/attention/test_blocksparse_attention.py +++ /dev/null @@ -1,441 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import random -from typing import Optional - -import pytest -import torch - -from tests.kernels.allclose_default import get_default_atol, get_default_rtol -from vllm import _custom_ops as ops -from vllm.attention.ops.blocksparse_attention.interface import ( - LocalStridedBlockSparseAttn) -from vllm.platforms import current_platform -from vllm.utils import get_max_shared_memory_bytes - -FLOAT32_BYTES = torch.finfo(torch.float).bits // 8 -# This will change depending on the compute capability. -# - 512 as a buffer -MAX_SEQ_LEN = get_max_shared_memory_bytes() // FLOAT32_BYTES - 512 -# MAX_SEQ_LEN = 2771 - -# There may not be enough gpu memory due to large NUM_BLOCKS. -# Reduce NUM_BLOCKS when it happens. -NUM_BLOCKS = 4321 # Arbitrary values for testing -PARTITION_SIZE = 512 -DTYPES = [torch.half, torch.bfloat16] -NUM_GEN_SEQS = [3] # Arbitrary values for testing -NUM_PREFILL_SEQS = [3] # Arbitrary values for testing -NUM_HEADS = [(40, 40)] # Arbitrary values for testing - -HEAD_SIZES = [64, 112] -BLOCK_SIZES = [16] -USE_ALIBI = [False, True] -KV_CACHE_DTYPE = ["auto", "fp8"] -SEEDS = [0] -CUDA_DEVICES = ['cuda:0'] -BLOCKSPARSE_LOCAL_BLOCKS = [16] -BLOCKSPARSE_VERT_STRIDES = [8] - -BLOCKSPARSE_BLOCK_SIZES = [64] -BLOCKSPARSE_HEADS_SLIDINGS = [2, -1] -BLOCKSPARSE_HOMO_HEADS = [True, False] - - -def ref_masked_attention( - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - scale: float, - attn_mask: Optional[torch.Tensor] = None, -) -> torch.Tensor: - attn_weights = scale * torch.einsum("qhd,khd->hqk", query, key).float() - if attn_mask is not None: - attn_weights = attn_weights + attn_mask.float() - attn_weights = torch.softmax(attn_weights, dim=-1).to(value.dtype) - out = torch.einsum("hqk,khd->qhd", attn_weights, value) - return out - - -def ref_single_query_cached_kv_attention( - output: torch.Tensor, - query: torch.Tensor, - num_queries_per_kv: int, - key_cache: torch.Tensor, - value_cache: torch.Tensor, - block_tables: torch.Tensor, - seq_lens: torch.Tensor, - scale: float, - alibi_slopes: Optional[torch.Tensor], - tp_rank: int = 0, - blocksparse_local_blocks: int = 0, - blocksparse_vert_stride: int = 1, - blocksparse_block_size: int = 64, - blocksparse_head_sliding_step: int = 0, -) -> None: - num_query_heads = query.shape[1] - num_kv_heads = value_cache.shape[1] - head_size = value_cache.shape[2] - block_size = value_cache.shape[3] - num_seqs = query.shape[0] - - block_tables_lst = block_tables.cpu().tolist() - seq_lens_lst = seq_lens.cpu().tolist() - for i in range(num_seqs): - q = query[i].unsqueeze(0) - block_table = block_tables_lst[i] - seq_len = int(seq_lens_lst[i]) - - keys_lst: list[torch.Tensor] = [] - values_lst: list[torch.Tensor] = [] - for j in range(seq_len): - block_number = int(block_table[j // block_size]) - block_offset = j % block_size - - k = key_cache[block_number, :, :, block_offset, :] - k = k.reshape(num_kv_heads, head_size) - keys_lst.append(k) - - v = value_cache[block_number, :, :, block_offset] - values_lst.append(v) - keys = torch.stack(keys_lst, dim=0) - values = torch.stack(values_lst, dim=0) - if num_queries_per_kv > 1: - # Handle MQA and GQA - keys = torch.repeat_interleave(keys, num_queries_per_kv, dim=1) - values = torch.repeat_interleave(values, num_queries_per_kv, dim=1) - - alibi_bias = None - if alibi_slopes is not None: - # Create the ALiBi bias used in the paged attention kernel. - position_ids = torch.arange(seq_len).int() - alibi_bias = (position_ids - seq_len + 1).float() - alibi_bias = alibi_slopes.view(-1, 1, 1) * alibi_bias.view( - 1, 1, -1) - - if blocksparse_vert_stride >= 1: - bsize = blocksparse_block_size - hsliding = blocksparse_head_sliding_step - vert = blocksparse_vert_stride - locals = blocksparse_local_blocks - qb = (seq_len - 1) // bsize - attn_mask = q.new_zeros( - (num_query_heads, 1, seq_len)).float() - torch.inf - for h in range(num_query_heads): - if hsliding >= 0: # slide with q heads - bs_offset = (tp_rank * num_query_heads + h) * hsliding + 1 - else: # slide with kv heads - bs_offset = (tp_rank * num_kv_heads + - h // num_queries_per_kv) * (-hsliding) + 1 - for kb in range(qb + 1): - kj = kb * bsize - if (qb - kb) < locals or \ - (kb + bs_offset) % vert == 0: - attn_mask[h, 0, kj:min(kj + bsize, seq_len)] = 0 - if alibi_bias is not None: - attn_mask += alibi_bias - else: - attn_mask = alibi_bias - - out = ref_masked_attention(q, keys, values, scale, attn_mask=attn_mask) - out = out.view(num_query_heads, head_size) - output[i].copy_(out, non_blocking=True) - - -@pytest.mark.parametrize("version", ["v1", "v2"]) -@pytest.mark.parametrize("num_seqs", NUM_GEN_SEQS) -@pytest.mark.parametrize("num_heads", NUM_HEADS) -@pytest.mark.parametrize("head_size", HEAD_SIZES) -@pytest.mark.parametrize("use_alibi", USE_ALIBI) -@pytest.mark.parametrize("block_size", BLOCK_SIZES) -@pytest.mark.parametrize("dtype", DTYPES) -@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE) -@pytest.mark.parametrize("seed", SEEDS) -@pytest.mark.parametrize("device", CUDA_DEVICES) -@pytest.mark.parametrize("blocksparse_local_blocks", BLOCKSPARSE_LOCAL_BLOCKS) -@pytest.mark.parametrize("blocksparse_vert_stride", BLOCKSPARSE_VERT_STRIDES) -@pytest.mark.parametrize("blocksparse_block_size", BLOCKSPARSE_BLOCK_SIZES) -@pytest.mark.parametrize("blocksparse_head_sliding_step", - BLOCKSPARSE_HEADS_SLIDINGS) -def test_paged_attention( - kv_cache_factory, - version: str, - num_seqs: int, - num_heads: tuple[int, int], - head_size: int, - use_alibi: bool, - block_size: int, - dtype: torch.dtype, - kv_cache_dtype: str, - seed: int, - device: str, - blocksparse_local_blocks: int, - blocksparse_vert_stride: int, - blocksparse_block_size: int, - blocksparse_head_sliding_step: int, -) -> None: - current_platform.seed_everything(seed) - torch.set_default_device(device) - scale = float(1.0 / (head_size**0.5)) - num_query_heads, num_kv_heads = num_heads - query = torch.empty(num_seqs, num_query_heads, head_size, dtype=dtype) - query.uniform_(-scale, scale) - - assert num_query_heads % num_kv_heads == 0 - num_queries_per_kv = num_query_heads // num_kv_heads - alibi_slopes = None - if use_alibi: - alibi_slopes = torch.rand(num_query_heads, dtype=torch.float) - - seq_lens = [random.randint(1, MAX_SEQ_LEN) for _ in range(num_seqs)] - seq_lens[-1] = MAX_SEQ_LEN - max_seq_len = max(seq_lens) - seq_lens = torch.tensor(seq_lens, dtype=torch.int) - - # Create the block tables. - max_num_blocks_per_seq = (max_seq_len + block_size - 1) // block_size - block_tables = [] - for _ in range(num_seqs): - block_table = [ - random.randint(0, NUM_BLOCKS - 1) - for _ in range(max_num_blocks_per_seq) - ] - block_tables.append(block_table) - block_tables = torch.tensor(block_tables, dtype=torch.int) - - # Create the KV caches. - key_caches, value_caches = kv_cache_factory(NUM_BLOCKS, block_size, 1, - num_kv_heads, head_size, - kv_cache_dtype, dtype, seed, - device) - key_cache, value_cache = key_caches[0], value_caches[0] - - # Using default kv_scale - k_scale = v_scale = torch.tensor(1.0, dtype=torch.float32, device=device) - tp_rank = 0 - - # Call the paged attention kernel. - output = torch.empty_like(query) - if version == "v1": - ops.paged_attention_v1( - output, - query, - key_cache, - value_cache, - num_kv_heads, - scale, - block_tables, - seq_lens, - block_size, - max_seq_len, - alibi_slopes, - kv_cache_dtype, - k_scale, - v_scale, - tp_rank=tp_rank, - blocksparse_local_blocks=blocksparse_local_blocks, - blocksparse_vert_stride=blocksparse_vert_stride, - blocksparse_block_size=blocksparse_block_size, - blocksparse_head_sliding_step=blocksparse_head_sliding_step, - ) - elif version == "v2": - num_partitions = ((max_seq_len + PARTITION_SIZE - 1) // PARTITION_SIZE) - assert PARTITION_SIZE % block_size == 0 - num_seqs, num_heads, head_size = output.shape - tmp_output = torch.empty( - size=(num_seqs, num_heads, num_partitions, head_size), - dtype=output.dtype, - ) - exp_sums = torch.empty( - size=(num_seqs, num_heads, num_partitions), - dtype=torch.float32, - ) - max_logits = torch.empty_like(exp_sums) - ops.paged_attention_v2( - output, - exp_sums, - max_logits, - tmp_output, - query, - key_cache, - value_cache, - num_kv_heads, - scale, - block_tables, - seq_lens, - block_size, - max_seq_len, - alibi_slopes, - kv_cache_dtype, - k_scale, - v_scale, - tp_rank=tp_rank, - blocksparse_local_blocks=blocksparse_local_blocks, - blocksparse_vert_stride=blocksparse_vert_stride, - blocksparse_block_size=blocksparse_block_size, - blocksparse_head_sliding_step=blocksparse_head_sliding_step, - ) - else: - raise AssertionError(f"Unknown version: {version}") - - # Run the reference implementation. - if kv_cache_dtype == "fp8": - # Convert cache data back to dtype. - x = 16 // torch.tensor([], dtype=dtype).element_size() - key_cache_shape = (NUM_BLOCKS, num_kv_heads, head_size // x, - block_size, x) - dequantized_key_cache = torch.empty(size=key_cache_shape, - dtype=dtype, - device=device) - ops.convert_fp8(dequantized_key_cache, key_cache) - key_cache = dequantized_key_cache - - value_cache_shape = value_cache.shape - dequantized_value_cache = torch.empty(size=value_cache_shape, - dtype=dtype, - device=device) - ops.convert_fp8(dequantized_value_cache, value_cache) - value_cache = dequantized_value_cache - - ref_output = torch.empty_like(query) - ref_single_query_cached_kv_attention( - ref_output, - query, - num_queries_per_kv, - key_cache, - value_cache, - block_tables, - seq_lens, - scale, - alibi_slopes, - tp_rank, - blocksparse_local_blocks, - blocksparse_vert_stride, - blocksparse_block_size, - blocksparse_head_sliding_step, - ) - - # NOTE(woosuk): Due to the kernel-level differences in the two - # implementations, there is a small numerical difference in the two - # outputs. Thus, we use a relaxed tolerance for the test. - atol = get_default_atol(output) if current_platform.is_rocm() else 1e-3 - rtol = get_default_rtol(output) if current_platform.is_rocm() else 1e-5 - - # NOTE(zhaoyang): FP8 KV Cache will introduce quantization error, - # so we use a relaxed tolerance for the test. - atol, rtol = 1e-3, 1e-5 - if kv_cache_dtype == "fp8": - atol, rtol = 1e-2, 1e-5 - torch.testing.assert_close(output, ref_output, atol=atol, rtol=rtol) - - -def ref_multi_query_kv_attention( - cu_seq_lens: list[int], - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - scale: float, - dtype: torch.dtype, -) -> torch.Tensor: - num_seqs = len(cu_seq_lens) - 1 - ref_outputs = [] - for i in range(num_seqs): - start_idx = cu_seq_lens[i] - end_idx = cu_seq_lens[i + 1] - seq_len = end_idx - start_idx - - # Create attention mask. - attn_mask = torch.triu(torch.ones(seq_len, seq_len, dtype=dtype), - diagonal=1) - attn_mask = attn_mask * torch.finfo(dtype).min - attn_mask = attn_mask.to(dtype=dtype) - - ref_output = ref_masked_attention( - query[start_idx:end_idx], - key[start_idx:end_idx], - value[start_idx:end_idx], - scale, - attn_mask=attn_mask, - ) - ref_outputs.append(ref_output) - ref_output = torch.cat(ref_outputs, dim=0) - return ref_output - - -@pytest.mark.parametrize("num_seqs", NUM_PREFILL_SEQS) -@pytest.mark.parametrize("num_heads", NUM_HEADS) -@pytest.mark.parametrize("head_size", HEAD_SIZES) -@pytest.mark.parametrize("blocksparse_local_blocks", BLOCKSPARSE_LOCAL_BLOCKS) -@pytest.mark.parametrize("blocksparse_vert_stride", BLOCKSPARSE_VERT_STRIDES) -@pytest.mark.parametrize("blocksparse_block_size", BLOCKSPARSE_BLOCK_SIZES) -@pytest.mark.parametrize("blocksparse_homo_heads", BLOCKSPARSE_HOMO_HEADS) -@pytest.mark.parametrize("dtype", DTYPES) -@pytest.mark.parametrize("seed", SEEDS) -@pytest.mark.parametrize("device", CUDA_DEVICES) -@torch.inference_mode() -def test_varlen_blocksparse_attention_prefill( - num_seqs: int, - num_heads: tuple[int, int], - head_size: int, - blocksparse_local_blocks: int, - blocksparse_vert_stride: int, - blocksparse_block_size: int, - blocksparse_homo_heads: bool, - dtype: torch.dtype, - seed: int, - device: str, -) -> None: - current_platform.seed_everything(seed) - torch.set_default_device(device) - # MAX_SEQ_LEN sometimes causes OOM in the reference implementation. - # As the xformers library is already tested with its own tests, we can use - # a smaller MAX_SEQ_LEN here. - max_len = min(MAX_SEQ_LEN, 4096) - seq_lens = random.sample(range(1, max_len), num_seqs) - cu_seq_lens = torch.cumsum(torch.tensor([0] + seq_lens), dim=0) - num_tokens = sum(seq_lens) - - scale = float(1.0 / (head_size**0.5)) - num_query_heads, num_kv_heads = num_heads - assert num_query_heads % num_kv_heads == 0 - num_queries_per_kv = num_query_heads // num_kv_heads - - qkv = torch.empty(num_tokens, - num_query_heads + 2 * num_kv_heads, - head_size, - dtype=dtype) - qkv.uniform_(-scale, scale) - query, key, value = qkv.split( - [num_query_heads, num_kv_heads, num_kv_heads], dim=1) - - bs_attn_op = LocalStridedBlockSparseAttn( - num_query_heads, - max_len, - local_blocks=blocksparse_local_blocks, - vert_stride=blocksparse_vert_stride, - block_size=blocksparse_block_size, - device=device, - dtype=dtype, - homo_head=blocksparse_homo_heads) - - output = bs_attn_op(query, - key, - value, - cu_seq_lens.to(device), - sm_scale=scale) - - if num_queries_per_kv > 1: - # Handle MQA and GQA - key = torch.repeat_interleave(key, num_queries_per_kv, dim=1) - value = torch.repeat_interleave(value, num_queries_per_kv, dim=1) - - ref_output = ref_multi_query_kv_attention( - cu_seq_lens.tolist(), - query, - key, - value, - scale, - dtype, - ) - torch.testing.assert_close(output, ref_output, atol=1e-2, rtol=1e-2) diff --git a/tests/kernels/attention/test_rocm_attention_selector.py b/tests/kernels/attention/test_rocm_attention_selector.py index 34311b9ccd767..d56d3f4638f1c 100644 --- a/tests/kernels/attention/test_rocm_attention_selector.py +++ b/tests/kernels/attention/test_rocm_attention_selector.py @@ -33,8 +33,12 @@ def test_selector(monkeypatch: pytest.MonkeyPatch): # change the attention backend to triton MLA m.setenv(STR_BACKEND_ENV_VAR, "TRITON_MLA") - backend = get_attn_backend(576, torch.bfloat16, "auto", 16, False, - False, True) + backend = get_attn_backend(576, + torch.bfloat16, + "auto", + 16, + False, + use_mla=True) assert (backend.get_name() == "TRITON_MLA" or backend.get_name() == "TRITON_MLA_VLLM_V1") @@ -42,15 +46,23 @@ def test_selector(monkeypatch: pytest.MonkeyPatch): # If use_mla is true # The selected backend is triton MLA m.setenv(STR_BACKEND_ENV_VAR, None) - backend = get_attn_backend(576, torch.bfloat16, "auto", 16, False, - False, True) + backend = get_attn_backend(576, + torch.bfloat16, + "auto", + 16, + False, + use_mla=True) assert (backend.get_name() == "TRITON_MLA" or backend.get_name() == "TRITON_MLA_VLLM_V1") # change the attention backend to AITER MLA m.setenv(STR_BACKEND_ENV_VAR, "ROCM_AITER_MLA") - backend = get_attn_backend(576, torch.bfloat16, "auto", 1, False, - False, True) + backend = get_attn_backend(576, + torch.bfloat16, + "auto", + 1, + False, + use_mla=True) assert (backend.get_name() == "ROCM_AITER_MLA" or backend.get_name() == "ROCM_AITER_MLA_VLLM_V1") @@ -60,7 +72,11 @@ def test_selector(monkeypatch: pytest.MonkeyPatch): # The selected backend is ROCM_AITER_MLA m.setenv(STR_BACKEND_ENV_VAR, None) m.setenv("VLLM_ROCM_USE_AITER", "1") - backend = get_attn_backend(576, torch.bfloat16, "auto", 1, False, - False, True) + backend = get_attn_backend(576, + torch.bfloat16, + "auto", + 1, + False, + use_mla=True) assert (backend.get_name() == "ROCM_AITER_MLA" or backend.get_name() == "ROCM_AITER_MLA_VLLM_V1") diff --git a/tests/models/registry.py b/tests/models/registry.py index 5c546a6c86da2..8afac32e1cf04 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -247,10 +247,6 @@ _TEXT_GENERATION_EXAMPLE_MODELS = { "PersimmonForCausalLM": _HfExamplesInfo("adept/persimmon-8b-chat"), "PhiForCausalLM": _HfExamplesInfo("microsoft/phi-2"), "Phi3ForCausalLM": _HfExamplesInfo("microsoft/Phi-3-mini-4k-instruct"), - # Blocksparse attention not supported in V1 yet - "Phi3SmallForCausalLM": _HfExamplesInfo("microsoft/Phi-3-small-8k-instruct", - trust_remote_code=True, - v0_only=True), "Phi4FlashForCausalLM": _HfExamplesInfo("microsoft/Phi-4-mini-flash-reasoning", # noqa: E501 trust_remote_code=True, v0_only=True, diff --git a/vllm/attention/backends/abstract.py b/vllm/attention/backends/abstract.py index 05c098a58a0d2..ba20da4fd75f6 100644 --- a/vllm/attention/backends/abstract.py +++ b/vllm/attention/backends/abstract.py @@ -269,7 +269,6 @@ class AttentionImpl(ABC, Generic[T]): alibi_slopes: Optional[List[float]] = None, sliding_window: Optional[int] = None, kv_cache_dtype: str = "auto", - blocksparse_params: Optional[Dict[str, Any]] = None, logits_soft_cap: Optional[float] = None, attn_type: str = AttentionType.DECODER, kv_sharing_target_layer_name: Optional[str] = None, diff --git a/vllm/attention/backends/blocksparse_attn.py b/vllm/attention/backends/blocksparse_attn.py deleted file mode 100644 index e4338805f5649..0000000000000 --- a/vllm/attention/backends/blocksparse_attn.py +++ /dev/null @@ -1,466 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -from dataclasses import dataclass, field -from typing import Any, Dict, List, Optional, Tuple, Type - -import torch - -from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, - AttentionLayer, - AttentionMetadata, AttentionType) -from vllm.attention.backends.utils import (CommonAttentionState, - CommonMetadataBuilder) -from vllm.attention.ops.blocksparse_attention.interface import ( - LocalStridedBlockSparseAttn, get_head_sliding_step) -from vllm.attention.ops.paged_attn import PagedAttention -from vllm.distributed import (get_tensor_model_parallel_rank, - get_tensor_model_parallel_world_size) - - -@dataclass -class BlocksparseParams: - max_seqlen: int - - # Num q heads per tensor-parallel rank/partition - num_heads: int # per TP partition - # Num kv heads per tensor-parallel rank/partition - num_kv_heads: int - - # block size used for blocksparse attention. - # This is the block_size used in `local_blocks`, `vert_stride`. - block_size: int - - # Number of blocks for local attention, i.e., number of - # local attended tokens / `sparse_block_size` - local_blocks: int - - # Attend to one block per every `vert_stride` blocks. - # Controlling the sparsity - vert_stride: int - """ - If to use the same vertical stride offset for all heads, - i.e., attend to the same block of tokens on all heads. - By default, it is False, i.e., attention on the non-local - blocks depends on the `head_idx`, that is on - blocks satisfying - `(block_idx + head_idx * head_sliding_step + 1) % vert_stride == 0` - where `head_sliding_step=max(1, int(vert_stride / num_total_heads))`, - `block_idx = position_id // sparse_block_size`. - See `..ops.blocksparse_attention.utils:get_sparse_attn_mask` - for more detail. - """ - homo_head: bool = False - - # If within a group, the kv offsets that each q attends is the same or no. - homo_head_group: bool = False - - # Decided by homo_head and homo_head group - head_sliding_step: int = field(init=False) - - # range of q heads to for a TP rank - active_head_range: Tuple = field(init=False) - - def __post_init__(self): - assert self.block_size > 0 - assert self.local_blocks >= 0 - assert self.vert_stride >= 1 - - tp_size = get_tensor_model_parallel_world_size() - tp_rank = get_tensor_model_parallel_rank() - total_heads = tp_size * self.num_heads - total_kv_heads = tp_size * self.num_kv_heads - - if self.homo_head: - self.head_sliding_step = 0 - elif self.homo_head_group: - head_sliding_step = get_head_sliding_step(total_kv_heads, - self.vert_stride) - # negative indicates sliding along kv heads, i.e., homo q group - self.head_sliding_step = -head_sliding_step - else: - self.head_sliding_step = get_head_sliding_step( - total_heads, self.vert_stride) - - self.active_head_range = ( - tp_rank * self.num_heads, - (tp_rank + 1) * self.num_heads, - ) - - -class BlocksparseFlashAttentionBackend(AttentionBackend): - - @staticmethod - def get_name() -> str: - return "BLOCK_SPARSE_FLASH_ATTN" - - @staticmethod - def get_impl_cls() -> Type["BlocksparseFlashAttentionImpl"]: - return BlocksparseFlashAttentionImpl - - @staticmethod - def get_metadata_cls() -> Type["AttentionMetadata"]: - return BlocksparseFlashAttentionMetadata - - @staticmethod - def get_builder_cls() -> Type["BlocksparseFlashAttentionMetadataBuilder"]: - return BlocksparseFlashAttentionMetadataBuilder - - @staticmethod - def get_state_cls() -> Type["CommonAttentionState"]: - return CommonAttentionState - - @staticmethod - def get_kv_cache_shape( - num_blocks: int, - block_size: int, - num_kv_heads: int, - head_size: int, - ) -> Tuple[int, ...]: - return PagedAttention.get_kv_cache_shape(num_blocks, block_size, - num_kv_heads, head_size) - - @staticmethod - def swap_blocks( - src_kv_cache: torch.Tensor, - dst_kv_cache: torch.Tensor, - src_to_dst: Dict[int, int], - ) -> None: - PagedAttention.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst) - - @staticmethod - def copy_blocks( - kv_caches: List[torch.Tensor], - src_to_dists: Dict[int, List[int]], - ) -> None: - PagedAttention.copy_blocks(kv_caches, src_to_dists) - - -@dataclass -class BlocksparseFlashAttentionMetadata(AttentionMetadata): - """A copy of Metadata for FlashAttentionBackend, - to avoid having to install flash_attn. - - NOTE: Any python object stored here is not updated when it is - cuda-graph replayed. If you have values that need to be changed - dynamically, it should be stored in tensor. The tensor has to be - updated from `CUDAGraphRunner.forward` API. - """ - # (batch_size,). The sequence length per sequence. Sequence length means - # the computed tokens + new tokens None if it is a decoding. - seq_lens: Optional[List[int]] - # seq_lens stored as a tensor. - seq_lens_tensor: Optional[torch.Tensor] - - # NOTE(sang): Definition of context_len, query_len, and seq_len. - # |---------- N-1 iteration --------| - # |---------------- N iteration ---------------------| - # |- tokenA -|......................|-- newTokens ---| - # |---------- context_len ----------| - # |-------------------- seq_len ----------------------| - # |-- query_len ---| - - # Maximum query length in the batch. None for decoding. - max_query_len: Optional[int] - # Maximum sequence length among prefill batch. 0 if there are decoding - # requests only. - max_prefill_seq_len: int - # Maximum sequence length among decode batch. 0 if there are prefill - # requests only. - max_decode_seq_len: int - # (batch_size + 1,). The cumulative subquery lengths of the sequences in - # the batch, used to index into subquery. E.g., if the subquery length - # is [4, 6], it is [0, 4, 10]. - query_start_loc: Optional[torch.Tensor] - # (batch_size + 1,). The cumulative sequence lengths of the sequences in - # the batch, used to index into sequence. E.g., if the sequence length is - # [4, 6], it is [0, 4, 10]. - seq_start_loc: Optional[torch.Tensor] - # (batch_size,) A tensor of context lengths (tokens that are computed - # so far). - context_lens_tensor: Optional[torch.Tensor] - - # (batch_size, max_blocks_per_seq). - # Block addresses per sequence. (Seq id -> list of physical block) - # E.g., [0, 1, 2] means tokens are stored in 0th, 1st, and 2nd blocks - # in the kv cache. Each block can contain up to block_size tokens. - # 2nd dimensions are padded up to max_blocks_per_seq if it is cuda-graph - # captured. - block_tables: Optional[torch.Tensor] - - # Whether or not if cuda graph is enabled. - # Cuda-graph is currently enabled for decoding only. - # TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention. - use_cuda_graph: bool - - # Max number of query tokens for among request in the batch. - max_decode_query_len: Optional[int] = None - - _cached_prefill_metadata: Optional[ - "BlocksparseFlashAttentionMetadata"] = None - _cached_decode_metadata: Optional[ - "BlocksparseFlashAttentionMetadata"] = None - - @property - def prefill_metadata( - self) -> Optional["BlocksparseFlashAttentionMetadata"]: - if self.num_prefills == 0: - return None - - if self._cached_prefill_metadata is not None: - return self._cached_prefill_metadata - - assert self.seq_lens is not None - assert self.seq_lens_tensor is not None - assert self.query_start_loc is not None - assert self.context_lens_tensor is not None - assert self.block_tables is not None - assert self.seq_start_loc is not None - - self._cached_prefill_metadata = BlocksparseFlashAttentionMetadata( - num_prefills=self.num_prefills, - num_prefill_tokens=self.num_prefill_tokens, - num_decode_tokens=0, - slot_mapping=self.slot_mapping[:self.num_prefill_tokens], - multi_modal_placeholder_index_maps=self. - multi_modal_placeholder_index_maps, - enable_kv_scales_calculation=self.enable_kv_scales_calculation, - seq_lens=self.seq_lens[:self.num_prefills], - seq_lens_tensor=self.seq_lens_tensor[:self.num_prefills], - max_query_len=self.max_query_len, - max_prefill_seq_len=self.max_prefill_seq_len, - max_decode_seq_len=0, - query_start_loc=self.query_start_loc[:self.num_prefills + 1], - seq_start_loc=self.seq_start_loc[:self.num_prefills + 1], - context_lens_tensor=self.context_lens_tensor[:self.num_prefills], - block_tables=self.block_tables[:self.num_prefills], - use_cuda_graph=False, - ) - return self._cached_prefill_metadata - - @property - def decode_metadata(self) -> Optional["BlocksparseFlashAttentionMetadata"]: - if self.num_decode_tokens == 0: - return None - - if self._cached_decode_metadata is not None: - return self._cached_decode_metadata - assert self.block_tables is not None - assert self.seq_lens_tensor is not None - - self._cached_decode_metadata = BlocksparseFlashAttentionMetadata( - num_prefills=0, - num_prefill_tokens=0, - num_decode_tokens=self.num_decode_tokens, - slot_mapping=self.slot_mapping[self.num_prefill_tokens:], - multi_modal_placeholder_index_maps=None, - enable_kv_scales_calculation=False, - seq_lens=None, - seq_lens_tensor=self.seq_lens_tensor[self.num_prefills:], - max_query_len=None, - max_prefill_seq_len=0, - max_decode_seq_len=self.max_decode_seq_len, - query_start_loc=None, - seq_start_loc=None, - context_lens_tensor=None, - block_tables=self.block_tables[self.num_prefills:], - use_cuda_graph=self.use_cuda_graph, - ) - return self._cached_decode_metadata - - -class BlocksparseFlashAttentionMetadataBuilder( - CommonMetadataBuilder[BlocksparseFlashAttentionMetadata]): - - _metadata_cls = BlocksparseFlashAttentionMetadata - - -class BlocksparseFlashAttentionImpl(AttentionImpl): - """ - If the input tensors contain prompt tokens, the layout is as follows: - |<--------------- num_prompt_tokens -------------->| - |<--prompt_0-->|<--prompt_1-->|...|<--prompt_N-1-->| - - Otherwise, the layout is as follows: - |<------------------ num_generation_tokens (M) ----------------->| - |<--generation_0-->|..........|<--generation_M-1-->|<--padding-->| - - Generation tokens can contain padding when cuda-graph is used. - Currently, prompt tokens don't contain any padding. - - The prompts might have different lengths, while the generation tokens - always have length 1. - - """ - - def __init__( - self, - num_heads: int, - head_size: int, - scale: float, - num_kv_heads: int, - alibi_slopes: Optional[List[float]], - sliding_window: Optional[int], - kv_cache_dtype: str, - blocksparse_params: Optional[Dict[str, Any]] = None, - logits_soft_cap: Optional[float] = None, - attn_type: str = AttentionType.DECODER, - kv_sharing_target_layer_name: Optional[str] = None, - ) -> None: - if kv_sharing_target_layer_name is not None: - raise NotImplementedError("KV sharing is not supported in V0 " - "BLOCK_SPARSE_FLASH_ATTN Backend.") - assert blocksparse_params is not None - assert alibi_slopes is None, ValueError( - "Alibi not support for blocksparse flash attention.") - assert sliding_window is None, ValueError( - "sliding_window is invalid for blocksparse attention.") - assert logits_soft_cap is None, ValueError( - "logits_soft_cap is invalid for blocksparse attention.") - - if "num_heads" not in blocksparse_params: - blocksparse_params["num_heads"] = num_heads - if "num_kv_heads" not in blocksparse_params: - blocksparse_params["num_kv_heads"] = num_kv_heads or num_heads - self.blocksparse_params = BlocksparseParams(**blocksparse_params) - self.kv_cache_dtype = kv_cache_dtype - - self.num_heads = num_heads - self.head_size = head_size - self.scale = float(scale) - self.alibi_slopes = alibi_slopes - self.num_kv_heads = num_kv_heads - - self.num_queries_per_kv = self.num_heads // self.num_kv_heads - - self.local_blocks = self.blocksparse_params.local_blocks - self.vert_stride = self.blocksparse_params.vert_stride - self.sparse_block_size = self.blocksparse_params.block_size - self.head_sliding_step = self.blocksparse_params.head_sliding_step - - supported_head_sizes = PagedAttention.get_supported_head_sizes() - if head_size not in supported_head_sizes: - raise ValueError( - f"Head size {head_size} is not supported by PagedAttention. " - f"Supported head sizes are: {supported_head_sizes}.") - - self.tp_size = get_tensor_model_parallel_world_size() - self.tp_rank = get_tensor_model_parallel_rank() - - total_num_heads = num_heads * self.tp_size - self.bs_attn = LocalStridedBlockSparseAttn( - total_num_heads, - self.blocksparse_params.max_seqlen, - self.blocksparse_params.local_blocks, - self.blocksparse_params.vert_stride, - self.blocksparse_params.block_size, - homo_head=self.blocksparse_params.homo_head, - active_head_range=self.blocksparse_params.active_head_range, - ) - - if attn_type != AttentionType.DECODER: - raise NotImplementedError("Encoder self-attention and " - "encoder/decoder cross-attention " - "are not implemented for " - "BlocksparseFlashAttentionImpl") - - def forward( - self, - layer: AttentionLayer, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: BlocksparseFlashAttentionMetadata, - output: Optional[torch.Tensor] = None, - output_scale: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - """Forward pass with FlashAttention and PagedAttention. - - Args: - query: shape = [num_tokens, num_heads * head_size] - key: shape = [num_tokens, num_kv_heads * head_size] - value: shape = [num_tokens, num_kv_heads * head_size] - kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size] - NOTE: kv_cache will be an empty tensor with shape [0] - for profiling run. - attn_metadata: Metadata for attention. - Returns: - shape = [num_tokens, num_heads * head_size] - """ - if output_scale is not None: - raise NotImplementedError( - "fused output quantization is not yet supported" - " for BlocksparseFlashAttentionImpl") - - num_tokens, hidden_size = query.shape - # Reshape the query, key, and value tensors. - query = query.view(-1, self.num_heads, self.head_size) - key = key.view(-1, self.num_kv_heads, self.head_size) - value = value.view(-1, self.num_kv_heads, self.head_size) - - if kv_cache.numel() > 0: - key_cache, value_cache = PagedAttention.split_kv_cache( - kv_cache, self.num_kv_heads, self.head_size) - - # Reshape the input keys and values and store them in the cache. - # If kv_cache is not provided, the new key and value tensors are - # not cached. This happens during the initial memory profiling run. - - PagedAttention.write_to_paged_cache( - key, - value, - key_cache, - value_cache, - attn_metadata.slot_mapping, - self.kv_cache_dtype, - layer._k_scale, - layer._v_scale, - ) - - if prefill_meta := attn_metadata.prefill_metadata: - - # Prompt run. - # normal attention - # When block_tables are not filled, it means q and k are the - # prompt, and they have the same length. - - assert kv_cache.numel() == 0 \ - or prefill_meta.block_tables is None \ - or prefill_meta.block_tables.numel() == 0, \ - "Does not support prefix-enabled attention." - - output = self.bs_attn( - q=query, - k=key, - v=value, - cu_seqlens_q=prefill_meta.seq_start_loc, - cu_seqlens_k=prefill_meta.seq_start_loc, - sm_scale=self.scale, - ) - - if decode_meta := attn_metadata.decode_metadata: - # Decoding run. - output = PagedAttention.forward_decode( - query, - key_cache, - value_cache, - decode_meta.block_tables, - decode_meta.seq_lens_tensor, - self.blocksparse_params.max_seqlen, - self.kv_cache_dtype, - self.num_kv_heads, - self.scale, - self.alibi_slopes, - layer._k_scale, - layer._v_scale, - tp_rank=self.tp_rank, - blocksparse_local_blocks=self.local_blocks, - blocksparse_vert_stride=self.vert_stride, - blocksparse_block_size=self.sparse_block_size, - blocksparse_head_sliding_step=self.head_sliding_step, - ) - - assert output is not None - # Reshape the output tensor. - return output.view(num_tokens, hidden_size) diff --git a/vllm/attention/backends/differential_flash_attn.py b/vllm/attention/backends/differential_flash_attn.py index 1c139952371a9..bd9bc427728d0 100644 --- a/vllm/attention/backends/differential_flash_attn.py +++ b/vllm/attention/backends/differential_flash_attn.py @@ -667,7 +667,6 @@ class DifferentialFlashAttentionImpl(AttentionImpl): alibi_slopes: Optional[List[float]], sliding_window: Optional[int], kv_cache_dtype: str, - blocksparse_params: Optional[Dict[str, Any]] = None, logits_soft_cap: Optional[float] = None, attn_type: str = AttentionType.DECODER, kv_sharing_target_layer_name: Optional[str] = None, @@ -680,9 +679,6 @@ class DifferentialFlashAttentionImpl(AttentionImpl): differential_flash_attention_config self.used_shared_kv_cache = kv_sharing_target_layer_name is not None self.kv_sharing_target_layer_name = kv_sharing_target_layer_name - if blocksparse_params is not None: - raise ValueError( - "FlashAttention does not support block-sparse attention.") if use_irope: logger.warning( "Using irope in V0 is not supported yet, it will fall back " diff --git a/vllm/attention/backends/dual_chunk_flash_attn.py b/vllm/attention/backends/dual_chunk_flash_attn.py index 40557a4e8f8f5..e108646e7ffb5 100644 --- a/vllm/attention/backends/dual_chunk_flash_attn.py +++ b/vllm/attention/backends/dual_chunk_flash_attn.py @@ -287,7 +287,6 @@ class DualChunkFlashAttentionImpl(FlashAttentionImpl): alibi_slopes: Optional[List[float]], sliding_window: Optional[int], kv_cache_dtype: str, - blocksparse_params: Optional[Dict[str, Any]] = None, logits_soft_cap: Optional[float] = None, attn_type: str = AttentionType.DECODER, kv_sharing_target_layer_name: Optional[str] = None, diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index 20e67eb9b401e..ee36fd19e0122 100755 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -4,7 +4,7 @@ from collections import defaultdict from dataclasses import dataclass from itertools import accumulate -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type +from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Type import torch @@ -615,7 +615,6 @@ class FlashAttentionImpl(AttentionImpl): alibi_slopes: Optional[List[float]], sliding_window: Optional[int], kv_cache_dtype: str, - blocksparse_params: Optional[Dict[str, Any]] = None, logits_soft_cap: Optional[float] = None, attn_type: str = AttentionType.DECODER, kv_sharing_target_layer_name: Optional[str] = None, @@ -624,9 +623,6 @@ class FlashAttentionImpl(AttentionImpl): if kv_sharing_target_layer_name is not None: raise NotImplementedError("KV sharing is not supported in V0 " "FLASH_ATTN backend.") - if blocksparse_params is not None: - raise ValueError( - "FlashAttention does not support block-sparse attention.") if use_irope: logger.warning( "Using irope in V0 is not supported yet, it will fall back " diff --git a/vllm/attention/backends/flashinfer.py b/vllm/attention/backends/flashinfer.py index 1f913ad895238..56d3da699f405 100644 --- a/vllm/attention/backends/flashinfer.py +++ b/vllm/attention/backends/flashinfer.py @@ -999,7 +999,6 @@ class FlashInferImpl(AttentionImpl): alibi_slopes: Optional[List[float]], sliding_window: Optional[int], kv_cache_dtype: str, - blocksparse_params: Optional[Dict[str, Any]] = None, logits_soft_cap: Optional[float] = None, attn_type: str = AttentionType.DECODER, kv_sharing_target_layer_name: Optional[str] = None, diff --git a/vllm/attention/backends/flashmla.py b/vllm/attention/backends/flashmla.py index e185d0260d0a0..a242ac9bbe0b6 100644 --- a/vllm/attention/backends/flashmla.py +++ b/vllm/attention/backends/flashmla.py @@ -3,7 +3,7 @@ from contextlib import contextmanager from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type +from typing import TYPE_CHECKING, List, Optional, Tuple, Type import torch @@ -181,7 +181,6 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]): alibi_slopes: Optional[List[float]], sliding_window: Optional[int], kv_cache_dtype: str, - blocksparse_params: Optional[Dict[str, Any]], logits_soft_cap: Optional[float], attn_type: str, kv_sharing_target_layer_name: Optional[str] = None, @@ -189,20 +188,17 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]): **mla_args) -> None: super().__init__(num_heads, head_size, scale, num_kv_heads, alibi_slopes, sliding_window, kv_cache_dtype, - blocksparse_params, logits_soft_cap, attn_type, + logits_soft_cap, attn_type, kv_sharing_target_layer_name, **mla_args) assert is_flashmla_supported(), \ "FlashMLA is not supported on this device" - unsupported_features = [ - alibi_slopes, sliding_window, blocksparse_params, logits_soft_cap - ] + unsupported_features = [alibi_slopes, sliding_window, logits_soft_cap] if any(unsupported_features): raise NotImplementedError( "FlashMLAImpl does not support one of the following: " - "alibi_slopes, sliding_window, blocksparse_params, " - "logits_soft_cap") + "alibi_slopes, sliding_window, logits_soft_cap") if attn_type != AttentionType.DECODER: raise NotImplementedError("Encoder self-attention and " diff --git a/vllm/attention/backends/mla/common.py b/vllm/attention/backends/mla/common.py index 0c3ff26d04c8b..52c4a9e7da3de 100644 --- a/vllm/attention/backends/mla/common.py +++ b/vllm/attention/backends/mla/common.py @@ -997,7 +997,6 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]): alibi_slopes: Optional[List[float]], sliding_window: Optional[int], kv_cache_dtype: str, - blocksparse_params: Optional[Dict[str, Any]], logits_soft_cap: Optional[float], attn_type: str, kv_sharing_target_layer_name: Optional[str], diff --git a/vllm/attention/backends/rocm_aiter_mla.py b/vllm/attention/backends/rocm_aiter_mla.py index 1edf34351db3f..a165a786d63d0 100644 --- a/vllm/attention/backends/rocm_aiter_mla.py +++ b/vllm/attention/backends/rocm_aiter_mla.py @@ -3,7 +3,7 @@ from contextlib import contextmanager from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, Optional, Type, Union +from typing import TYPE_CHECKING, Optional, Type, Union import torch @@ -367,7 +367,6 @@ class AiterMLAImpl(MLACommonImpl[AiterMLAMetadata]): alibi_slopes: Optional[list[float]], sliding_window: Optional[int], kv_cache_dtype: str, - blocksparse_params: Optional[dict[str, Any]], logits_soft_cap: Optional[float], attn_type: str, kv_sharing_target_layer_name: Optional[str], @@ -375,17 +374,14 @@ class AiterMLAImpl(MLACommonImpl[AiterMLAMetadata]): **mla_args) -> None: super().__init__(num_heads, head_size, scale, num_kv_heads, alibi_slopes, sliding_window, kv_cache_dtype, - blocksparse_params, logits_soft_cap, attn_type, + logits_soft_cap, attn_type, kv_sharing_target_layer_name, **mla_args) - unsupported_features = [ - alibi_slopes, sliding_window, blocksparse_params, logits_soft_cap - ] + unsupported_features = [alibi_slopes, sliding_window, logits_soft_cap] if any(unsupported_features): raise NotImplementedError( "Aiter MLA does not support one of the following: " - "alibi_slopes, sliding_window, blocksparse_params, " - "logits_soft_cap") + "alibi_slopes, sliding_window, logits_soft_cap") from aiter import flash_attn_varlen_func self.flash_attn_varlen_func = flash_attn_varlen_func diff --git a/vllm/attention/backends/rocm_flash_attn.py b/vllm/attention/backends/rocm_flash_attn.py index 4653d5267e197..1ee1dea729d9e 100644 --- a/vllm/attention/backends/rocm_flash_attn.py +++ b/vllm/attention/backends/rocm_flash_attn.py @@ -4,7 +4,7 @@ import itertools from dataclasses import dataclass from functools import cache -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type +from typing import TYPE_CHECKING, List, Optional, Tuple, Type import torch @@ -494,7 +494,6 @@ class ROCmFlashAttentionImpl(AttentionImpl): alibi_slopes: Optional[List[float]], sliding_window: Optional[int], kv_cache_dtype: str, - blocksparse_params: Optional[Dict[str, Any]] = None, logits_soft_cap: Optional[float] = None, attn_type: str = AttentionType.DECODER, kv_sharing_target_layer_name: Optional[str] = None, @@ -507,9 +506,6 @@ class ROCmFlashAttentionImpl(AttentionImpl): logger.warning_once( "Using irope in ROCm Flash Attention is not supported yet, it " "will fail back to global attention for long context.") - if blocksparse_params is not None: - raise ValueError( - "ROCmFlashAttention does not support blocksparse attention.") if use_irope: logger.warning( "Using irope in V0 is not supported yet, it will fall back " diff --git a/vllm/attention/backends/triton_mla.py b/vllm/attention/backends/triton_mla.py index e06f7d54e3421..fba5b5f6bca86 100644 --- a/vllm/attention/backends/triton_mla.py +++ b/vllm/attention/backends/triton_mla.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Any, Dict, List, Optional, Type +from typing import List, Optional, Type import torch @@ -35,7 +35,6 @@ class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]): alibi_slopes: Optional[List[float]], sliding_window: Optional[int], kv_cache_dtype: str, - blocksparse_params: Optional[Dict[str, Any]], logits_soft_cap: Optional[float], attn_type: str, kv_sharing_target_layer_name: Optional[str], @@ -43,17 +42,14 @@ class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]): **mla_args) -> None: super().__init__(num_heads, head_size, scale, num_kv_heads, alibi_slopes, sliding_window, kv_cache_dtype, - blocksparse_params, logits_soft_cap, attn_type, + logits_soft_cap, attn_type, kv_sharing_target_layer_name, **mla_args) - unsupported_features = [ - alibi_slopes, sliding_window, blocksparse_params, logits_soft_cap - ] + unsupported_features = [alibi_slopes, sliding_window, logits_soft_cap] if any(unsupported_features): raise NotImplementedError( "TritonMLAImpl does not support one of the following: " - "alibi_slopes, sliding_window, blocksparse_params, " - "logits_soft_cap") + "alibi_slopes, sliding_window, logits_soft_cap") if attn_type != AttentionType.DECODER: raise NotImplementedError("Encoder self-attention and " diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py index 3ef79bb621208..0bc38b4142901 100644 --- a/vllm/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -2,7 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Attention layer with xFormers and PagedAttention.""" from dataclasses import dataclass -from typing import Any, Dict, List, Optional, Tuple, Type +from typing import Dict, List, Optional, Tuple, Type import torch from xformers import ops as xops @@ -387,7 +387,6 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]): alibi_slopes: Optional[List[float]], sliding_window: Optional[int], kv_cache_dtype: str, - blocksparse_params: Optional[Dict[str, Any]] = None, logits_soft_cap: Optional[float] = None, attn_type: str = AttentionType.DECODER, kv_sharing_target_layer_name: Optional[str] = None, @@ -396,9 +395,6 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]): if kv_sharing_target_layer_name is not None: raise NotImplementedError("KV sharing is not supported in V0 " "XFORMERS backend.") - if blocksparse_params is not None: - raise ValueError( - "XFormers does not support block-sparse attention.") if logits_soft_cap is not None: logger.warning_once("XFormers does not support logits soft cap. " "Outputs may be slightly off.") diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index d0677525d3106..5d8ffb8e82d3f 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Attention layer.""" -from typing import Any, Dict, List, Optional +from typing import List, Optional import torch import torch.nn as nn @@ -74,7 +74,6 @@ class Attention(nn.Module): alibi_slopes: Optional[List[float]] = None, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, - blocksparse_params: Optional[Dict[str, Any]] = None, logits_soft_cap: Optional[float] = None, per_layer_sliding_window: Optional[int] = None, use_mla: bool = False, @@ -163,12 +162,11 @@ class Attention(nn.Module): kv_cache_dtype, block_size, is_attention_free, - blocksparse_params is not None, use_mla=use_mla) impl_cls = attn_backend.get_impl_cls() self.impl = impl_cls(num_heads, head_size, scale, num_kv_heads, alibi_slopes, sliding_window, kv_cache_dtype, - blocksparse_params, logits_soft_cap, attn_type, + logits_soft_cap, attn_type, kv_sharing_target_layer_name, **extra_impl_args) self.backend = backend_name_to_enum(attn_backend.get_name()) self.dtype = dtype diff --git a/vllm/attention/ops/blocksparse_attention/__init__.py b/vllm/attention/ops/blocksparse_attention/__init__.py deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/vllm/attention/ops/blocksparse_attention/blocksparse_attention_kernel.py b/vllm/attention/ops/blocksparse_attention/blocksparse_attention_kernel.py deleted file mode 100644 index 05fa9d11f2283..0000000000000 --- a/vllm/attention/ops/blocksparse_attention/blocksparse_attention_kernel.py +++ /dev/null @@ -1,433 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import torch - -from vllm.triton_utils import tl, triton - - -def blocksparse_flash_attn_varlen_fwd( - q, - k, - v, # (#tokens, n_heads, head_size) - cu_seqlens_k, - cu_seqlens_q, - sm_scale, - sparse_layout, - *, - block_size=64, - q_block_size=None, - max_seqlen=None): - # split q to blocks - - assert isinstance(sparse_layout, (list, tuple)) - - _, n_heads, head_size = q.shape - batch_size = cu_seqlens_k.size(0) - 1 - q_block_size = q_block_size or block_size - - assert q.dim() == k.dim() == v.dim() == 3 - assert q.size(1) % k.size(1) == 0 - assert q.size(2) == k.size(2) - # TODO(linxihui): allow k, v to have different head_size - assert k.shape == v.shape - assert cu_seqlens_k.dim() == 1 - - q_k_ratio = q.size(1) // k.size(1) - - if cu_seqlens_q is None: - if q.size(0) == batch_size: # decoding only - cu_seqlens_q = torch.arange( - 0, - batch_size + 1, - dtype=cu_seqlens_k.dtype, - device=cu_seqlens_k.device, - ) - elif q.size(0) == k.size(0): - cu_seqlens_q = cu_seqlens_k - else: - raise ValueError("cu_seqlens_q must be specified\ - if it mix of prefilling and decoding.") - else: - assert cu_seqlens_k.size(0) == cu_seqlens_q.size(0) - - # switch to use cpu to avoid too many kernel launches when iterated over - q_lens = (cu_seqlens_q[1:] - cu_seqlens_q[:-1]).cpu() - k_lens = (cu_seqlens_k[1:] - cu_seqlens_k[:-1]).cpu() - - assert torch.logical_or(q_lens == 1, k_lens == q_lens).all(), ( - "length of q should either be 1 (decoding) or same as k (prefilling).") - - if max_seqlen: - assert k_lens.max() <= max_seqlen - - n_blocks = (q_lens + q_block_size - 1) // q_block_size - - q_batch_ids = torch.tensor( - [i for i, n in enumerate(n_blocks) for _ in range(n)], - dtype=cu_seqlens_q.dtype, - device=cu_seqlens_q.device, - ) - q_start_sids = torch.tensor( - [i * q_block_size for n in n_blocks for i in range(n)], - dtype=cu_seqlens_q.dtype, - device=cu_seqlens_q.device, - ) - - out = q.new_empty(q.shape) - cu_seqlens_q = cu_seqlens_q.contiguous() - cu_seqlens_k = cu_seqlens_k.contiguous() - - layout_crow_indices, layout_col_indices = sparse_layout - block_d = triton.next_power_of_2(head_size) - - decoding_only = (q_lens == 1).all().item() - grid = (len(q_start_sids), n_heads, 1) - - _fwd_kernel_batch_inference[grid]( - q, - k, - v, - out, - sm_scale, - cu_seqlens_q[:-1], - cu_seqlens_q[1:], - cu_seqlens_k[:-1], - cu_seqlens_k[1:], - q_batch_ids, - q_start_sids, - 0, - *q.stride(), - 0, - *k.stride(), - 0, - *v.stride(), - 0, - *out.stride(), - layout_crow_indices, - layout_col_indices, - *layout_crow_indices.stride(), - *layout_col_indices.stride(), - q_k_ratio, - HAS_BATCH_DIM=False, - D_HEAD=head_size, - BLOCK_M=q_block_size, - BLOCK_N=block_size, - BLOCK_D=block_d, - BLOCK_M_LOADING=(16 if decoding_only else - q_block_size), # smaller for decoding - EVEN_D=block_d == head_size, - num_warps=1 if decoding_only else 4, - num_stages=3) - - return out - - -@triton.jit -def _fwd_kernel_inner( - acc, - l_i, - m_i, - q, - Q, - k_block_col_idx, - layout_col_ptr, - layout_col_stride_h, - layout_col_stride_m, - k_ptrs, - v_ptrs, - off_h, - offs_m, - offs_n, - offs_d, - stride_kt, - stride_vt, - sm_scale, - k_seqlen, - past_len, - LAST_K_BLOCK: tl.constexpr, - BLOCK_M_LOADING: tl.constexpr, - BLOCK_N: tl.constexpr, - D_HEAD: tl.constexpr, - EVEN_D: tl.constexpr, - M_LT_N: tl.constexpr, -): - k_block_id = tl.load(layout_col_ptr + off_h * layout_col_stride_h + - k_block_col_idx * layout_col_stride_m).to(tl.int32) - start_n = k_block_id * BLOCK_N - if LAST_K_BLOCK: - if EVEN_D: - k = tl.load( - k_ptrs + start_n * stride_kt, - mask=offs_n[None, :] + start_n < k_seqlen, - other=0.0, - ) - else: - k = tl.load( - k_ptrs + start_n * stride_kt, - mask=(offs_n[None, :] + start_n < k_seqlen) & - (offs_d[:, None] < D_HEAD), - other=0.0, - ) - else: - if EVEN_D: - k = tl.load(k_ptrs + start_n * stride_kt) - else: - k = tl.load(k_ptrs + start_n * stride_kt, - mask=offs_d[:, None] < D_HEAD, - other=0.0) - - qk = tl.zeros([BLOCK_M_LOADING, BLOCK_N], dtype=tl.float32) - qk += tl.dot(q, k) - qk *= sm_scale - - # the following is needed only when LAST_K_BLOCK or BLOCK_M < BLOCK_N - if LAST_K_BLOCK | M_LT_N: - qk += tl.where( - offs_m[:, None] + past_len >= (start_n + offs_n[None, :]), - 0, - float("-inf"), - ) - - # flash-attn2 - m_ij = tl.maximum(m_i, tl.max(qk, 1)) - p = tl.math.exp2(qk - m_ij[:, None]) - l_ij = tl.sum(p, 1) - alpha = tl.math.exp2(m_i - m_ij) - acc = acc * alpha[:, None] - # update m_i - m_i = m_ij - l_i = l_i * alpha + l_ij - - p = p.to(Q.dtype.element_ty) - # update acc - if LAST_K_BLOCK: - if EVEN_D: - v = tl.load( - v_ptrs + start_n * stride_vt, - mask=offs_n[:, None] + start_n < k_seqlen, - other=0.0, - ) - else: - v = tl.load( - v_ptrs + start_n * stride_vt, - mask=(offs_n[:, None] + start_n < k_seqlen) & - (offs_d[None, :] < D_HEAD), - other=0.0, - ) - else: - if EVEN_D: - v = tl.load(v_ptrs + start_n * stride_vt) - else: - v = tl.load(v_ptrs + start_n * stride_vt, - mask=offs_d[None, :] < D_HEAD, - other=0.0) - - acc += tl.dot(p, v) - - return acc, l_i, m_i - - -@triton.heuristics({ - "M_LT_N": - lambda kwargs: kwargs["BLOCK_M"] < kwargs["BLOCK_N"], -}) -@triton.jit -def _fwd_kernel_batch_inference( - Q, - K, - V, - Out, - sm_scale, - q_batch_starts, - q_batch_ends, - k_batch_starts, - k_batch_ends, - q_batch_ids, - q_start_sids, - stride_qb, - stride_qt, - stride_qh, - stride_qd, - stride_kb, - stride_kt, - stride_kh, - stride_kd, - stride_vb, - stride_vt, - stride_vh, - stride_vd, - stride_ob, - stride_ot, - stride_oh, - stride_od, - layout_crow_ptr, - layout_col_ptr, - layout_crow_stride_h, - layout_crow_stride_m, - layout_col_stride_h, - layout_col_stride_m, - q_k_ratio, - HAS_BATCH_DIM: tl.constexpr, - D_HEAD: tl.constexpr, - BLOCK_M: tl.constexpr, - BLOCK_N: tl.constexpr, - BLOCK_D: tl.constexpr, - BLOCK_M_LOADING: tl.constexpr, - EVEN_D: tl.constexpr, - M_LT_N: tl.constexpr, -): - """ - NOTATION: - pid: position id - sid: storage id - sbid: storage block id - pbid: position block id - offs_m, offs_n: storage offsets of m-dim(q, row) and n-dim(k, col) - - TODO(linxihui): - Optimize grouped-attn - """ - off_zm = tl.program_id(0) - off_h = tl.program_id(1) - - off_h_for_kv = off_h // q_k_ratio - - if HAS_BATCH_DIM: - off_z = tl.program_id(2) - Q += off_z * stride_qb - K += off_z * stride_kb - V += off_z * stride_vb - Out += off_z * stride_ob - start_m = off_zm - q_start_sid = start_m * BLOCK_M # always 0 for decoding - else: - off_z = tl.load(q_batch_ids + off_zm).to(tl.int32) # [0, 0, 0, 1] - q_start_sid = tl.load(q_start_sids + off_zm) - start_m = q_start_sid // BLOCK_M # q_sbid - - offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M_LOADING) - offs_n = tl.arange(0, BLOCK_N) - offs_d = tl.arange(0, BLOCK_D) - - q_cu_start = tl.load(q_batch_starts + off_z).to(tl.int32) - q_seqlen = tl.load(q_batch_ends + off_z).to(tl.int32) - q_cu_start - k_cu_start = tl.load(k_batch_starts + off_z).to(tl.int32) - k_seqlen = tl.load(k_batch_ends + off_z).to(tl.int32) - k_cu_start - past_len = k_seqlen - q_seqlen - - Q += q_cu_start * stride_qt + off_h * stride_qh - K += k_cu_start * stride_kt + off_h_for_kv * stride_kh - V += k_cu_start * stride_vt + off_h_for_kv * stride_vh - Out += q_cu_start * stride_ot + off_h * stride_oh - - q_pbid = (past_len + q_start_sid) // BLOCK_M - - if EVEN_D: - q = tl.load( - Q + offs_m[:, None] * stride_qt + offs_d[None, :] * stride_qd, - mask=offs_m[:, None] < q_seqlen, - other=0.0, - ) - else: - q = tl.load( - Q + offs_m[:, None] * stride_qt + offs_d[None, :] * stride_qd, - mask=(offs_m[:, None] < q_seqlen) & (offs_d[None, :] < D_HEAD), - other=0.0, - ) - - sparse_crow_ptr = (layout_crow_ptr + off_h * layout_crow_stride_h + - q_pbid * layout_crow_stride_m) - - # TODO(linxihui): load at once, with any Triton version - # that supports `tl.split`, e.g., Triton 3.0 - k_block_start = tl.load(sparse_crow_ptr).to(tl.int32) - k_block_end = tl.load(sparse_crow_ptr + 1).to(tl.int32) - - m_i = tl.zeros([BLOCK_M_LOADING], dtype=tl.float32) - float("inf") - l_i = tl.zeros([BLOCK_M_LOADING], dtype=tl.float32) - acc = tl.zeros([BLOCK_M_LOADING, BLOCK_D], dtype=tl.float32) - - k_ptrs = K + offs_n[None, :] * stride_kt + offs_d[:, None] * stride_kd - v_ptrs = V + offs_n[:, None] * stride_vt + offs_d[None, :] * stride_vd - - sm_scale *= ( - 1.44269504 # 1/log2 as we use base2 for exponential and logarithm - ) - - for k_block_col_idx in range(k_block_start, k_block_end - 1): - acc, l_i, m_i = _fwd_kernel_inner( - acc, - l_i, - m_i, - q, - Q, - k_block_col_idx, - layout_col_ptr, - layout_col_stride_h, - layout_col_stride_m, - k_ptrs, - v_ptrs, - off_h, - offs_m, - offs_n, - offs_d, - stride_kt, - stride_vt, - sm_scale, - k_seqlen, - past_len, - False, - BLOCK_M_LOADING, - BLOCK_N, - D_HEAD, - EVEN_D, - M_LT_N, - ) - - acc, l_i, m_i = _fwd_kernel_inner( - acc, - l_i, - m_i, - q, - Q, - k_block_end - 1, - layout_col_ptr, - layout_col_stride_h, - layout_col_stride_m, - k_ptrs, - v_ptrs, - off_h, - offs_m, - offs_n, - offs_d, - stride_kt, - stride_vt, - sm_scale, - k_seqlen, - past_len, - True, - BLOCK_M_LOADING, - BLOCK_N, - D_HEAD, - EVEN_D, - M_LT_N, - ) - - # flash-attn 2 - m_i += tl.math.log2(l_i) - acc = acc / l_i[:, None] - - # write output - if EVEN_D: - tl.store( - Out + offs_m[:, None] * stride_ot + offs_d[None, :] * stride_od, - acc, - mask=offs_m[:, None] < q_seqlen, - ) - else: - tl.store( - Out + offs_m[:, None] * stride_ot + offs_d[None, :] * stride_od, - acc, - mask=(offs_m[:, None] < q_seqlen) & (offs_d[None, :] < D_HEAD), - ) diff --git a/vllm/attention/ops/blocksparse_attention/interface.py b/vllm/attention/ops/blocksparse_attention/interface.py deleted file mode 100644 index c6f6cc29793f4..0000000000000 --- a/vllm/attention/ops/blocksparse_attention/interface.py +++ /dev/null @@ -1,239 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import math - -import torch - -from vllm.platforms import current_platform - -from .utils import (dense_to_crow_col, get_head_sliding_step, - get_sparse_attn_mask) - -IS_COMPUTE_8_OR_ABOVE = current_platform.has_device_capability(80) - -if IS_COMPUTE_8_OR_ABOVE: - from .blocksparse_attention_kernel import blocksparse_flash_attn_varlen_fwd - - -class LocalStridedBlockSparseAttn(torch.nn.Module): - - def __init__( - self, - n_heads, - max_seqlen, - local_blocks, - vert_stride, - block_size, - device=None, - dtype=None, - homo_head=False, - active_head_range=None, - q_block_size=None, - use_spda=None, - ): - super().__init__() - if use_spda is None: - use_spda = current_platform.is_rocm() or \ - current_platform.is_cpu() or not \ - IS_COMPUTE_8_OR_ABOVE - device = device or (torch.cuda.current_device() - if current_platform.is_cuda_alike() else "cpu") - device = torch.device(device) - # NOTE: vllm CPU backend support BF16 instead of FP16. - dtype = dtype or (torch.bfloat16 if IS_COMPUTE_8_OR_ABOVE - or device.type == "cpu" else torch.half) - - self.n_heads = n_heads - self.max_seqlen = max_seqlen - self.local_blocks = local_blocks - self.vert_stride = vert_stride - self.use_spda = use_spda - self.dtype = dtype - self.device = device - self.block_size = block_size - self.q_block_size = q_block_size - self.homo_head = homo_head - self.active_head_range = active_head_range - self.head_sliding_step = get_head_sliding_step(n_heads, vert_stride, - homo_head) - - sparse_layout, sparse_pattern, self.dense_attn_mask = ( - self.get_attn_pattern(dtype, device)) - - if q_block_size is not None and q_block_size != block_size: - if q_block_size > block_size: - assert q_block_size % block_size == 0 - blocks_to_merge = q_block_size // block_size - shape = sparse_pattern.shape - sparse_pattern = sparse_pattern.view(shape[0], -1, - blocks_to_merge, - shape[-1]) - sparse_pattern = sparse_pattern.sum(2) - sparse_layout = dense_to_crow_col(sparse_pattern) - else: - raise ValueError( - "Does not support smaller q_block_size. It will be slower." - ) - - self.sparse_layout = sparse_layout - - def get_attn_pattern(self, dtype, device): - sparse_layout, sparse_pattern, dense_attn_mask = get_sparse_attn_mask( - self.n_heads, - self.max_seqlen, - self.max_seqlen, - dtype, - device, - block_size=self.block_size, - local_blocks=self.local_blocks, - vert_stride=self.vert_stride, - homo_head=self.homo_head, - return_dense=self.use_spda, - dense_mask_type="bias", - ) - if (not self.homo_head) and (self.active_head_range is not None): - assert isinstance(self.active_head_range, tuple) - assert (len(self.active_head_range) == 2) - h_start, h_end = self.active_head_range - sparse_layout = tuple(x[h_start:h_end] for x in sparse_layout) - if self.use_spda: - dense_attn_mask = dense_attn_mask[h_start:h_end] - return sparse_layout, sparse_pattern, dense_attn_mask - - def varlen_attn(self, - q, - k, - v, - cu_seqlens_k, - cu_seqlens_q=None, - sm_scale=None): - """ - q, k, v: shape = (num_tokens, num_heads_q/kv, head_size). - Support grouped attention, with `q[:, i*r:(i*r + r)]` - is correspondent to `k[:, i]`, where `r` is the q/k ratio. - cu_seqlens_k: shape=(batch_size + 1,), - indicating segment of samples, - e.g., `k[cu_seqlen[i]:cu_seqlne[i+1]]` is q of sample i - cu_seqlens_q: shape=(batch_size + 1, ). - Default None: same as cu_seqlens_k for prefilling or - [0, 1, .., batch_size] for decoding. - The only case you need to specify is when q is a mix of - prefilling and decoding. - sm_scale: softmax scale, default to 1/sqrt(head_size). - - return: tensor of shape as q. - """ - assert ( - IS_COMPUTE_8_OR_ABOVE - ), "Requires compute capability of 8 or above (Ampere or newer) to use \ - Triton kernel." - - sm_scale = sm_scale or 1.0 / math.sqrt(q.size(-1)) - - return blocksparse_flash_attn_varlen_fwd( - q, - k, - v, - cu_seqlens_k, - cu_seqlens_q, - sm_scale, - self.sparse_layout, - block_size=self.block_size, - q_block_size=self.q_block_size, - max_seqlen=self.max_seqlen, - ) - - @staticmethod - def transpose_and_pad(x, cu_seqlens, maxlen, head_repeats=1): - """ - :param x: (total_tokens, n_heads, head_size) - :return: (batch, n_heads, length, head_size) - """ - x_padded = x.new_empty( - len(cu_seqlens) - 1, x.size(1), head_repeats, maxlen, x.size(2)) - cu_seqlens = cu_seqlens.cpu() - for i, (s, e) in enumerate(zip(cu_seqlens[:-1], cu_seqlens[1:])): - x_padded[i, :, :, :e - s].copy_(x[s:e].transpose(0, - 1).unsqueeze(1)) - return x_padded.flatten(1, 2) - - @staticmethod - def transpose_and_unpad(x_padded, cu_seqlens): - """ - :param x_padded: (batch, n_heads, length, head_size) - :return: (total_tokens, n_heads, head_size) - """ - cu_seqlens = cu_seqlens.cpu() - total_n_tokens = cu_seqlens[-1] - x = x_padded.new_empty(total_n_tokens, x_padded.size(1), - x_padded.size(3)) - for i, (s, e) in enumerate(zip(cu_seqlens[:-1], cu_seqlens[1:])): - x[s:e].copy_(x_padded[i, :, :e - s].transpose(0, 1)) - return x - - def spda(self, q, k, v, cu_seqlens_k, cu_seqlens_q=None, sm_scale=None): - """For CPU, V100 or other older GPUs. - NOTE: torch SPDA supports nested tensor, - but seems extremely slow. Choose to pad instead. - """ - assert (cu_seqlens_q is None or - (cu_seqlens_q - == cu_seqlens_k).all()), "Can only handle prompt with SPDA." - assert q.size(0) == k.size(0), "can only handle prompt with SPDA." - - assert q.size(1) % k.size(1) == 0 - q_k_ratio = q.size(1) // k.size(1) - sm_scale = sm_scale or 1.0 / math.sqrt(q.size(-1)) - cu_seqlens = cu_seqlens_k.cpu() - maxlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max() - - if (self.dense_attn_mask.dtype != q.dtype - or self.dense_attn_mask.device != q.device): - _, _, self.dense_attn_mask = self.get_attn_pattern( - q.dtype, q.device) - attn_mask = self.dense_attn_mask[None, :, :maxlen, :maxlen] - - q2 = self.transpose_and_pad(q, cu_seqlens, maxlen, 1) - k2, v2 = (self.transpose_and_pad(x, cu_seqlens, maxlen, q_k_ratio) - for x in [k, v]) - spda_output = torch.nn.functional.scaled_dot_product_attention( - q2, k2, v2, attn_mask=attn_mask, scale=sm_scale) - return self.transpose_and_unpad(spda_output, cu_seqlens) - - def forward(self, q, k, v, cu_seqlens_k, cu_seqlens_q=None, sm_scale=None): - """Dispatch to `varlen_attn` (Ampere or newer) or - `self.spda`(cpu, Volta, Turing or older)based on - the type of device used and cuda compute capability. - - q, k, v: shape = (num_tokens, num_heads_q/kv, head_size). - Support grouped attention, with `q[:, i*r:(i*r + r)]` - is correspondent to `k[:, i]`, where `r` is the q/k ratio. - cu_seqlens_k: shape=(batch_size + 1,), indicating segment of samples, - e.g., `k[cu_seqlen[i]:cu_seqlne[i+1]]` is q of sample i - cu_seqlens_q: shape=(batch_size + 1, ). - Default None: same as cu_seqlens_k for prefilling or - [0, 1, .., batch_size] for decoding. - The only case you need to specify - is when q is a mix of prefilling - and decoding. - sm_scale: softmax scale, default to 1/sqrt(head_size). - - return: tensor of shape as q. - """ - assert k.dim() == 3 - if self.use_spda: - return self.spda( - q, - k, - v, - cu_seqlens_k, - cu_seqlens_q=cu_seqlens_q, - sm_scale=sm_scale, - ) - return self.varlen_attn(q, - k, - v, - cu_seqlens_k, - cu_seqlens_q=cu_seqlens_q, - sm_scale=sm_scale) diff --git a/vllm/attention/ops/blocksparse_attention/utils.py b/vllm/attention/ops/blocksparse_attention/utils.py deleted file mode 100644 index 445720c709c47..0000000000000 --- a/vllm/attention/ops/blocksparse_attention/utils.py +++ /dev/null @@ -1,246 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -# Helper functions for 3D sparse pattern -# These function are not optimized and very inefficient. -# Avoid calling them too frequent or use a cache mechanism. - -from functools import lru_cache - -import numpy as np -import torch - -from vllm.triton_utils import triton - - -class csr_matrix: - """Simple implementation of CSR matrix conversion without scipy. - This replaced scipy.sparse.csr_matrix() previously used.""" - - def __init__(self, input_array): - if not isinstance(input_array, np.ndarray): - raise ValueError("Input must be a NumPy array") - - self.shape = input_array.shape - rows, cols = self.shape - data = [] - indices = [] - indptr = [0] - - for i in range(rows): - for j in range(cols): - if input_array[i, j]: - data.append(input_array[i, j]) - indices.append(j) - indptr.append(len(indices)) - - self.data = np.array(data) - self.indices = np.array(indices) - self.indptr = np.array(indptr) - - -def dense_to_crow_col(x: torch.Tensor): - """Turning a 2D/3D torch tensor (x) to CSR rows/cols indexing. - NOTE: col_indices padded -1 - """ - device = x.device - pad = -1 - dim = x.dim() - assert x.dim() in (2, 3) - if x.dim() == 2: - x = x[None] - x = [csr_matrix(xi.bool().cpu().numpy()) for xi in x] - crows = torch.vstack([torch.from_numpy(xi.indptr) for xi in x]) - cols = [torch.from_numpy(xi.indices) for xi in x] - max_cols = max(len(xi) for xi in cols) - cols = [ - torch.cat([xi, pad + xi.new_zeros(max_cols - xi.shape[0])]) - for xi in cols - ] - cols = torch.vstack(cols) - if dim == 2: - crows = crows[0] - cols = cols[0] - return crows.to(device), cols.to(device) - - -def crow_col_to_dense(crows: torch.Tensor, - cols: torch.Tensor, - dtype: torch.dtype = torch.float16): - dim = crows.dim() - if dim == 1: - crows = crows[None] - cols = cols[None] - device = crows.device - crows, cols = crows.cpu(), cols.cpu() # faster in cpu - shape = (crows.shape[0], crows.shape[1] - 1, cols.max() + 1) - x = torch.zeros(shape, dtype=dtype) - for i in range(shape[0]): - for j in range(shape[1]): - x[i, j, cols[i, crows[i, j]:crows[i, j + 1]]] = 1 - if dim == 1: - x = x[0] - return x.to(device) - - -def dense_to_ccol_row(x: torch.Tensor): - """Similar, but to CSC format""" - x = x.transpose(-2, -1) - return dense_to_crow_col(x) - - -def ccol_row_to_dense(ccol: torch.Tensor, - rows: torch.Tensor, - dtype: torch.dtype = torch.float16): - return crow_col_to_dense(ccol, rows, dtype).permute(0, 2, 1).contiguous() - - -def _get_sparse_attn_mask_homo_head( - q_len: int, - max_seqlen: int, - dtype: torch.dtype, - device: torch.device, - block_size: int = 128, - local_blocks: int = 4, - vert_stride: int = 4, - return_dense: bool = False, -): - """ - :return: a tuple of 3: - - tuple of crow_indices, col_indices representation - of CSR format. - - block dense mask - - all token dense mask (be aware that it can be - OOM if it is too big) if `return_dense==True`, - otherwise, None - """ - with torch.no_grad(): - num_blocks = triton.cdiv(max_seqlen, block_size) - q_pos = torch.arange(num_blocks)[:, None] - k_pos = torch.arange(num_blocks)[None] - mask_vert_strided = (torch.arange(num_blocks) + 1) % vert_stride == 0 - block_mask_dense = (((q_pos >= k_pos) - & ((q_pos - k_pos < local_blocks) - | mask_vert_strided)).to(device).to(dtype)) - num_blocks_q = triton.cdiv(q_len, block_size) - block_mask_dense_output = (dense_to_crow_col( - block_mask_dense[-num_blocks_q:].contiguous())) - if return_dense: - mask_dense = torch.kron( - block_mask_dense, - block_mask_dense.new_ones((block_size, block_size)), - ) - causal_mask = torch.tril(torch.ones( - max_seqlen, max_seqlen)).type_as(mask_dense)[-q_len:] - mask_dense = mask_dense[-q_len:, :max_seqlen] * causal_mask - return ( - block_mask_dense_output, - block_mask_dense, - mask_dense, - ) - else: - return ( - block_mask_dense_output, - block_mask_dense, - None, - ) - - -def binary_mask_to_bias(mask_dense: torch.Tensor): - mask_dense = 1 - mask_dense - mask_dense.masked_fill_(mask_dense.bool(), -torch.inf) - return mask_dense - - -def get_head_sliding_step(n_heads: int, - vert_stride: int, - homo_head: bool = False): - if homo_head: - return 0 - return max(1, int(vert_stride / n_heads)) - - -@lru_cache -def get_sparse_attn_mask( - n_heads: int, - q_len: int, - max_seqlen: int, - dtype: torch.dtype, - device: torch.device, - block_size: int = 64, - local_blocks: int = 4, - vert_stride: int = 4, - homo_head: bool = True, - return_dense: bool = False, - dense_mask_type: str = "binary", -): - """ - :param dense_mask_type: "binary" (0 for skip token, 1 for others) - or "bias" (-inf for skip token, 0 or others) - :return: a tuple of 3: - - tuple of crow_indices, col_indices representation - of CSR format. - - block dense mask - - all token dense mask (be aware that it can be OOM if it - is too big) if `return_dense==True`, otherwise, None - """ - assert dense_mask_type in ("binary", "bias") - if homo_head: - with torch.no_grad(): - (crow, col), block_mask_dense, mask_dense = ( - _get_sparse_attn_mask_homo_head( - q_len, - max_seqlen, - dtype, - device, - block_size, - local_blocks, - vert_stride, - return_dense, - )) - crow = crow[None].expand(n_heads, crow.shape[0]) - col = col[None].expand(n_heads, col.shape[0]) - if return_dense: - mask_dense = mask_dense[None].expand(n_heads, - *mask_dense.shape) - if dense_mask_type == "bias": - mask_dense = binary_mask_to_bias(mask_dense) - return (crow, col), block_mask_dense, mask_dense - - with torch.no_grad(): - num_blocks = triton.cdiv(max_seqlen, block_size) - q_pos = torch.arange(num_blocks)[None, :, None] - k_pos = torch.arange(num_blocks)[None, None] - head_sliding_step = get_head_sliding_step(n_heads, vert_stride) - mask_vert_strided = [ - (torch.arange(num_blocks) + h * head_sliding_step + 1) % - vert_stride == 0 for h in range(n_heads) - ] - mask_vert_strided = torch.vstack(mask_vert_strided).unsqueeze(1) - block_mask_dense = (((q_pos >= k_pos) - & ((q_pos - k_pos < local_blocks) - | mask_vert_strided)).to(device).to(dtype)) - num_blocks_q = triton.cdiv(q_len, block_size) - block_mask_dense_output = block_mask_dense[:, -num_blocks_q:] - if return_dense: - mask_dense = torch.kron( - block_mask_dense, - block_mask_dense.new_ones((block_size, block_size)), - ) - causal_mask = torch.tril(torch.ones( - max_seqlen, max_seqlen)).type_as(mask_dense)[-q_len:] - mask_dense = mask_dense[..., -q_len:, :max_seqlen] * causal_mask[None] - if dense_mask_type == "bias": - mask_dense = binary_mask_to_bias(mask_dense) - - return ( - dense_to_crow_col(block_mask_dense_output), - block_mask_dense, - mask_dense, - ) - else: - return ( - dense_to_crow_col(block_mask_dense_output), - block_mask_dense, - None, - ) diff --git a/vllm/attention/selector.py b/vllm/attention/selector.py index 4d4886d02b78e..2e3c8638125f7 100644 --- a/vllm/attention/selector.py +++ b/vllm/attention/selector.py @@ -143,7 +143,6 @@ def get_attn_backend( kv_cache_dtype: Optional[str], block_size: int, is_attention_free: bool, - is_blocksparse: bool = False, use_mla: bool = False, ) -> type[AttentionBackend]: """Selects which attention backend to use and lazily imports it.""" @@ -157,7 +156,6 @@ def get_attn_backend( kv_cache_dtype=kv_cache_dtype, block_size=block_size, is_attention_free=is_attention_free, - is_blocksparse=is_blocksparse, use_v1=envs.VLLM_USE_V1, use_mla=use_mla, ) @@ -170,16 +168,9 @@ def _cached_get_attn_backend( kv_cache_dtype: Optional[str], block_size: int, is_attention_free: bool, - is_blocksparse: bool = False, use_v1: bool = False, use_mla: bool = False, ) -> type[AttentionBackend]: - if is_blocksparse: - logger.info("Using BlocksparseFlashAttention backend.") - from vllm.attention.backends.blocksparse_attn import ( - BlocksparseFlashAttentionBackend) - return BlocksparseFlashAttentionBackend - # If there are no attention layers (e.g. we are running Mamba), # use the placeholder NO_ATTENTION if is_attention_free: diff --git a/vllm/model_executor/models/phi3_small.py b/vllm/model_executor/models/phi3_small.py deleted file mode 100644 index 754ddda233f42..0000000000000 --- a/vllm/model_executor/models/phi3_small.py +++ /dev/null @@ -1,465 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import math -from collections.abc import Iterable -from typing import Optional, Union - -import torch -from torch import nn -from transformers.configuration_utils import PretrainedConfig - -from vllm.attention import Attention -from vllm.config import CacheConfig, VllmConfig -from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, - get_tensor_model_parallel_world_size) -from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, - QKVParallelLinear, - RowParallelLinear) -from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.quantization import QuantizationConfig -from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.vocab_parallel_embedding import ( - DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) -from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.platforms import current_platform -from vllm.sequence import IntermediateTensors - -from .interfaces import SupportsPP -from .utils import (AutoWeightsLoader, WeightsMapper, is_pp_missing_parameter, - make_empty_intermediate_tensors_factory, make_layers, - maybe_prefix) - - -def load_column_parallel_weight(param: torch.nn.Parameter, - loaded_weight: torch.Tensor): - tp = get_tensor_model_parallel_world_size() - rk = get_tensor_model_parallel_rank() - assert param.size(0) * tp == loaded_weight.size(0) - s = rk * param.size(0) - e = (rk + 1) * param.size(0) - loaded_weight = loaded_weight[s:e] - assert param.shape == loaded_weight.shape - param.data.copy_(loaded_weight) - - -class HeadMajorQKVParallelLinear(QKVParallelLinear): - - def weight_loader(self, param: torch.nn.Parameter, - loaded_weight: torch.Tensor): - return load_column_parallel_weight(param, loaded_weight) - - -class HeadMajorColumnParallelLinear(MergedColumnParallelLinear): - - def weight_loader(self, param: torch.nn.Parameter, - loaded_weight: torch.Tensor): - return load_column_parallel_weight(param, loaded_weight) - - -@torch.compile(dynamic=True, backend=current_platform.simple_compile_backend) -def quick_gelu(x): - return x * torch.sigmoid(1.702 * x) - - -@torch.compile(dynamic=True, backend=current_platform.simple_compile_backend) -def gegelu(input, limit: Optional[float] = None): - a_gelu, a_linear = input[..., ::2], input[..., 1::2] - if limit is not None: - a_gelu = torch.where(torch.isinf(a_gelu), a_gelu, - a_gelu.clamp(min=None, max=limit)) - a_linear = torch.where( - torch.isinf(a_linear), - a_linear, - a_linear.clamp(min=-limit, max=limit), - ) - out_gelu = quick_gelu(a_gelu) - return out_gelu * (a_linear + 1) - - -class Phi3SmallMLP(nn.Module): - - def __init__( - self, - config: PretrainedConfig, - quant_config: Optional[QuantizationConfig] = None, - ) -> None: - super().__init__() - self.config = config - assert (self.config.hidden_act == "gegelu" - ), "Only `gegelu` is supported for the 4.7 series of models .." - self.hidden_size = config.hidden_size - self.gegelu_limit = config.gegelu_limit - self.intermediate_size = config.intermediate_size - - self.up_proj = HeadMajorColumnParallelLinear( - self.hidden_size, - 2 * [self.intermediate_size], - bias=True, - quant_config=quant_config, - ) - self.down_proj = RowParallelLinear( - self.intermediate_size, - self.hidden_size, - bias=True, - quant_config=quant_config, - ) - - def forward(self, x): - gate_up, _ = self.up_proj(x) - x = gegelu(gate_up) - x, _ = self.down_proj(x) - return x - - -class Phi3SmallSelfAttention(nn.Module): - - def __init__( - self, - config: PretrainedConfig, - layer_idx: int, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = "", - ) -> None: - super().__init__() - self.layer_idx = layer_idx - self.config = config - self.sparse_block_size = config.blocksparse_block_size - self.homo_heads = config.blocksparse_homo_head_pattern - self.local_blocks = config.blocksparse_num_local_blocks - self.vert_stride = config.blocksparse_vert_stride - - assert (config.blocksparse_block_size == - config.blocksparse_triton_kernel_block_size) - - self.hidden_size = config.hidden_size - # Number of Query Heads - self.num_heads = config.num_attention_heads - - self.head_dim = self.hidden_size // self.num_heads - self.tp_size = get_tensor_model_parallel_world_size() - # Number of total Key Value Heads before tensor parallel - self.num_key_value_heads = config.num_key_value_heads - self.num_q_per_kv = self.num_heads // self.num_key_value_heads - if self.tp_size > 1: - assert self.num_key_value_heads % self.tp_size == 0 - self.num_kv_heads_per_partition = max( - 1, self.num_key_value_heads // self.tp_size) - self.num_heads_per_partition = self.num_heads // self.tp_size - - self.max_position_embeddings = config.max_position_embeddings - self.rope_embedding_base = config.rope_embedding_base - self.rope_position_scale = config.rope_position_scale - self.is_causal = True - - norm_factor = None - if config.mup_use_scaling: - norm_factor = self.head_dim / config.mup_attn_multiplier - else: - norm_factor = math.sqrt(self.head_dim) - self.scale = 1 / norm_factor - - self.query_key_value = HeadMajorQKVParallelLinear( - self.hidden_size, - self.head_dim, - self.num_heads, - self.num_key_value_heads, - bias=True, - quant_config=quant_config, - ) - - self.dense = RowParallelLinear(self.hidden_size, - self.hidden_size, - bias=True, - quant_config=quant_config) - - if getattr(self.config, "rope_scaling", None) is not None: - rope_scaling = self.config.rope_scaling - for key in rope_scaling: - if isinstance(rope_scaling[key], list): - rope_scaling[key] = tuple(rope_scaling[key]) - - if "factor" not in rope_scaling: - rope_scaling["factor"] = self.rope_position_scale - else: - rope_scaling = { - "rope_type": "linear", - "factor": self.rope_position_scale, - } - - self.rotary_emb = get_rope( - self.head_dim, - rotary_dim=self.head_dim, - max_position=self.max_position_embeddings, - base=self.rope_embedding_base, - rope_scaling=rope_scaling, - ) - - # blocksparse params - self.blocksparse_block_size = config.blocksparse_block_size - self.blocksparse_num_local_blocks = config.blocksparse_num_local_blocks - self.blocksparse_vert_stride = config.blocksparse_vert_stride - - use_dense_attn = (getattr(self.config, - "dense_attention_every_n_layers", None) - and (self.layer_idx + 1) % - self.config.dense_attention_every_n_layers == 0) - - bs_params = None - if not use_dense_attn: - bs_params = { - 'max_seqlen': self.max_position_embeddings, - 'num_heads': self.num_heads_per_partition, - "num_kv_heads": self.num_kv_heads_per_partition, - "block_size": self.sparse_block_size, - "local_blocks": self.local_blocks, - "vert_stride": self.vert_stride, - "homo_head": self.homo_heads - } - - self.attn = Attention(self.num_heads_per_partition, - self.head_dim, - self.scale, - num_kv_heads=self.num_kv_heads_per_partition, - cache_config=cache_config, - quant_config=quant_config, - blocksparse_params=bs_params, - prefix=f"{prefix}.attn") - - def forward( - self, - positions: torch.Tensor, - hidden_states: torch.Tensor, - ) -> tuple[torch.Tensor, Optional[torch.Tensor], - Optional[tuple[torch.Tensor]]]: - qkv, _ = self.query_key_value(hidden_states) - - qkv = qkv.view(qkv.shape[:-1] + - (-1, (self.num_q_per_kv + 2), self.head_dim)) - q, k, v = qkv.split([self.num_q_per_kv, 1, 1], dim=-2) - - # NOTE: this is required by RotaryEmbed, which indeed does not have to - # TODO: allow 3D QK for rotary forward - q = q.reshape(-1, self.head_dim * self.num_heads_per_partition) - k = k.reshape(-1, self.head_dim * self.num_kv_heads_per_partition) - v = v.reshape(-1, self.head_dim * self.num_kv_heads_per_partition) - - q, k = self.rotary_emb(positions, q, k) - attn_output = self.attn(q, k, v) - output, _ = self.dense(attn_output) - - return output - - -class Phi3SmallDecoderLayer(nn.Module): - - def __init__( - self, - config: PretrainedConfig, - layer_idx: int, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = "", - ): - super().__init__() - self.hidden_size = config.hidden_size - self.self_attn = Phi3SmallSelfAttention(config, - layer_idx, - cache_config=cache_config, - quant_config=quant_config, - prefix=f"{prefix}.self_attn") - self.mlp = Phi3SmallMLP(config, quant_config) - - self.input_layernorm = nn.LayerNorm(config.hidden_size, - eps=config.layer_norm_epsilon) - self.post_attention_layernorm = nn.LayerNorm( - config.hidden_size, eps=config.layer_norm_epsilon) - - def forward( - self, - positions: torch.Tensor, - hidden_states: torch.Tensor, - ) -> torch.Tensor: - residual = hidden_states - hidden_states = self.input_layernorm(hidden_states) - - hidden_states = self.self_attn( - positions=positions, - hidden_states=hidden_states, - ) - hidden_states = residual + hidden_states - - residual = hidden_states - hidden_states = self.post_attention_layernorm(hidden_states) - hidden_states = self.mlp(hidden_states) - hidden_states = residual + hidden_states - return hidden_states - - -class Phi3SmallModel(nn.Module): - - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): - super().__init__() - - config = vllm_config.model_config.hf_config - cache_config = vllm_config.cache_config - quant_config = vllm_config.quant_config - - self.config = config - self.embed_tokens = VocabParallelEmbedding(config.vocab_size, - config.hidden_size) - self.mup_embedding_multiplier = config.mup_embedding_multiplier - self.start_layer, self.end_layer, self.layers = make_layers( - config.num_hidden_layers, - lambda prefix: Phi3SmallDecoderLayer(config, - int(prefix.split('.')[-1]), - cache_config, - quant_config, - prefix=prefix), - prefix=f"{prefix}.layers") - - self.final_layernorm = nn.LayerNorm(config.hidden_size, - eps=config.layer_norm_epsilon) - self.make_empty_intermediate_tensors = ( - make_empty_intermediate_tensors_factory(["hidden_states"], - config.hidden_size)) - - def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: - return self.embed_tokens(input_ids) - - def forward( - self, - input_ids: torch.LongTensor, - positions: Optional[torch.LongTensor], - intermediate_tensors: Optional[IntermediateTensors], - inputs_embeds: Optional[torch.Tensor], - ) -> Union[torch.Tensor, IntermediateTensors]: - if get_pp_group().is_first_rank: - if inputs_embeds is not None: - hidden_states = inputs_embeds - else: - hidden_states = self.get_input_embeddings(input_ids) - if (self.mup_embedding_multiplier is not None - and self.mup_embedding_multiplier > 0.0): - hidden_states = hidden_states * self.mup_embedding_multiplier - else: - assert intermediate_tensors - hidden_states = intermediate_tensors["hidden_states"] - for layer in self.layers[self.start_layer:self.end_layer]: - hidden_states = layer(positions, hidden_states) - if not get_pp_group().is_last_rank: - return IntermediateTensors({"hidden_states": hidden_states}) - hidden_states = self.final_layernorm(hidden_states) - return hidden_states - - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: - params_dict = dict(self.named_parameters()) - loaded_params: set[str] = set() - for name, loaded_weight in weights: - if name.endswith(".bias") and name not in params_dict: - continue - if is_pp_missing_parameter(name, self): - continue - param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - weight_loader(param, loaded_weight) - loaded_params.add(name) - return loaded_params - - -class Phi3SmallForCausalLM(nn.Module, SupportsPP): - _tied_weights_keys = ["lm_head.weight"] - - hf_to_vllm_mapper = WeightsMapper( - orig_to_new_suffix={"rotary_emb.inv_freq": None}) - - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): - super().__init__() - config = vllm_config.model_config.hf_config - quant_config = vllm_config.quant_config - self.config = config - self.quant_config = quant_config - self.model = Phi3SmallModel(vllm_config=vllm_config, - prefix=maybe_prefix(prefix, "model")) - self.vocab_size = config.vocab_size - self.mup_width_multiplier = config.mup_width_multiplier - self.lm_head = ParallelLMHead( - self.vocab_size, - config.hidden_size, - org_num_embeddings=config.vocab_size, - padding_size=DEFAULT_VOCAB_PADDING_SIZE, - quant_config=quant_config, - ) - if self.config.tie_word_embeddings: - self.lm_head.weight = self.model.embed_tokens.weight - self.logits_processor = LogitsProcessor(config.vocab_size) - self.make_empty_intermediate_tensors = ( - self.model.make_empty_intermediate_tensors) - - # tokens in tiktoken but not used - if hasattr(config, 'dummy_token_indices'): - device = self.lm_head.weight.device - self.register_buffer('dummy_token_indices', - torch.LongTensor( - config.dummy_token_indices).to(device), - persistent=False) - else: - self.dummy_token_indices = None - - def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: - return self.model.get_input_embeddings(input_ids) - - def set_input_embeddings(self, value): - self.model.embed_tokens = value - - def get_output_embeddings(self): - return self.lm_head - - def set_output_embeddings(self, value): - self.lm_head = value - - def set_decoder(self, decoder): - self.model = decoder - - def get_decoder(self): - return self.model - - def compute_logits( - self, - hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, - ) -> Optional[torch.Tensor]: - logits = self.logits_processor(self.lm_head, hidden_states, - sampling_metadata) - if self.dummy_token_indices is not None and logits is not None: - logits.index_fill_(-1, self.dummy_token_indices, -torch.inf) - logits = logits / self.mup_width_multiplier - return logits - - def forward( - self, - input_ids: torch.LongTensor, - positions: Optional[torch.LongTensor], - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, - ) -> Union[torch.Tensor, IntermediateTensors]: - output_hidden_states = self.model( - input_ids=input_ids, - positions=positions, - intermediate_tensors=intermediate_tensors, - inputs_embeds=inputs_embeds, - ) - output_hidden_states = output_hidden_states - return output_hidden_states - - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: - loader = AutoWeightsLoader( - self, - skip_prefixes=(["lm_head.weight"] - if self.config.tie_word_embeddings else None)) - return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index 2ca37867b88c6..3440dd656c509 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -110,7 +110,6 @@ _TEXT_GENERATION_MODELS = { "PersimmonForCausalLM": ("persimmon", "PersimmonForCausalLM"), "PhiForCausalLM": ("phi", "PhiForCausalLM"), "Phi3ForCausalLM": ("phi3", "Phi3ForCausalLM"), - "Phi3SmallForCausalLM": ("phi3_small", "Phi3SmallForCausalLM"), "PhiMoEForCausalLM": ("phimoe", "PhiMoEForCausalLM"), "Phi4FlashForCausalLM": ("phi4flash", "Phi4FlashForCausalLM"), "Plamo2ForCausalLM": ("plamo2", "Plamo2ForCausalLM"), diff --git a/vllm/platforms/interface.py b/vllm/platforms/interface.py index b8e788de11c65..1cd5cb5e83db7 100644 --- a/vllm/platforms/interface.py +++ b/vllm/platforms/interface.py @@ -57,7 +57,6 @@ class _Backend(enum.Enum): PALLAS = enum.auto() PALLAS_VLLM_V1 = enum.auto() IPEX = enum.auto() - BLOCK_SPARSE_FLASH_ATTN = enum.auto() DUAL_CHUNK_FLASH_ATTN = enum.auto() DIFFERENTIAL_FLASH_ATTN = enum.auto() NO_ATTENTION = enum.auto() diff --git a/vllm/v1/attention/backends/cpu_attn.py b/vllm/v1/attention/backends/cpu_attn.py index d63b82012a52a..2efbe0de27255 100644 --- a/vllm/v1/attention/backends/cpu_attn.py +++ b/vllm/v1/attention/backends/cpu_attn.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from dataclasses import dataclass -from typing import Any, Optional +from typing import Optional import numpy as np import torch @@ -443,7 +443,6 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]): alibi_slopes: Optional[list[float]], sliding_window: Optional[int], kv_cache_dtype: str, - blocksparse_params: Optional[dict[str, Any]] = None, logits_soft_cap: Optional[float] = None, attn_type: str = AttentionType.DECODER, kv_sharing_target_layer_name: Optional[str] = None, @@ -451,9 +450,6 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]): ) -> None: if kv_sharing_target_layer_name is not None: raise NotImplementedError("KV sharing is not supported in V0.") - if blocksparse_params is not None: - raise ValueError( - "Torch SPDA does not support block-sparse attention.") if logits_soft_cap is not None: logger.warning_once("Torch SPDA does not support logits soft cap. " "Outputs may be slightly off.") diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index a37bf2a7115ba..ad414ee0a1fc9 100755 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -2,7 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Attention layer with FlashAttention.""" from dataclasses import dataclass -from typing import Any, ClassVar, Optional +from typing import ClassVar, Optional import numpy as np import torch @@ -349,15 +349,11 @@ class FlashAttentionImpl(AttentionImpl): alibi_slopes: Optional[list[float]], sliding_window: Optional[int], kv_cache_dtype: str, - blocksparse_params: Optional[dict[str, Any]] = None, logits_soft_cap: Optional[float] = None, attn_type: AttentionType = AttentionType.DECODER, kv_sharing_target_layer_name: Optional[str] = None, use_irope: bool = False, ) -> None: - if blocksparse_params is not None: - raise ValueError( - "FlashAttention does not support block-sparse attention.") self.num_heads = num_heads self.head_size = head_size self.scale = float(scale) diff --git a/vllm/v1/attention/backends/flashinfer.py b/vllm/v1/attention/backends/flashinfer.py index 7f3c4ed129cf0..e1ffa61a6005e 100755 --- a/vllm/v1/attention/backends/flashinfer.py +++ b/vllm/v1/attention/backends/flashinfer.py @@ -4,7 +4,7 @@ from __future__ import annotations from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, Optional +from typing import TYPE_CHECKING, Optional import torch from flashinfer import (BatchDecodeWithPagedKVCacheWrapper, @@ -490,7 +490,6 @@ class FlashInferImpl(AttentionImpl): alibi_slopes: Optional[list[float]], sliding_window: Optional[int], kv_cache_dtype: str, - blocksparse_params: Optional[dict[str, Any]] = None, logits_soft_cap: Optional[float] = None, attn_type: AttentionType = AttentionType.DECODER, kv_sharing_target_layer_name: Optional[int] = None, diff --git a/vllm/v1/attention/backends/flex_attention.py b/vllm/v1/attention/backends/flex_attention.py index c229ec12fd1b4..ad63f92cd88a7 100644 --- a/vllm/v1/attention/backends/flex_attention.py +++ b/vllm/v1/attention/backends/flex_attention.py @@ -3,7 +3,7 @@ """Attention layer with FlashAttention.""" from collections import defaultdict from dataclasses import dataclass -from typing import Any, Optional +from typing import Optional import torch from torch.nn.attention.flex_attention import (BlockMask, _mask_mod_signature, @@ -342,15 +342,10 @@ class FlexAttentionImpl(AttentionImpl): alibi_slopes: Optional[list[float]], sliding_window: Optional[int], kv_cache_dtype: str, - blocksparse_params: Optional[dict[str, Any]] = None, logits_soft_cap: Optional[float] = None, attn_type: AttentionType = AttentionType.DECODER, kv_sharing_target_layer_name: Optional[str] = None, ) -> None: - if blocksparse_params is not None: - # TODO we should support this :think - raise ValueError( - "FlashAttention does not support block-sparse attention.") self.num_heads = num_heads self.head_size = head_size self.scale = float(scale) diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py index 93c8156b16a7f..cf17d93302395 100755 --- a/vllm/v1/attention/backends/mla/common.py +++ b/vllm/v1/attention/backends/mla/common.py @@ -190,7 +190,7 @@ return curr_o @ W_O import functools from abc import abstractmethod from dataclasses import dataclass, field -from typing import TYPE_CHECKING, Any, Generic, Optional, TypeVar, Union +from typing import TYPE_CHECKING, Generic, Optional, TypeVar, Union import torch @@ -754,7 +754,6 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): alibi_slopes: Optional[list[float]], sliding_window: Optional[int], kv_cache_dtype: str, - blocksparse_params: Optional[dict[str, Any]], logits_soft_cap: Optional[float], attn_type: str, kv_sharing_target_layer_name: Optional[str], diff --git a/vllm/v1/attention/backends/mla/cutlass_mla.py b/vllm/v1/attention/backends/mla/cutlass_mla.py index a0f7c39c00412..c787f25cd3adf 100644 --- a/vllm/v1/attention/backends/mla/cutlass_mla.py +++ b/vllm/v1/attention/backends/mla/cutlass_mla.py @@ -2,7 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import os -from typing import Any, Optional +from typing import Optional import torch @@ -74,7 +74,6 @@ class CutlassMLAImpl(MLACommonImpl[MLACommonMetadata]): alibi_slopes: Optional[list[float]], sliding_window: Optional[int], kv_cache_dtype: str, - blocksparse_params: Optional[dict[str, Any]], logits_soft_cap: Optional[float], attn_type: str, kv_sharing_target_layer_name: Optional[str], @@ -82,17 +81,14 @@ class CutlassMLAImpl(MLACommonImpl[MLACommonMetadata]): **mla_args) -> None: super().__init__(num_heads, head_size, scale, num_kv_heads, alibi_slopes, sliding_window, kv_cache_dtype, - blocksparse_params, logits_soft_cap, attn_type, + logits_soft_cap, attn_type, kv_sharing_target_layer_name, **mla_args) - unsupported_features = [ - alibi_slopes, sliding_window, blocksparse_params, logits_soft_cap - ] + unsupported_features = [alibi_slopes, sliding_window, logits_soft_cap] if any(unsupported_features): raise NotImplementedError( "CutlassMLAImpl does not support one of the following: " - "alibi_slopes, sliding_window, blocksparse_params, " - "logits_soft_cap") + "alibi_slopes, sliding_window, logits_soft_cap") if attn_type != AttentionType.DECODER: raise NotImplementedError("Encoder self-attention and " diff --git a/vllm/v1/attention/backends/mla/flashmla.py b/vllm/v1/attention/backends/mla/flashmla.py index 935311aacc35a..d3e5300dbbd6b 100644 --- a/vllm/v1/attention/backends/mla/flashmla.py +++ b/vllm/v1/attention/backends/mla/flashmla.py @@ -2,7 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from dataclasses import dataclass -from typing import Any, ClassVar, Optional +from typing import ClassVar, Optional import torch @@ -119,7 +119,6 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]): alibi_slopes: Optional[list[float]], sliding_window: Optional[int], kv_cache_dtype: str, - blocksparse_params: Optional[dict[str, Any]], logits_soft_cap: Optional[float], attn_type: str, kv_sharing_target_layer_name: Optional[str], @@ -127,20 +126,17 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]): **mla_args) -> None: super().__init__(num_heads, head_size, scale, num_kv_heads, alibi_slopes, sliding_window, kv_cache_dtype, - blocksparse_params, logits_soft_cap, attn_type, + logits_soft_cap, attn_type, kv_sharing_target_layer_name, **mla_args) assert is_flashmla_supported(), \ "FlashMLA is not supported on this device" - unsupported_features = [ - alibi_slopes, sliding_window, blocksparse_params, logits_soft_cap - ] + unsupported_features = [alibi_slopes, sliding_window, logits_soft_cap] if any(unsupported_features): raise NotImplementedError( "FlashMLAImpl does not support one of the following: " - "alibi_slopes, sliding_window, blocksparse_params, " - "logits_soft_cap") + "alibi_slopes, sliding_window, logits_soft_cap") if attn_type != AttentionType.DECODER: raise NotImplementedError("Encoder self-attention and " diff --git a/vllm/v1/attention/backends/mla/rocm_aiter_mla.py b/vllm/v1/attention/backends/mla/rocm_aiter_mla.py index 42a0425836154..834c234558350 100644 --- a/vllm/v1/attention/backends/mla/rocm_aiter_mla.py +++ b/vllm/v1/attention/backends/mla/rocm_aiter_mla.py @@ -2,7 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from dataclasses import dataclass -from typing import Any, ClassVar, Optional +from typing import ClassVar, Optional import torch @@ -167,7 +167,6 @@ class AiterMLAImpl(MLACommonImpl[AiterMLAMetadata]): alibi_slopes: Optional[list[float]], sliding_window: Optional[int], kv_cache_dtype: str, - blocksparse_params: Optional[dict[str, Any]], logits_soft_cap: Optional[float], attn_type: str, kv_sharing_target_layer_name: Optional[str], @@ -175,20 +174,17 @@ class AiterMLAImpl(MLACommonImpl[AiterMLAMetadata]): **mla_args) -> None: super().__init__(num_heads, head_size, scale, num_kv_heads, alibi_slopes, sliding_window, kv_cache_dtype, - blocksparse_params, logits_soft_cap, attn_type, + logits_soft_cap, attn_type, kv_sharing_target_layer_name, **mla_args) assert (num_heads == 16 or num_heads == 128), ( f"Aiter MLA only supports 16 or 128 number of heads.\n" f"Provided {num_heads} number of heads.\n" "Try adjusting tensor_parallel_size value.") - unsupported_features = [ - alibi_slopes, sliding_window, blocksparse_params, logits_soft_cap - ] + unsupported_features = [alibi_slopes, sliding_window, logits_soft_cap] if any(unsupported_features): raise NotImplementedError( "Aiter MLA does not support one of the following: " - "alibi_slopes, sliding_window, blocksparse_params, " - "logits_soft_cap") + "alibi_slopes, sliding_window, logits_soft_cap") from aiter import flash_attn_varlen_func self.flash_attn_varlen_func = flash_attn_varlen_func diff --git a/vllm/v1/attention/backends/mla/triton_mla.py b/vllm/v1/attention/backends/mla/triton_mla.py index 99938f22f108c..700fce68953e5 100644 --- a/vllm/v1/attention/backends/mla/triton_mla.py +++ b/vllm/v1/attention/backends/mla/triton_mla.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Any, Optional +from typing import Optional import torch @@ -42,7 +42,6 @@ class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]): alibi_slopes: Optional[list[float]], sliding_window: Optional[int], kv_cache_dtype: str, - blocksparse_params: Optional[dict[str, Any]], logits_soft_cap: Optional[float], attn_type: str, kv_sharing_target_layer_name: Optional[str], @@ -50,17 +49,14 @@ class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]): **mla_args) -> None: super().__init__(num_heads, head_size, scale, num_kv_heads, alibi_slopes, sliding_window, kv_cache_dtype, - blocksparse_params, logits_soft_cap, attn_type, + logits_soft_cap, attn_type, kv_sharing_target_layer_name, **mla_args) - unsupported_features = [ - alibi_slopes, sliding_window, blocksparse_params, logits_soft_cap - ] + unsupported_features = [alibi_slopes, sliding_window, logits_soft_cap] if any(unsupported_features): raise NotImplementedError( "TritonMLAImpl does not support one of the following: " - "alibi_slopes, sliding_window, blocksparse_params, " - "logits_soft_cap") + "alibi_slopes, sliding_window, logits_soft_cap") if attn_type != AttentionType.DECODER: raise NotImplementedError("Encoder self-attention and " diff --git a/vllm/v1/attention/backends/pallas.py b/vllm/v1/attention/backends/pallas.py index 52e12a1a506f5..ac7980c79e4d0 100644 --- a/vllm/v1/attention/backends/pallas.py +++ b/vllm/v1/attention/backends/pallas.py @@ -2,7 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from dataclasses import dataclass -from typing import Any, Optional +from typing import Optional import torch import torch_xla.core.xla_builder as xb @@ -132,7 +132,6 @@ class PallasAttentionBackendImpl(AttentionImpl): alibi_slopes: Optional[list[float]], sliding_window: Optional[int], kv_cache_dtype: str, - blocksparse_params: Optional[dict[str, Any]] = None, logits_soft_cap: Optional[float] = None, attn_type: str = AttentionType.DECODER, kv_sharing_target_layer_name: Optional[int] = None, @@ -142,9 +141,6 @@ class PallasAttentionBackendImpl(AttentionImpl): logger.warning_once( "Using irope in Pallas is not supported yet, it will fall back " "to global attention for long context.") - if blocksparse_params is not None: - raise ValueError("Paged attention Pallas kernel does " - "not support block-sparse attention.") self.num_heads = num_heads self.head_size = head_size self.scale = float(scale) @@ -158,8 +154,6 @@ class PallasAttentionBackendImpl(AttentionImpl): raise NotImplementedError("Alibi slopes is not supported.") if kv_cache_dtype != "auto": raise NotImplementedError("FP8 KV cache dtype is not supported.") - if blocksparse_params is not None: - raise NotImplementedError("Blocksparse is not supported.") if attn_type != AttentionType.DECODER: raise NotImplementedError("Encoder self-attention and " diff --git a/vllm/v1/attention/backends/rocm_aiter_fa.py b/vllm/v1/attention/backends/rocm_aiter_fa.py index 43fe30a9a89f0..8f75676394494 100644 --- a/vllm/v1/attention/backends/rocm_aiter_fa.py +++ b/vllm/v1/attention/backends/rocm_aiter_fa.py @@ -2,7 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Attention layer with AiterFlashAttention.""" from dataclasses import dataclass -from typing import Any, Optional +from typing import Optional import torch @@ -334,15 +334,11 @@ class AiterFlashAttentionImpl(AttentionImpl): alibi_slopes: Optional[list[float]], sliding_window: Optional[int], kv_cache_dtype: str, - blocksparse_params: Optional[dict[str, Any]] = None, logits_soft_cap: Optional[float] = None, attn_type: AttentionType = AttentionType.DECODER, kv_sharing_target_layer_name: Optional[int] = None, use_irope: bool = False, ) -> None: - if blocksparse_params is not None: - raise ValueError( - "AiterFlashAttention does not support block-sparse attention.") self.num_heads = num_heads self.head_size = head_size self.scale = float(scale) diff --git a/vllm/v1/attention/backends/triton_attn.py b/vllm/v1/attention/backends/triton_attn.py index 79796ac149283..d65ff5ff74ece 100644 --- a/vllm/v1/attention/backends/triton_attn.py +++ b/vllm/v1/attention/backends/triton_attn.py @@ -2,7 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Attention layer with PagedAttention and Triton prefix prefill.""" from dataclasses import dataclass -from typing import Any, ClassVar, Optional +from typing import ClassVar, Optional import torch @@ -205,15 +205,11 @@ class TritonAttentionImpl(AttentionImpl): alibi_slopes: Optional[list[float]], sliding_window: Optional[int], kv_cache_dtype: str, - blocksparse_params: Optional[dict[str, Any]] = None, logits_soft_cap: Optional[float] = None, attn_type: AttentionType = AttentionType.DECODER, kv_sharing_target_layer_name: Optional[int] = None, use_irope: bool = False, ) -> None: - if blocksparse_params is not None: - raise ValueError( - "TritonAttention does not support block-sparse attention.") self.num_heads = num_heads self.head_size = head_size self.scale = float(scale)