mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-24 00:14:34 +08:00
[V0 Deprecation] Deprecate BlockSparse Attention & Phi3-Small (#21217)
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
parent
881e3cbe3b
commit
752c6ade2e
@ -108,7 +108,6 @@ fi
|
||||
if [[ $commands == *" kernels/attention"* ]]; then
|
||||
commands="${commands} \
|
||||
--ignore=kernels/attention/test_attention_selector.py \
|
||||
--ignore=kernels/attention/test_blocksparse_attention.py \
|
||||
--ignore=kernels/attention/test_encoder_decoder_attn.py \
|
||||
--ignore=kernels/attention/test_flash_attn.py \
|
||||
--ignore=kernels/attention/test_flashinfer.py \
|
||||
|
||||
@ -376,7 +376,6 @@ Specified using `--task generate`.
|
||||
| `OrionForCausalLM` | Orion | `OrionStarAI/Orion-14B-Base`, `OrionStarAI/Orion-14B-Chat`, etc. | | ✅︎ | ✅︎ |
|
||||
| `PhiForCausalLM` | Phi | `microsoft/phi-1_5`, `microsoft/phi-2`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||
| `Phi3ForCausalLM` | Phi-4, Phi-3 | `microsoft/Phi-4-mini-instruct`, `microsoft/Phi-4`, `microsoft/Phi-3-mini-4k-instruct`, `microsoft/Phi-3-mini-128k-instruct`, `microsoft/Phi-3-medium-128k-instruct`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||
| `Phi3SmallForCausalLM` | Phi-3-Small | `microsoft/Phi-3-small-8k-instruct`, `microsoft/Phi-3-small-128k-instruct`, etc. | | ✅︎ | ✅︎ |
|
||||
| `PhiMoEForCausalLM` | Phi-3.5-MoE | `microsoft/Phi-3.5-MoE-instruct`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||
| `Phi4FlashForCausalLM` | Phi-4-mini-flash-reasoning | `microsoft/microsoft/Phi-4-mini-instruct`, etc. | | | |
|
||||
| `PersimmonForCausalLM` | Persimmon | `adept/persimmon-8b-base`, `adept/persimmon-8b-chat`, etc. | | ✅︎ | ✅︎ |
|
||||
|
||||
@ -1,441 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import random
|
||||
from typing import Optional
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from tests.kernels.allclose_default import get_default_atol, get_default_rtol
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.attention.ops.blocksparse_attention.interface import (
|
||||
LocalStridedBlockSparseAttn)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import get_max_shared_memory_bytes
|
||||
|
||||
FLOAT32_BYTES = torch.finfo(torch.float).bits // 8
|
||||
# This will change depending on the compute capability.
|
||||
# - 512 as a buffer
|
||||
MAX_SEQ_LEN = get_max_shared_memory_bytes() // FLOAT32_BYTES - 512
|
||||
# MAX_SEQ_LEN = 2771
|
||||
|
||||
# There may not be enough gpu memory due to large NUM_BLOCKS.
|
||||
# Reduce NUM_BLOCKS when it happens.
|
||||
NUM_BLOCKS = 4321 # Arbitrary values for testing
|
||||
PARTITION_SIZE = 512
|
||||
DTYPES = [torch.half, torch.bfloat16]
|
||||
NUM_GEN_SEQS = [3] # Arbitrary values for testing
|
||||
NUM_PREFILL_SEQS = [3] # Arbitrary values for testing
|
||||
NUM_HEADS = [(40, 40)] # Arbitrary values for testing
|
||||
|
||||
HEAD_SIZES = [64, 112]
|
||||
BLOCK_SIZES = [16]
|
||||
USE_ALIBI = [False, True]
|
||||
KV_CACHE_DTYPE = ["auto", "fp8"]
|
||||
SEEDS = [0]
|
||||
CUDA_DEVICES = ['cuda:0']
|
||||
BLOCKSPARSE_LOCAL_BLOCKS = [16]
|
||||
BLOCKSPARSE_VERT_STRIDES = [8]
|
||||
|
||||
BLOCKSPARSE_BLOCK_SIZES = [64]
|
||||
BLOCKSPARSE_HEADS_SLIDINGS = [2, -1]
|
||||
BLOCKSPARSE_HOMO_HEADS = [True, False]
|
||||
|
||||
|
||||
def ref_masked_attention(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
scale: float,
|
||||
attn_mask: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
attn_weights = scale * torch.einsum("qhd,khd->hqk", query, key).float()
|
||||
if attn_mask is not None:
|
||||
attn_weights = attn_weights + attn_mask.float()
|
||||
attn_weights = torch.softmax(attn_weights, dim=-1).to(value.dtype)
|
||||
out = torch.einsum("hqk,khd->qhd", attn_weights, value)
|
||||
return out
|
||||
|
||||
|
||||
def ref_single_query_cached_kv_attention(
|
||||
output: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
num_queries_per_kv: int,
|
||||
key_cache: torch.Tensor,
|
||||
value_cache: torch.Tensor,
|
||||
block_tables: torch.Tensor,
|
||||
seq_lens: torch.Tensor,
|
||||
scale: float,
|
||||
alibi_slopes: Optional[torch.Tensor],
|
||||
tp_rank: int = 0,
|
||||
blocksparse_local_blocks: int = 0,
|
||||
blocksparse_vert_stride: int = 1,
|
||||
blocksparse_block_size: int = 64,
|
||||
blocksparse_head_sliding_step: int = 0,
|
||||
) -> None:
|
||||
num_query_heads = query.shape[1]
|
||||
num_kv_heads = value_cache.shape[1]
|
||||
head_size = value_cache.shape[2]
|
||||
block_size = value_cache.shape[3]
|
||||
num_seqs = query.shape[0]
|
||||
|
||||
block_tables_lst = block_tables.cpu().tolist()
|
||||
seq_lens_lst = seq_lens.cpu().tolist()
|
||||
for i in range(num_seqs):
|
||||
q = query[i].unsqueeze(0)
|
||||
block_table = block_tables_lst[i]
|
||||
seq_len = int(seq_lens_lst[i])
|
||||
|
||||
keys_lst: list[torch.Tensor] = []
|
||||
values_lst: list[torch.Tensor] = []
|
||||
for j in range(seq_len):
|
||||
block_number = int(block_table[j // block_size])
|
||||
block_offset = j % block_size
|
||||
|
||||
k = key_cache[block_number, :, :, block_offset, :]
|
||||
k = k.reshape(num_kv_heads, head_size)
|
||||
keys_lst.append(k)
|
||||
|
||||
v = value_cache[block_number, :, :, block_offset]
|
||||
values_lst.append(v)
|
||||
keys = torch.stack(keys_lst, dim=0)
|
||||
values = torch.stack(values_lst, dim=0)
|
||||
if num_queries_per_kv > 1:
|
||||
# Handle MQA and GQA
|
||||
keys = torch.repeat_interleave(keys, num_queries_per_kv, dim=1)
|
||||
values = torch.repeat_interleave(values, num_queries_per_kv, dim=1)
|
||||
|
||||
alibi_bias = None
|
||||
if alibi_slopes is not None:
|
||||
# Create the ALiBi bias used in the paged attention kernel.
|
||||
position_ids = torch.arange(seq_len).int()
|
||||
alibi_bias = (position_ids - seq_len + 1).float()
|
||||
alibi_bias = alibi_slopes.view(-1, 1, 1) * alibi_bias.view(
|
||||
1, 1, -1)
|
||||
|
||||
if blocksparse_vert_stride >= 1:
|
||||
bsize = blocksparse_block_size
|
||||
hsliding = blocksparse_head_sliding_step
|
||||
vert = blocksparse_vert_stride
|
||||
locals = blocksparse_local_blocks
|
||||
qb = (seq_len - 1) // bsize
|
||||
attn_mask = q.new_zeros(
|
||||
(num_query_heads, 1, seq_len)).float() - torch.inf
|
||||
for h in range(num_query_heads):
|
||||
if hsliding >= 0: # slide with q heads
|
||||
bs_offset = (tp_rank * num_query_heads + h) * hsliding + 1
|
||||
else: # slide with kv heads
|
||||
bs_offset = (tp_rank * num_kv_heads +
|
||||
h // num_queries_per_kv) * (-hsliding) + 1
|
||||
for kb in range(qb + 1):
|
||||
kj = kb * bsize
|
||||
if (qb - kb) < locals or \
|
||||
(kb + bs_offset) % vert == 0:
|
||||
attn_mask[h, 0, kj:min(kj + bsize, seq_len)] = 0
|
||||
if alibi_bias is not None:
|
||||
attn_mask += alibi_bias
|
||||
else:
|
||||
attn_mask = alibi_bias
|
||||
|
||||
out = ref_masked_attention(q, keys, values, scale, attn_mask=attn_mask)
|
||||
out = out.view(num_query_heads, head_size)
|
||||
output[i].copy_(out, non_blocking=True)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("version", ["v1", "v2"])
|
||||
@pytest.mark.parametrize("num_seqs", NUM_GEN_SEQS)
|
||||
@pytest.mark.parametrize("num_heads", NUM_HEADS)
|
||||
@pytest.mark.parametrize("head_size", HEAD_SIZES)
|
||||
@pytest.mark.parametrize("use_alibi", USE_ALIBI)
|
||||
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE)
|
||||
@pytest.mark.parametrize("seed", SEEDS)
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
@pytest.mark.parametrize("blocksparse_local_blocks", BLOCKSPARSE_LOCAL_BLOCKS)
|
||||
@pytest.mark.parametrize("blocksparse_vert_stride", BLOCKSPARSE_VERT_STRIDES)
|
||||
@pytest.mark.parametrize("blocksparse_block_size", BLOCKSPARSE_BLOCK_SIZES)
|
||||
@pytest.mark.parametrize("blocksparse_head_sliding_step",
|
||||
BLOCKSPARSE_HEADS_SLIDINGS)
|
||||
def test_paged_attention(
|
||||
kv_cache_factory,
|
||||
version: str,
|
||||
num_seqs: int,
|
||||
num_heads: tuple[int, int],
|
||||
head_size: int,
|
||||
use_alibi: bool,
|
||||
block_size: int,
|
||||
dtype: torch.dtype,
|
||||
kv_cache_dtype: str,
|
||||
seed: int,
|
||||
device: str,
|
||||
blocksparse_local_blocks: int,
|
||||
blocksparse_vert_stride: int,
|
||||
blocksparse_block_size: int,
|
||||
blocksparse_head_sliding_step: int,
|
||||
) -> None:
|
||||
current_platform.seed_everything(seed)
|
||||
torch.set_default_device(device)
|
||||
scale = float(1.0 / (head_size**0.5))
|
||||
num_query_heads, num_kv_heads = num_heads
|
||||
query = torch.empty(num_seqs, num_query_heads, head_size, dtype=dtype)
|
||||
query.uniform_(-scale, scale)
|
||||
|
||||
assert num_query_heads % num_kv_heads == 0
|
||||
num_queries_per_kv = num_query_heads // num_kv_heads
|
||||
alibi_slopes = None
|
||||
if use_alibi:
|
||||
alibi_slopes = torch.rand(num_query_heads, dtype=torch.float)
|
||||
|
||||
seq_lens = [random.randint(1, MAX_SEQ_LEN) for _ in range(num_seqs)]
|
||||
seq_lens[-1] = MAX_SEQ_LEN
|
||||
max_seq_len = max(seq_lens)
|
||||
seq_lens = torch.tensor(seq_lens, dtype=torch.int)
|
||||
|
||||
# Create the block tables.
|
||||
max_num_blocks_per_seq = (max_seq_len + block_size - 1) // block_size
|
||||
block_tables = []
|
||||
for _ in range(num_seqs):
|
||||
block_table = [
|
||||
random.randint(0, NUM_BLOCKS - 1)
|
||||
for _ in range(max_num_blocks_per_seq)
|
||||
]
|
||||
block_tables.append(block_table)
|
||||
block_tables = torch.tensor(block_tables, dtype=torch.int)
|
||||
|
||||
# Create the KV caches.
|
||||
key_caches, value_caches = kv_cache_factory(NUM_BLOCKS, block_size, 1,
|
||||
num_kv_heads, head_size,
|
||||
kv_cache_dtype, dtype, seed,
|
||||
device)
|
||||
key_cache, value_cache = key_caches[0], value_caches[0]
|
||||
|
||||
# Using default kv_scale
|
||||
k_scale = v_scale = torch.tensor(1.0, dtype=torch.float32, device=device)
|
||||
tp_rank = 0
|
||||
|
||||
# Call the paged attention kernel.
|
||||
output = torch.empty_like(query)
|
||||
if version == "v1":
|
||||
ops.paged_attention_v1(
|
||||
output,
|
||||
query,
|
||||
key_cache,
|
||||
value_cache,
|
||||
num_kv_heads,
|
||||
scale,
|
||||
block_tables,
|
||||
seq_lens,
|
||||
block_size,
|
||||
max_seq_len,
|
||||
alibi_slopes,
|
||||
kv_cache_dtype,
|
||||
k_scale,
|
||||
v_scale,
|
||||
tp_rank=tp_rank,
|
||||
blocksparse_local_blocks=blocksparse_local_blocks,
|
||||
blocksparse_vert_stride=blocksparse_vert_stride,
|
||||
blocksparse_block_size=blocksparse_block_size,
|
||||
blocksparse_head_sliding_step=blocksparse_head_sliding_step,
|
||||
)
|
||||
elif version == "v2":
|
||||
num_partitions = ((max_seq_len + PARTITION_SIZE - 1) // PARTITION_SIZE)
|
||||
assert PARTITION_SIZE % block_size == 0
|
||||
num_seqs, num_heads, head_size = output.shape
|
||||
tmp_output = torch.empty(
|
||||
size=(num_seqs, num_heads, num_partitions, head_size),
|
||||
dtype=output.dtype,
|
||||
)
|
||||
exp_sums = torch.empty(
|
||||
size=(num_seqs, num_heads, num_partitions),
|
||||
dtype=torch.float32,
|
||||
)
|
||||
max_logits = torch.empty_like(exp_sums)
|
||||
ops.paged_attention_v2(
|
||||
output,
|
||||
exp_sums,
|
||||
max_logits,
|
||||
tmp_output,
|
||||
query,
|
||||
key_cache,
|
||||
value_cache,
|
||||
num_kv_heads,
|
||||
scale,
|
||||
block_tables,
|
||||
seq_lens,
|
||||
block_size,
|
||||
max_seq_len,
|
||||
alibi_slopes,
|
||||
kv_cache_dtype,
|
||||
k_scale,
|
||||
v_scale,
|
||||
tp_rank=tp_rank,
|
||||
blocksparse_local_blocks=blocksparse_local_blocks,
|
||||
blocksparse_vert_stride=blocksparse_vert_stride,
|
||||
blocksparse_block_size=blocksparse_block_size,
|
||||
blocksparse_head_sliding_step=blocksparse_head_sliding_step,
|
||||
)
|
||||
else:
|
||||
raise AssertionError(f"Unknown version: {version}")
|
||||
|
||||
# Run the reference implementation.
|
||||
if kv_cache_dtype == "fp8":
|
||||
# Convert cache data back to dtype.
|
||||
x = 16 // torch.tensor([], dtype=dtype).element_size()
|
||||
key_cache_shape = (NUM_BLOCKS, num_kv_heads, head_size // x,
|
||||
block_size, x)
|
||||
dequantized_key_cache = torch.empty(size=key_cache_shape,
|
||||
dtype=dtype,
|
||||
device=device)
|
||||
ops.convert_fp8(dequantized_key_cache, key_cache)
|
||||
key_cache = dequantized_key_cache
|
||||
|
||||
value_cache_shape = value_cache.shape
|
||||
dequantized_value_cache = torch.empty(size=value_cache_shape,
|
||||
dtype=dtype,
|
||||
device=device)
|
||||
ops.convert_fp8(dequantized_value_cache, value_cache)
|
||||
value_cache = dequantized_value_cache
|
||||
|
||||
ref_output = torch.empty_like(query)
|
||||
ref_single_query_cached_kv_attention(
|
||||
ref_output,
|
||||
query,
|
||||
num_queries_per_kv,
|
||||
key_cache,
|
||||
value_cache,
|
||||
block_tables,
|
||||
seq_lens,
|
||||
scale,
|
||||
alibi_slopes,
|
||||
tp_rank,
|
||||
blocksparse_local_blocks,
|
||||
blocksparse_vert_stride,
|
||||
blocksparse_block_size,
|
||||
blocksparse_head_sliding_step,
|
||||
)
|
||||
|
||||
# NOTE(woosuk): Due to the kernel-level differences in the two
|
||||
# implementations, there is a small numerical difference in the two
|
||||
# outputs. Thus, we use a relaxed tolerance for the test.
|
||||
atol = get_default_atol(output) if current_platform.is_rocm() else 1e-3
|
||||
rtol = get_default_rtol(output) if current_platform.is_rocm() else 1e-5
|
||||
|
||||
# NOTE(zhaoyang): FP8 KV Cache will introduce quantization error,
|
||||
# so we use a relaxed tolerance for the test.
|
||||
atol, rtol = 1e-3, 1e-5
|
||||
if kv_cache_dtype == "fp8":
|
||||
atol, rtol = 1e-2, 1e-5
|
||||
torch.testing.assert_close(output, ref_output, atol=atol, rtol=rtol)
|
||||
|
||||
|
||||
def ref_multi_query_kv_attention(
|
||||
cu_seq_lens: list[int],
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
scale: float,
|
||||
dtype: torch.dtype,
|
||||
) -> torch.Tensor:
|
||||
num_seqs = len(cu_seq_lens) - 1
|
||||
ref_outputs = []
|
||||
for i in range(num_seqs):
|
||||
start_idx = cu_seq_lens[i]
|
||||
end_idx = cu_seq_lens[i + 1]
|
||||
seq_len = end_idx - start_idx
|
||||
|
||||
# Create attention mask.
|
||||
attn_mask = torch.triu(torch.ones(seq_len, seq_len, dtype=dtype),
|
||||
diagonal=1)
|
||||
attn_mask = attn_mask * torch.finfo(dtype).min
|
||||
attn_mask = attn_mask.to(dtype=dtype)
|
||||
|
||||
ref_output = ref_masked_attention(
|
||||
query[start_idx:end_idx],
|
||||
key[start_idx:end_idx],
|
||||
value[start_idx:end_idx],
|
||||
scale,
|
||||
attn_mask=attn_mask,
|
||||
)
|
||||
ref_outputs.append(ref_output)
|
||||
ref_output = torch.cat(ref_outputs, dim=0)
|
||||
return ref_output
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_seqs", NUM_PREFILL_SEQS)
|
||||
@pytest.mark.parametrize("num_heads", NUM_HEADS)
|
||||
@pytest.mark.parametrize("head_size", HEAD_SIZES)
|
||||
@pytest.mark.parametrize("blocksparse_local_blocks", BLOCKSPARSE_LOCAL_BLOCKS)
|
||||
@pytest.mark.parametrize("blocksparse_vert_stride", BLOCKSPARSE_VERT_STRIDES)
|
||||
@pytest.mark.parametrize("blocksparse_block_size", BLOCKSPARSE_BLOCK_SIZES)
|
||||
@pytest.mark.parametrize("blocksparse_homo_heads", BLOCKSPARSE_HOMO_HEADS)
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("seed", SEEDS)
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
@torch.inference_mode()
|
||||
def test_varlen_blocksparse_attention_prefill(
|
||||
num_seqs: int,
|
||||
num_heads: tuple[int, int],
|
||||
head_size: int,
|
||||
blocksparse_local_blocks: int,
|
||||
blocksparse_vert_stride: int,
|
||||
blocksparse_block_size: int,
|
||||
blocksparse_homo_heads: bool,
|
||||
dtype: torch.dtype,
|
||||
seed: int,
|
||||
device: str,
|
||||
) -> None:
|
||||
current_platform.seed_everything(seed)
|
||||
torch.set_default_device(device)
|
||||
# MAX_SEQ_LEN sometimes causes OOM in the reference implementation.
|
||||
# As the xformers library is already tested with its own tests, we can use
|
||||
# a smaller MAX_SEQ_LEN here.
|
||||
max_len = min(MAX_SEQ_LEN, 4096)
|
||||
seq_lens = random.sample(range(1, max_len), num_seqs)
|
||||
cu_seq_lens = torch.cumsum(torch.tensor([0] + seq_lens), dim=0)
|
||||
num_tokens = sum(seq_lens)
|
||||
|
||||
scale = float(1.0 / (head_size**0.5))
|
||||
num_query_heads, num_kv_heads = num_heads
|
||||
assert num_query_heads % num_kv_heads == 0
|
||||
num_queries_per_kv = num_query_heads // num_kv_heads
|
||||
|
||||
qkv = torch.empty(num_tokens,
|
||||
num_query_heads + 2 * num_kv_heads,
|
||||
head_size,
|
||||
dtype=dtype)
|
||||
qkv.uniform_(-scale, scale)
|
||||
query, key, value = qkv.split(
|
||||
[num_query_heads, num_kv_heads, num_kv_heads], dim=1)
|
||||
|
||||
bs_attn_op = LocalStridedBlockSparseAttn(
|
||||
num_query_heads,
|
||||
max_len,
|
||||
local_blocks=blocksparse_local_blocks,
|
||||
vert_stride=blocksparse_vert_stride,
|
||||
block_size=blocksparse_block_size,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
homo_head=blocksparse_homo_heads)
|
||||
|
||||
output = bs_attn_op(query,
|
||||
key,
|
||||
value,
|
||||
cu_seq_lens.to(device),
|
||||
sm_scale=scale)
|
||||
|
||||
if num_queries_per_kv > 1:
|
||||
# Handle MQA and GQA
|
||||
key = torch.repeat_interleave(key, num_queries_per_kv, dim=1)
|
||||
value = torch.repeat_interleave(value, num_queries_per_kv, dim=1)
|
||||
|
||||
ref_output = ref_multi_query_kv_attention(
|
||||
cu_seq_lens.tolist(),
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
scale,
|
||||
dtype,
|
||||
)
|
||||
torch.testing.assert_close(output, ref_output, atol=1e-2, rtol=1e-2)
|
||||
@ -33,8 +33,12 @@ def test_selector(monkeypatch: pytest.MonkeyPatch):
|
||||
|
||||
# change the attention backend to triton MLA
|
||||
m.setenv(STR_BACKEND_ENV_VAR, "TRITON_MLA")
|
||||
backend = get_attn_backend(576, torch.bfloat16, "auto", 16, False,
|
||||
False, True)
|
||||
backend = get_attn_backend(576,
|
||||
torch.bfloat16,
|
||||
"auto",
|
||||
16,
|
||||
False,
|
||||
use_mla=True)
|
||||
assert (backend.get_name() == "TRITON_MLA"
|
||||
or backend.get_name() == "TRITON_MLA_VLLM_V1")
|
||||
|
||||
@ -42,15 +46,23 @@ def test_selector(monkeypatch: pytest.MonkeyPatch):
|
||||
# If use_mla is true
|
||||
# The selected backend is triton MLA
|
||||
m.setenv(STR_BACKEND_ENV_VAR, None)
|
||||
backend = get_attn_backend(576, torch.bfloat16, "auto", 16, False,
|
||||
False, True)
|
||||
backend = get_attn_backend(576,
|
||||
torch.bfloat16,
|
||||
"auto",
|
||||
16,
|
||||
False,
|
||||
use_mla=True)
|
||||
assert (backend.get_name() == "TRITON_MLA"
|
||||
or backend.get_name() == "TRITON_MLA_VLLM_V1")
|
||||
|
||||
# change the attention backend to AITER MLA
|
||||
m.setenv(STR_BACKEND_ENV_VAR, "ROCM_AITER_MLA")
|
||||
backend = get_attn_backend(576, torch.bfloat16, "auto", 1, False,
|
||||
False, True)
|
||||
backend = get_attn_backend(576,
|
||||
torch.bfloat16,
|
||||
"auto",
|
||||
1,
|
||||
False,
|
||||
use_mla=True)
|
||||
assert (backend.get_name() == "ROCM_AITER_MLA"
|
||||
or backend.get_name() == "ROCM_AITER_MLA_VLLM_V1")
|
||||
|
||||
@ -60,7 +72,11 @@ def test_selector(monkeypatch: pytest.MonkeyPatch):
|
||||
# The selected backend is ROCM_AITER_MLA
|
||||
m.setenv(STR_BACKEND_ENV_VAR, None)
|
||||
m.setenv("VLLM_ROCM_USE_AITER", "1")
|
||||
backend = get_attn_backend(576, torch.bfloat16, "auto", 1, False,
|
||||
False, True)
|
||||
backend = get_attn_backend(576,
|
||||
torch.bfloat16,
|
||||
"auto",
|
||||
1,
|
||||
False,
|
||||
use_mla=True)
|
||||
assert (backend.get_name() == "ROCM_AITER_MLA"
|
||||
or backend.get_name() == "ROCM_AITER_MLA_VLLM_V1")
|
||||
|
||||
@ -247,10 +247,6 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
|
||||
"PersimmonForCausalLM": _HfExamplesInfo("adept/persimmon-8b-chat"),
|
||||
"PhiForCausalLM": _HfExamplesInfo("microsoft/phi-2"),
|
||||
"Phi3ForCausalLM": _HfExamplesInfo("microsoft/Phi-3-mini-4k-instruct"),
|
||||
# Blocksparse attention not supported in V1 yet
|
||||
"Phi3SmallForCausalLM": _HfExamplesInfo("microsoft/Phi-3-small-8k-instruct",
|
||||
trust_remote_code=True,
|
||||
v0_only=True),
|
||||
"Phi4FlashForCausalLM": _HfExamplesInfo("microsoft/Phi-4-mini-flash-reasoning", # noqa: E501
|
||||
trust_remote_code=True,
|
||||
v0_only=True,
|
||||
|
||||
@ -269,7 +269,6 @@ class AttentionImpl(ABC, Generic[T]):
|
||||
alibi_slopes: Optional[List[float]] = None,
|
||||
sliding_window: Optional[int] = None,
|
||||
kv_cache_dtype: str = "auto",
|
||||
blocksparse_params: Optional[Dict[str, Any]] = None,
|
||||
logits_soft_cap: Optional[float] = None,
|
||||
attn_type: str = AttentionType.DECODER,
|
||||
kv_sharing_target_layer_name: Optional[str] = None,
|
||||
|
||||
@ -1,466 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Dict, List, Optional, Tuple, Type
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||
AttentionLayer,
|
||||
AttentionMetadata, AttentionType)
|
||||
from vllm.attention.backends.utils import (CommonAttentionState,
|
||||
CommonMetadataBuilder)
|
||||
from vllm.attention.ops.blocksparse_attention.interface import (
|
||||
LocalStridedBlockSparseAttn, get_head_sliding_step)
|
||||
from vllm.attention.ops.paged_attn import PagedAttention
|
||||
from vllm.distributed import (get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size)
|
||||
|
||||
|
||||
@dataclass
|
||||
class BlocksparseParams:
|
||||
max_seqlen: int
|
||||
|
||||
# Num q heads per tensor-parallel rank/partition
|
||||
num_heads: int # per TP partition
|
||||
# Num kv heads per tensor-parallel rank/partition
|
||||
num_kv_heads: int
|
||||
|
||||
# block size used for blocksparse attention.
|
||||
# This is the block_size used in `local_blocks`, `vert_stride`.
|
||||
block_size: int
|
||||
|
||||
# Number of blocks for local attention, i.e., number of
|
||||
# local attended tokens / `sparse_block_size`
|
||||
local_blocks: int
|
||||
|
||||
# Attend to one block per every `vert_stride` blocks.
|
||||
# Controlling the sparsity
|
||||
vert_stride: int
|
||||
"""
|
||||
If to use the same vertical stride offset for all heads,
|
||||
i.e., attend to the same block of tokens on all heads.
|
||||
By default, it is False, i.e., attention on the non-local
|
||||
blocks depends on the `head_idx`, that is on
|
||||
blocks satisfying
|
||||
`(block_idx + head_idx * head_sliding_step + 1) % vert_stride == 0`
|
||||
where `head_sliding_step=max(1, int(vert_stride / num_total_heads))`,
|
||||
`block_idx = position_id // sparse_block_size`.
|
||||
See `..ops.blocksparse_attention.utils:get_sparse_attn_mask`
|
||||
for more detail.
|
||||
"""
|
||||
homo_head: bool = False
|
||||
|
||||
# If within a group, the kv offsets that each q attends is the same or no.
|
||||
homo_head_group: bool = False
|
||||
|
||||
# Decided by homo_head and homo_head group
|
||||
head_sliding_step: int = field(init=False)
|
||||
|
||||
# range of q heads to for a TP rank
|
||||
active_head_range: Tuple = field(init=False)
|
||||
|
||||
def __post_init__(self):
|
||||
assert self.block_size > 0
|
||||
assert self.local_blocks >= 0
|
||||
assert self.vert_stride >= 1
|
||||
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
total_heads = tp_size * self.num_heads
|
||||
total_kv_heads = tp_size * self.num_kv_heads
|
||||
|
||||
if self.homo_head:
|
||||
self.head_sliding_step = 0
|
||||
elif self.homo_head_group:
|
||||
head_sliding_step = get_head_sliding_step(total_kv_heads,
|
||||
self.vert_stride)
|
||||
# negative indicates sliding along kv heads, i.e., homo q group
|
||||
self.head_sliding_step = -head_sliding_step
|
||||
else:
|
||||
self.head_sliding_step = get_head_sliding_step(
|
||||
total_heads, self.vert_stride)
|
||||
|
||||
self.active_head_range = (
|
||||
tp_rank * self.num_heads,
|
||||
(tp_rank + 1) * self.num_heads,
|
||||
)
|
||||
|
||||
|
||||
class BlocksparseFlashAttentionBackend(AttentionBackend):
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "BLOCK_SPARSE_FLASH_ATTN"
|
||||
|
||||
@staticmethod
|
||||
def get_impl_cls() -> Type["BlocksparseFlashAttentionImpl"]:
|
||||
return BlocksparseFlashAttentionImpl
|
||||
|
||||
@staticmethod
|
||||
def get_metadata_cls() -> Type["AttentionMetadata"]:
|
||||
return BlocksparseFlashAttentionMetadata
|
||||
|
||||
@staticmethod
|
||||
def get_builder_cls() -> Type["BlocksparseFlashAttentionMetadataBuilder"]:
|
||||
return BlocksparseFlashAttentionMetadataBuilder
|
||||
|
||||
@staticmethod
|
||||
def get_state_cls() -> Type["CommonAttentionState"]:
|
||||
return CommonAttentionState
|
||||
|
||||
@staticmethod
|
||||
def get_kv_cache_shape(
|
||||
num_blocks: int,
|
||||
block_size: int,
|
||||
num_kv_heads: int,
|
||||
head_size: int,
|
||||
) -> Tuple[int, ...]:
|
||||
return PagedAttention.get_kv_cache_shape(num_blocks, block_size,
|
||||
num_kv_heads, head_size)
|
||||
|
||||
@staticmethod
|
||||
def swap_blocks(
|
||||
src_kv_cache: torch.Tensor,
|
||||
dst_kv_cache: torch.Tensor,
|
||||
src_to_dst: Dict[int, int],
|
||||
) -> None:
|
||||
PagedAttention.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst)
|
||||
|
||||
@staticmethod
|
||||
def copy_blocks(
|
||||
kv_caches: List[torch.Tensor],
|
||||
src_to_dists: Dict[int, List[int]],
|
||||
) -> None:
|
||||
PagedAttention.copy_blocks(kv_caches, src_to_dists)
|
||||
|
||||
|
||||
@dataclass
|
||||
class BlocksparseFlashAttentionMetadata(AttentionMetadata):
|
||||
"""A copy of Metadata for FlashAttentionBackend,
|
||||
to avoid having to install flash_attn.
|
||||
|
||||
NOTE: Any python object stored here is not updated when it is
|
||||
cuda-graph replayed. If you have values that need to be changed
|
||||
dynamically, it should be stored in tensor. The tensor has to be
|
||||
updated from `CUDAGraphRunner.forward` API.
|
||||
"""
|
||||
# (batch_size,). The sequence length per sequence. Sequence length means
|
||||
# the computed tokens + new tokens None if it is a decoding.
|
||||
seq_lens: Optional[List[int]]
|
||||
# seq_lens stored as a tensor.
|
||||
seq_lens_tensor: Optional[torch.Tensor]
|
||||
|
||||
# NOTE(sang): Definition of context_len, query_len, and seq_len.
|
||||
# |---------- N-1 iteration --------|
|
||||
# |---------------- N iteration ---------------------|
|
||||
# |- tokenA -|......................|-- newTokens ---|
|
||||
# |---------- context_len ----------|
|
||||
# |-------------------- seq_len ----------------------|
|
||||
# |-- query_len ---|
|
||||
|
||||
# Maximum query length in the batch. None for decoding.
|
||||
max_query_len: Optional[int]
|
||||
# Maximum sequence length among prefill batch. 0 if there are decoding
|
||||
# requests only.
|
||||
max_prefill_seq_len: int
|
||||
# Maximum sequence length among decode batch. 0 if there are prefill
|
||||
# requests only.
|
||||
max_decode_seq_len: int
|
||||
# (batch_size + 1,). The cumulative subquery lengths of the sequences in
|
||||
# the batch, used to index into subquery. E.g., if the subquery length
|
||||
# is [4, 6], it is [0, 4, 10].
|
||||
query_start_loc: Optional[torch.Tensor]
|
||||
# (batch_size + 1,). The cumulative sequence lengths of the sequences in
|
||||
# the batch, used to index into sequence. E.g., if the sequence length is
|
||||
# [4, 6], it is [0, 4, 10].
|
||||
seq_start_loc: Optional[torch.Tensor]
|
||||
# (batch_size,) A tensor of context lengths (tokens that are computed
|
||||
# so far).
|
||||
context_lens_tensor: Optional[torch.Tensor]
|
||||
|
||||
# (batch_size, max_blocks_per_seq).
|
||||
# Block addresses per sequence. (Seq id -> list of physical block)
|
||||
# E.g., [0, 1, 2] means tokens are stored in 0th, 1st, and 2nd blocks
|
||||
# in the kv cache. Each block can contain up to block_size tokens.
|
||||
# 2nd dimensions are padded up to max_blocks_per_seq if it is cuda-graph
|
||||
# captured.
|
||||
block_tables: Optional[torch.Tensor]
|
||||
|
||||
# Whether or not if cuda graph is enabled.
|
||||
# Cuda-graph is currently enabled for decoding only.
|
||||
# TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention.
|
||||
use_cuda_graph: bool
|
||||
|
||||
# Max number of query tokens for among request in the batch.
|
||||
max_decode_query_len: Optional[int] = None
|
||||
|
||||
_cached_prefill_metadata: Optional[
|
||||
"BlocksparseFlashAttentionMetadata"] = None
|
||||
_cached_decode_metadata: Optional[
|
||||
"BlocksparseFlashAttentionMetadata"] = None
|
||||
|
||||
@property
|
||||
def prefill_metadata(
|
||||
self) -> Optional["BlocksparseFlashAttentionMetadata"]:
|
||||
if self.num_prefills == 0:
|
||||
return None
|
||||
|
||||
if self._cached_prefill_metadata is not None:
|
||||
return self._cached_prefill_metadata
|
||||
|
||||
assert self.seq_lens is not None
|
||||
assert self.seq_lens_tensor is not None
|
||||
assert self.query_start_loc is not None
|
||||
assert self.context_lens_tensor is not None
|
||||
assert self.block_tables is not None
|
||||
assert self.seq_start_loc is not None
|
||||
|
||||
self._cached_prefill_metadata = BlocksparseFlashAttentionMetadata(
|
||||
num_prefills=self.num_prefills,
|
||||
num_prefill_tokens=self.num_prefill_tokens,
|
||||
num_decode_tokens=0,
|
||||
slot_mapping=self.slot_mapping[:self.num_prefill_tokens],
|
||||
multi_modal_placeholder_index_maps=self.
|
||||
multi_modal_placeholder_index_maps,
|
||||
enable_kv_scales_calculation=self.enable_kv_scales_calculation,
|
||||
seq_lens=self.seq_lens[:self.num_prefills],
|
||||
seq_lens_tensor=self.seq_lens_tensor[:self.num_prefills],
|
||||
max_query_len=self.max_query_len,
|
||||
max_prefill_seq_len=self.max_prefill_seq_len,
|
||||
max_decode_seq_len=0,
|
||||
query_start_loc=self.query_start_loc[:self.num_prefills + 1],
|
||||
seq_start_loc=self.seq_start_loc[:self.num_prefills + 1],
|
||||
context_lens_tensor=self.context_lens_tensor[:self.num_prefills],
|
||||
block_tables=self.block_tables[:self.num_prefills],
|
||||
use_cuda_graph=False,
|
||||
)
|
||||
return self._cached_prefill_metadata
|
||||
|
||||
@property
|
||||
def decode_metadata(self) -> Optional["BlocksparseFlashAttentionMetadata"]:
|
||||
if self.num_decode_tokens == 0:
|
||||
return None
|
||||
|
||||
if self._cached_decode_metadata is not None:
|
||||
return self._cached_decode_metadata
|
||||
assert self.block_tables is not None
|
||||
assert self.seq_lens_tensor is not None
|
||||
|
||||
self._cached_decode_metadata = BlocksparseFlashAttentionMetadata(
|
||||
num_prefills=0,
|
||||
num_prefill_tokens=0,
|
||||
num_decode_tokens=self.num_decode_tokens,
|
||||
slot_mapping=self.slot_mapping[self.num_prefill_tokens:],
|
||||
multi_modal_placeholder_index_maps=None,
|
||||
enable_kv_scales_calculation=False,
|
||||
seq_lens=None,
|
||||
seq_lens_tensor=self.seq_lens_tensor[self.num_prefills:],
|
||||
max_query_len=None,
|
||||
max_prefill_seq_len=0,
|
||||
max_decode_seq_len=self.max_decode_seq_len,
|
||||
query_start_loc=None,
|
||||
seq_start_loc=None,
|
||||
context_lens_tensor=None,
|
||||
block_tables=self.block_tables[self.num_prefills:],
|
||||
use_cuda_graph=self.use_cuda_graph,
|
||||
)
|
||||
return self._cached_decode_metadata
|
||||
|
||||
|
||||
class BlocksparseFlashAttentionMetadataBuilder(
|
||||
CommonMetadataBuilder[BlocksparseFlashAttentionMetadata]):
|
||||
|
||||
_metadata_cls = BlocksparseFlashAttentionMetadata
|
||||
|
||||
|
||||
class BlocksparseFlashAttentionImpl(AttentionImpl):
|
||||
"""
|
||||
If the input tensors contain prompt tokens, the layout is as follows:
|
||||
|<--------------- num_prompt_tokens -------------->|
|
||||
|<--prompt_0-->|<--prompt_1-->|...|<--prompt_N-1-->|
|
||||
|
||||
Otherwise, the layout is as follows:
|
||||
|<------------------ num_generation_tokens (M) ----------------->|
|
||||
|<--generation_0-->|..........|<--generation_M-1-->|<--padding-->|
|
||||
|
||||
Generation tokens can contain padding when cuda-graph is used.
|
||||
Currently, prompt tokens don't contain any padding.
|
||||
|
||||
The prompts might have different lengths, while the generation tokens
|
||||
always have length 1.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
scale: float,
|
||||
num_kv_heads: int,
|
||||
alibi_slopes: Optional[List[float]],
|
||||
sliding_window: Optional[int],
|
||||
kv_cache_dtype: str,
|
||||
blocksparse_params: Optional[Dict[str, Any]] = None,
|
||||
logits_soft_cap: Optional[float] = None,
|
||||
attn_type: str = AttentionType.DECODER,
|
||||
kv_sharing_target_layer_name: Optional[str] = None,
|
||||
) -> None:
|
||||
if kv_sharing_target_layer_name is not None:
|
||||
raise NotImplementedError("KV sharing is not supported in V0 "
|
||||
"BLOCK_SPARSE_FLASH_ATTN Backend.")
|
||||
assert blocksparse_params is not None
|
||||
assert alibi_slopes is None, ValueError(
|
||||
"Alibi not support for blocksparse flash attention.")
|
||||
assert sliding_window is None, ValueError(
|
||||
"sliding_window is invalid for blocksparse attention.")
|
||||
assert logits_soft_cap is None, ValueError(
|
||||
"logits_soft_cap is invalid for blocksparse attention.")
|
||||
|
||||
if "num_heads" not in blocksparse_params:
|
||||
blocksparse_params["num_heads"] = num_heads
|
||||
if "num_kv_heads" not in blocksparse_params:
|
||||
blocksparse_params["num_kv_heads"] = num_kv_heads or num_heads
|
||||
self.blocksparse_params = BlocksparseParams(**blocksparse_params)
|
||||
self.kv_cache_dtype = kv_cache_dtype
|
||||
|
||||
self.num_heads = num_heads
|
||||
self.head_size = head_size
|
||||
self.scale = float(scale)
|
||||
self.alibi_slopes = alibi_slopes
|
||||
self.num_kv_heads = num_kv_heads
|
||||
|
||||
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
||||
|
||||
self.local_blocks = self.blocksparse_params.local_blocks
|
||||
self.vert_stride = self.blocksparse_params.vert_stride
|
||||
self.sparse_block_size = self.blocksparse_params.block_size
|
||||
self.head_sliding_step = self.blocksparse_params.head_sliding_step
|
||||
|
||||
supported_head_sizes = PagedAttention.get_supported_head_sizes()
|
||||
if head_size not in supported_head_sizes:
|
||||
raise ValueError(
|
||||
f"Head size {head_size} is not supported by PagedAttention. "
|
||||
f"Supported head sizes are: {supported_head_sizes}.")
|
||||
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
self.tp_rank = get_tensor_model_parallel_rank()
|
||||
|
||||
total_num_heads = num_heads * self.tp_size
|
||||
self.bs_attn = LocalStridedBlockSparseAttn(
|
||||
total_num_heads,
|
||||
self.blocksparse_params.max_seqlen,
|
||||
self.blocksparse_params.local_blocks,
|
||||
self.blocksparse_params.vert_stride,
|
||||
self.blocksparse_params.block_size,
|
||||
homo_head=self.blocksparse_params.homo_head,
|
||||
active_head_range=self.blocksparse_params.active_head_range,
|
||||
)
|
||||
|
||||
if attn_type != AttentionType.DECODER:
|
||||
raise NotImplementedError("Encoder self-attention and "
|
||||
"encoder/decoder cross-attention "
|
||||
"are not implemented for "
|
||||
"BlocksparseFlashAttentionImpl")
|
||||
|
||||
def forward(
|
||||
self,
|
||||
layer: AttentionLayer,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: BlocksparseFlashAttentionMetadata,
|
||||
output: Optional[torch.Tensor] = None,
|
||||
output_scale: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
"""Forward pass with FlashAttention and PagedAttention.
|
||||
|
||||
Args:
|
||||
query: shape = [num_tokens, num_heads * head_size]
|
||||
key: shape = [num_tokens, num_kv_heads * head_size]
|
||||
value: shape = [num_tokens, num_kv_heads * head_size]
|
||||
kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size]
|
||||
NOTE: kv_cache will be an empty tensor with shape [0]
|
||||
for profiling run.
|
||||
attn_metadata: Metadata for attention.
|
||||
Returns:
|
||||
shape = [num_tokens, num_heads * head_size]
|
||||
"""
|
||||
if output_scale is not None:
|
||||
raise NotImplementedError(
|
||||
"fused output quantization is not yet supported"
|
||||
" for BlocksparseFlashAttentionImpl")
|
||||
|
||||
num_tokens, hidden_size = query.shape
|
||||
# Reshape the query, key, and value tensors.
|
||||
query = query.view(-1, self.num_heads, self.head_size)
|
||||
key = key.view(-1, self.num_kv_heads, self.head_size)
|
||||
value = value.view(-1, self.num_kv_heads, self.head_size)
|
||||
|
||||
if kv_cache.numel() > 0:
|
||||
key_cache, value_cache = PagedAttention.split_kv_cache(
|
||||
kv_cache, self.num_kv_heads, self.head_size)
|
||||
|
||||
# Reshape the input keys and values and store them in the cache.
|
||||
# If kv_cache is not provided, the new key and value tensors are
|
||||
# not cached. This happens during the initial memory profiling run.
|
||||
|
||||
PagedAttention.write_to_paged_cache(
|
||||
key,
|
||||
value,
|
||||
key_cache,
|
||||
value_cache,
|
||||
attn_metadata.slot_mapping,
|
||||
self.kv_cache_dtype,
|
||||
layer._k_scale,
|
||||
layer._v_scale,
|
||||
)
|
||||
|
||||
if prefill_meta := attn_metadata.prefill_metadata:
|
||||
|
||||
# Prompt run.
|
||||
# normal attention
|
||||
# When block_tables are not filled, it means q and k are the
|
||||
# prompt, and they have the same length.
|
||||
|
||||
assert kv_cache.numel() == 0 \
|
||||
or prefill_meta.block_tables is None \
|
||||
or prefill_meta.block_tables.numel() == 0, \
|
||||
"Does not support prefix-enabled attention."
|
||||
|
||||
output = self.bs_attn(
|
||||
q=query,
|
||||
k=key,
|
||||
v=value,
|
||||
cu_seqlens_q=prefill_meta.seq_start_loc,
|
||||
cu_seqlens_k=prefill_meta.seq_start_loc,
|
||||
sm_scale=self.scale,
|
||||
)
|
||||
|
||||
if decode_meta := attn_metadata.decode_metadata:
|
||||
# Decoding run.
|
||||
output = PagedAttention.forward_decode(
|
||||
query,
|
||||
key_cache,
|
||||
value_cache,
|
||||
decode_meta.block_tables,
|
||||
decode_meta.seq_lens_tensor,
|
||||
self.blocksparse_params.max_seqlen,
|
||||
self.kv_cache_dtype,
|
||||
self.num_kv_heads,
|
||||
self.scale,
|
||||
self.alibi_slopes,
|
||||
layer._k_scale,
|
||||
layer._v_scale,
|
||||
tp_rank=self.tp_rank,
|
||||
blocksparse_local_blocks=self.local_blocks,
|
||||
blocksparse_vert_stride=self.vert_stride,
|
||||
blocksparse_block_size=self.sparse_block_size,
|
||||
blocksparse_head_sliding_step=self.head_sliding_step,
|
||||
)
|
||||
|
||||
assert output is not None
|
||||
# Reshape the output tensor.
|
||||
return output.view(num_tokens, hidden_size)
|
||||
@ -667,7 +667,6 @@ class DifferentialFlashAttentionImpl(AttentionImpl):
|
||||
alibi_slopes: Optional[List[float]],
|
||||
sliding_window: Optional[int],
|
||||
kv_cache_dtype: str,
|
||||
blocksparse_params: Optional[Dict[str, Any]] = None,
|
||||
logits_soft_cap: Optional[float] = None,
|
||||
attn_type: str = AttentionType.DECODER,
|
||||
kv_sharing_target_layer_name: Optional[str] = None,
|
||||
@ -680,9 +679,6 @@ class DifferentialFlashAttentionImpl(AttentionImpl):
|
||||
differential_flash_attention_config
|
||||
self.used_shared_kv_cache = kv_sharing_target_layer_name is not None
|
||||
self.kv_sharing_target_layer_name = kv_sharing_target_layer_name
|
||||
if blocksparse_params is not None:
|
||||
raise ValueError(
|
||||
"FlashAttention does not support block-sparse attention.")
|
||||
if use_irope:
|
||||
logger.warning(
|
||||
"Using irope in V0 is not supported yet, it will fall back "
|
||||
|
||||
@ -287,7 +287,6 @@ class DualChunkFlashAttentionImpl(FlashAttentionImpl):
|
||||
alibi_slopes: Optional[List[float]],
|
||||
sliding_window: Optional[int],
|
||||
kv_cache_dtype: str,
|
||||
blocksparse_params: Optional[Dict[str, Any]] = None,
|
||||
logits_soft_cap: Optional[float] = None,
|
||||
attn_type: str = AttentionType.DECODER,
|
||||
kv_sharing_target_layer_name: Optional[str] = None,
|
||||
|
||||
@ -4,7 +4,7 @@
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass
|
||||
from itertools import accumulate
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Type
|
||||
|
||||
import torch
|
||||
|
||||
@ -615,7 +615,6 @@ class FlashAttentionImpl(AttentionImpl):
|
||||
alibi_slopes: Optional[List[float]],
|
||||
sliding_window: Optional[int],
|
||||
kv_cache_dtype: str,
|
||||
blocksparse_params: Optional[Dict[str, Any]] = None,
|
||||
logits_soft_cap: Optional[float] = None,
|
||||
attn_type: str = AttentionType.DECODER,
|
||||
kv_sharing_target_layer_name: Optional[str] = None,
|
||||
@ -624,9 +623,6 @@ class FlashAttentionImpl(AttentionImpl):
|
||||
if kv_sharing_target_layer_name is not None:
|
||||
raise NotImplementedError("KV sharing is not supported in V0 "
|
||||
"FLASH_ATTN backend.")
|
||||
if blocksparse_params is not None:
|
||||
raise ValueError(
|
||||
"FlashAttention does not support block-sparse attention.")
|
||||
if use_irope:
|
||||
logger.warning(
|
||||
"Using irope in V0 is not supported yet, it will fall back "
|
||||
|
||||
@ -999,7 +999,6 @@ class FlashInferImpl(AttentionImpl):
|
||||
alibi_slopes: Optional[List[float]],
|
||||
sliding_window: Optional[int],
|
||||
kv_cache_dtype: str,
|
||||
blocksparse_params: Optional[Dict[str, Any]] = None,
|
||||
logits_soft_cap: Optional[float] = None,
|
||||
attn_type: str = AttentionType.DECODER,
|
||||
kv_sharing_target_layer_name: Optional[str] = None,
|
||||
|
||||
@ -3,7 +3,7 @@
|
||||
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type
|
||||
from typing import TYPE_CHECKING, List, Optional, Tuple, Type
|
||||
|
||||
import torch
|
||||
|
||||
@ -181,7 +181,6 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
|
||||
alibi_slopes: Optional[List[float]],
|
||||
sliding_window: Optional[int],
|
||||
kv_cache_dtype: str,
|
||||
blocksparse_params: Optional[Dict[str, Any]],
|
||||
logits_soft_cap: Optional[float],
|
||||
attn_type: str,
|
||||
kv_sharing_target_layer_name: Optional[str] = None,
|
||||
@ -189,20 +188,17 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
|
||||
**mla_args) -> None:
|
||||
super().__init__(num_heads, head_size, scale, num_kv_heads,
|
||||
alibi_slopes, sliding_window, kv_cache_dtype,
|
||||
blocksparse_params, logits_soft_cap, attn_type,
|
||||
logits_soft_cap, attn_type,
|
||||
kv_sharing_target_layer_name, **mla_args)
|
||||
|
||||
assert is_flashmla_supported(), \
|
||||
"FlashMLA is not supported on this device"
|
||||
|
||||
unsupported_features = [
|
||||
alibi_slopes, sliding_window, blocksparse_params, logits_soft_cap
|
||||
]
|
||||
unsupported_features = [alibi_slopes, sliding_window, logits_soft_cap]
|
||||
if any(unsupported_features):
|
||||
raise NotImplementedError(
|
||||
"FlashMLAImpl does not support one of the following: "
|
||||
"alibi_slopes, sliding_window, blocksparse_params, "
|
||||
"logits_soft_cap")
|
||||
"alibi_slopes, sliding_window, logits_soft_cap")
|
||||
|
||||
if attn_type != AttentionType.DECODER:
|
||||
raise NotImplementedError("Encoder self-attention and "
|
||||
|
||||
@ -997,7 +997,6 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
|
||||
alibi_slopes: Optional[List[float]],
|
||||
sliding_window: Optional[int],
|
||||
kv_cache_dtype: str,
|
||||
blocksparse_params: Optional[Dict[str, Any]],
|
||||
logits_soft_cap: Optional[float],
|
||||
attn_type: str,
|
||||
kv_sharing_target_layer_name: Optional[str],
|
||||
|
||||
@ -3,7 +3,7 @@
|
||||
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Any, Optional, Type, Union
|
||||
from typing import TYPE_CHECKING, Optional, Type, Union
|
||||
|
||||
import torch
|
||||
|
||||
@ -367,7 +367,6 @@ class AiterMLAImpl(MLACommonImpl[AiterMLAMetadata]):
|
||||
alibi_slopes: Optional[list[float]],
|
||||
sliding_window: Optional[int],
|
||||
kv_cache_dtype: str,
|
||||
blocksparse_params: Optional[dict[str, Any]],
|
||||
logits_soft_cap: Optional[float],
|
||||
attn_type: str,
|
||||
kv_sharing_target_layer_name: Optional[str],
|
||||
@ -375,17 +374,14 @@ class AiterMLAImpl(MLACommonImpl[AiterMLAMetadata]):
|
||||
**mla_args) -> None:
|
||||
super().__init__(num_heads, head_size, scale, num_kv_heads,
|
||||
alibi_slopes, sliding_window, kv_cache_dtype,
|
||||
blocksparse_params, logits_soft_cap, attn_type,
|
||||
logits_soft_cap, attn_type,
|
||||
kv_sharing_target_layer_name, **mla_args)
|
||||
|
||||
unsupported_features = [
|
||||
alibi_slopes, sliding_window, blocksparse_params, logits_soft_cap
|
||||
]
|
||||
unsupported_features = [alibi_slopes, sliding_window, logits_soft_cap]
|
||||
if any(unsupported_features):
|
||||
raise NotImplementedError(
|
||||
"Aiter MLA does not support one of the following: "
|
||||
"alibi_slopes, sliding_window, blocksparse_params, "
|
||||
"logits_soft_cap")
|
||||
"alibi_slopes, sliding_window, logits_soft_cap")
|
||||
|
||||
from aiter import flash_attn_varlen_func
|
||||
self.flash_attn_varlen_func = flash_attn_varlen_func
|
||||
|
||||
@ -4,7 +4,7 @@
|
||||
import itertools
|
||||
from dataclasses import dataclass
|
||||
from functools import cache
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type
|
||||
from typing import TYPE_CHECKING, List, Optional, Tuple, Type
|
||||
|
||||
import torch
|
||||
|
||||
@ -494,7 +494,6 @@ class ROCmFlashAttentionImpl(AttentionImpl):
|
||||
alibi_slopes: Optional[List[float]],
|
||||
sliding_window: Optional[int],
|
||||
kv_cache_dtype: str,
|
||||
blocksparse_params: Optional[Dict[str, Any]] = None,
|
||||
logits_soft_cap: Optional[float] = None,
|
||||
attn_type: str = AttentionType.DECODER,
|
||||
kv_sharing_target_layer_name: Optional[str] = None,
|
||||
@ -507,9 +506,6 @@ class ROCmFlashAttentionImpl(AttentionImpl):
|
||||
logger.warning_once(
|
||||
"Using irope in ROCm Flash Attention is not supported yet, it "
|
||||
"will fail back to global attention for long context.")
|
||||
if blocksparse_params is not None:
|
||||
raise ValueError(
|
||||
"ROCmFlashAttention does not support blocksparse attention.")
|
||||
if use_irope:
|
||||
logger.warning(
|
||||
"Using irope in V0 is not supported yet, it will fall back "
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Any, Dict, List, Optional, Type
|
||||
from typing import List, Optional, Type
|
||||
|
||||
import torch
|
||||
|
||||
@ -35,7 +35,6 @@ class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]):
|
||||
alibi_slopes: Optional[List[float]],
|
||||
sliding_window: Optional[int],
|
||||
kv_cache_dtype: str,
|
||||
blocksparse_params: Optional[Dict[str, Any]],
|
||||
logits_soft_cap: Optional[float],
|
||||
attn_type: str,
|
||||
kv_sharing_target_layer_name: Optional[str],
|
||||
@ -43,17 +42,14 @@ class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]):
|
||||
**mla_args) -> None:
|
||||
super().__init__(num_heads, head_size, scale, num_kv_heads,
|
||||
alibi_slopes, sliding_window, kv_cache_dtype,
|
||||
blocksparse_params, logits_soft_cap, attn_type,
|
||||
logits_soft_cap, attn_type,
|
||||
kv_sharing_target_layer_name, **mla_args)
|
||||
|
||||
unsupported_features = [
|
||||
alibi_slopes, sliding_window, blocksparse_params, logits_soft_cap
|
||||
]
|
||||
unsupported_features = [alibi_slopes, sliding_window, logits_soft_cap]
|
||||
if any(unsupported_features):
|
||||
raise NotImplementedError(
|
||||
"TritonMLAImpl does not support one of the following: "
|
||||
"alibi_slopes, sliding_window, blocksparse_params, "
|
||||
"logits_soft_cap")
|
||||
"alibi_slopes, sliding_window, logits_soft_cap")
|
||||
|
||||
if attn_type != AttentionType.DECODER:
|
||||
raise NotImplementedError("Encoder self-attention and "
|
||||
|
||||
@ -2,7 +2,7 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Attention layer with xFormers and PagedAttention."""
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Optional, Tuple, Type
|
||||
from typing import Dict, List, Optional, Tuple, Type
|
||||
|
||||
import torch
|
||||
from xformers import ops as xops
|
||||
@ -387,7 +387,6 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
|
||||
alibi_slopes: Optional[List[float]],
|
||||
sliding_window: Optional[int],
|
||||
kv_cache_dtype: str,
|
||||
blocksparse_params: Optional[Dict[str, Any]] = None,
|
||||
logits_soft_cap: Optional[float] = None,
|
||||
attn_type: str = AttentionType.DECODER,
|
||||
kv_sharing_target_layer_name: Optional[str] = None,
|
||||
@ -396,9 +395,6 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
|
||||
if kv_sharing_target_layer_name is not None:
|
||||
raise NotImplementedError("KV sharing is not supported in V0 "
|
||||
"XFORMERS backend.")
|
||||
if blocksparse_params is not None:
|
||||
raise ValueError(
|
||||
"XFormers does not support block-sparse attention.")
|
||||
if logits_soft_cap is not None:
|
||||
logger.warning_once("XFormers does not support logits soft cap. "
|
||||
"Outputs may be slightly off.")
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Attention layer."""
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import List, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@ -74,7 +74,6 @@ class Attention(nn.Module):
|
||||
alibi_slopes: Optional[List[float]] = None,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
blocksparse_params: Optional[Dict[str, Any]] = None,
|
||||
logits_soft_cap: Optional[float] = None,
|
||||
per_layer_sliding_window: Optional[int] = None,
|
||||
use_mla: bool = False,
|
||||
@ -163,12 +162,11 @@ class Attention(nn.Module):
|
||||
kv_cache_dtype,
|
||||
block_size,
|
||||
is_attention_free,
|
||||
blocksparse_params is not None,
|
||||
use_mla=use_mla)
|
||||
impl_cls = attn_backend.get_impl_cls()
|
||||
self.impl = impl_cls(num_heads, head_size, scale, num_kv_heads,
|
||||
alibi_slopes, sliding_window, kv_cache_dtype,
|
||||
blocksparse_params, logits_soft_cap, attn_type,
|
||||
logits_soft_cap, attn_type,
|
||||
kv_sharing_target_layer_name, **extra_impl_args)
|
||||
self.backend = backend_name_to_enum(attn_backend.get_name())
|
||||
self.dtype = dtype
|
||||
|
||||
@ -1,433 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.triton_utils import tl, triton
|
||||
|
||||
|
||||
def blocksparse_flash_attn_varlen_fwd(
|
||||
q,
|
||||
k,
|
||||
v, # (#tokens, n_heads, head_size)
|
||||
cu_seqlens_k,
|
||||
cu_seqlens_q,
|
||||
sm_scale,
|
||||
sparse_layout,
|
||||
*,
|
||||
block_size=64,
|
||||
q_block_size=None,
|
||||
max_seqlen=None):
|
||||
# split q to blocks
|
||||
|
||||
assert isinstance(sparse_layout, (list, tuple))
|
||||
|
||||
_, n_heads, head_size = q.shape
|
||||
batch_size = cu_seqlens_k.size(0) - 1
|
||||
q_block_size = q_block_size or block_size
|
||||
|
||||
assert q.dim() == k.dim() == v.dim() == 3
|
||||
assert q.size(1) % k.size(1) == 0
|
||||
assert q.size(2) == k.size(2)
|
||||
# TODO(linxihui): allow k, v to have different head_size
|
||||
assert k.shape == v.shape
|
||||
assert cu_seqlens_k.dim() == 1
|
||||
|
||||
q_k_ratio = q.size(1) // k.size(1)
|
||||
|
||||
if cu_seqlens_q is None:
|
||||
if q.size(0) == batch_size: # decoding only
|
||||
cu_seqlens_q = torch.arange(
|
||||
0,
|
||||
batch_size + 1,
|
||||
dtype=cu_seqlens_k.dtype,
|
||||
device=cu_seqlens_k.device,
|
||||
)
|
||||
elif q.size(0) == k.size(0):
|
||||
cu_seqlens_q = cu_seqlens_k
|
||||
else:
|
||||
raise ValueError("cu_seqlens_q must be specified\
|
||||
if it mix of prefilling and decoding.")
|
||||
else:
|
||||
assert cu_seqlens_k.size(0) == cu_seqlens_q.size(0)
|
||||
|
||||
# switch to use cpu to avoid too many kernel launches when iterated over
|
||||
q_lens = (cu_seqlens_q[1:] - cu_seqlens_q[:-1]).cpu()
|
||||
k_lens = (cu_seqlens_k[1:] - cu_seqlens_k[:-1]).cpu()
|
||||
|
||||
assert torch.logical_or(q_lens == 1, k_lens == q_lens).all(), (
|
||||
"length of q should either be 1 (decoding) or same as k (prefilling).")
|
||||
|
||||
if max_seqlen:
|
||||
assert k_lens.max() <= max_seqlen
|
||||
|
||||
n_blocks = (q_lens + q_block_size - 1) // q_block_size
|
||||
|
||||
q_batch_ids = torch.tensor(
|
||||
[i for i, n in enumerate(n_blocks) for _ in range(n)],
|
||||
dtype=cu_seqlens_q.dtype,
|
||||
device=cu_seqlens_q.device,
|
||||
)
|
||||
q_start_sids = torch.tensor(
|
||||
[i * q_block_size for n in n_blocks for i in range(n)],
|
||||
dtype=cu_seqlens_q.dtype,
|
||||
device=cu_seqlens_q.device,
|
||||
)
|
||||
|
||||
out = q.new_empty(q.shape)
|
||||
cu_seqlens_q = cu_seqlens_q.contiguous()
|
||||
cu_seqlens_k = cu_seqlens_k.contiguous()
|
||||
|
||||
layout_crow_indices, layout_col_indices = sparse_layout
|
||||
block_d = triton.next_power_of_2(head_size)
|
||||
|
||||
decoding_only = (q_lens == 1).all().item()
|
||||
grid = (len(q_start_sids), n_heads, 1)
|
||||
|
||||
_fwd_kernel_batch_inference[grid](
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
out,
|
||||
sm_scale,
|
||||
cu_seqlens_q[:-1],
|
||||
cu_seqlens_q[1:],
|
||||
cu_seqlens_k[:-1],
|
||||
cu_seqlens_k[1:],
|
||||
q_batch_ids,
|
||||
q_start_sids,
|
||||
0,
|
||||
*q.stride(),
|
||||
0,
|
||||
*k.stride(),
|
||||
0,
|
||||
*v.stride(),
|
||||
0,
|
||||
*out.stride(),
|
||||
layout_crow_indices,
|
||||
layout_col_indices,
|
||||
*layout_crow_indices.stride(),
|
||||
*layout_col_indices.stride(),
|
||||
q_k_ratio,
|
||||
HAS_BATCH_DIM=False,
|
||||
D_HEAD=head_size,
|
||||
BLOCK_M=q_block_size,
|
||||
BLOCK_N=block_size,
|
||||
BLOCK_D=block_d,
|
||||
BLOCK_M_LOADING=(16 if decoding_only else
|
||||
q_block_size), # smaller for decoding
|
||||
EVEN_D=block_d == head_size,
|
||||
num_warps=1 if decoding_only else 4,
|
||||
num_stages=3)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _fwd_kernel_inner(
|
||||
acc,
|
||||
l_i,
|
||||
m_i,
|
||||
q,
|
||||
Q,
|
||||
k_block_col_idx,
|
||||
layout_col_ptr,
|
||||
layout_col_stride_h,
|
||||
layout_col_stride_m,
|
||||
k_ptrs,
|
||||
v_ptrs,
|
||||
off_h,
|
||||
offs_m,
|
||||
offs_n,
|
||||
offs_d,
|
||||
stride_kt,
|
||||
stride_vt,
|
||||
sm_scale,
|
||||
k_seqlen,
|
||||
past_len,
|
||||
LAST_K_BLOCK: tl.constexpr,
|
||||
BLOCK_M_LOADING: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr,
|
||||
D_HEAD: tl.constexpr,
|
||||
EVEN_D: tl.constexpr,
|
||||
M_LT_N: tl.constexpr,
|
||||
):
|
||||
k_block_id = tl.load(layout_col_ptr + off_h * layout_col_stride_h +
|
||||
k_block_col_idx * layout_col_stride_m).to(tl.int32)
|
||||
start_n = k_block_id * BLOCK_N
|
||||
if LAST_K_BLOCK:
|
||||
if EVEN_D:
|
||||
k = tl.load(
|
||||
k_ptrs + start_n * stride_kt,
|
||||
mask=offs_n[None, :] + start_n < k_seqlen,
|
||||
other=0.0,
|
||||
)
|
||||
else:
|
||||
k = tl.load(
|
||||
k_ptrs + start_n * stride_kt,
|
||||
mask=(offs_n[None, :] + start_n < k_seqlen) &
|
||||
(offs_d[:, None] < D_HEAD),
|
||||
other=0.0,
|
||||
)
|
||||
else:
|
||||
if EVEN_D:
|
||||
k = tl.load(k_ptrs + start_n * stride_kt)
|
||||
else:
|
||||
k = tl.load(k_ptrs + start_n * stride_kt,
|
||||
mask=offs_d[:, None] < D_HEAD,
|
||||
other=0.0)
|
||||
|
||||
qk = tl.zeros([BLOCK_M_LOADING, BLOCK_N], dtype=tl.float32)
|
||||
qk += tl.dot(q, k)
|
||||
qk *= sm_scale
|
||||
|
||||
# the following is needed only when LAST_K_BLOCK or BLOCK_M < BLOCK_N
|
||||
if LAST_K_BLOCK | M_LT_N:
|
||||
qk += tl.where(
|
||||
offs_m[:, None] + past_len >= (start_n + offs_n[None, :]),
|
||||
0,
|
||||
float("-inf"),
|
||||
)
|
||||
|
||||
# flash-attn2
|
||||
m_ij = tl.maximum(m_i, tl.max(qk, 1))
|
||||
p = tl.math.exp2(qk - m_ij[:, None])
|
||||
l_ij = tl.sum(p, 1)
|
||||
alpha = tl.math.exp2(m_i - m_ij)
|
||||
acc = acc * alpha[:, None]
|
||||
# update m_i
|
||||
m_i = m_ij
|
||||
l_i = l_i * alpha + l_ij
|
||||
|
||||
p = p.to(Q.dtype.element_ty)
|
||||
# update acc
|
||||
if LAST_K_BLOCK:
|
||||
if EVEN_D:
|
||||
v = tl.load(
|
||||
v_ptrs + start_n * stride_vt,
|
||||
mask=offs_n[:, None] + start_n < k_seqlen,
|
||||
other=0.0,
|
||||
)
|
||||
else:
|
||||
v = tl.load(
|
||||
v_ptrs + start_n * stride_vt,
|
||||
mask=(offs_n[:, None] + start_n < k_seqlen) &
|
||||
(offs_d[None, :] < D_HEAD),
|
||||
other=0.0,
|
||||
)
|
||||
else:
|
||||
if EVEN_D:
|
||||
v = tl.load(v_ptrs + start_n * stride_vt)
|
||||
else:
|
||||
v = tl.load(v_ptrs + start_n * stride_vt,
|
||||
mask=offs_d[None, :] < D_HEAD,
|
||||
other=0.0)
|
||||
|
||||
acc += tl.dot(p, v)
|
||||
|
||||
return acc, l_i, m_i
|
||||
|
||||
|
||||
@triton.heuristics({
|
||||
"M_LT_N":
|
||||
lambda kwargs: kwargs["BLOCK_M"] < kwargs["BLOCK_N"],
|
||||
})
|
||||
@triton.jit
|
||||
def _fwd_kernel_batch_inference(
|
||||
Q,
|
||||
K,
|
||||
V,
|
||||
Out,
|
||||
sm_scale,
|
||||
q_batch_starts,
|
||||
q_batch_ends,
|
||||
k_batch_starts,
|
||||
k_batch_ends,
|
||||
q_batch_ids,
|
||||
q_start_sids,
|
||||
stride_qb,
|
||||
stride_qt,
|
||||
stride_qh,
|
||||
stride_qd,
|
||||
stride_kb,
|
||||
stride_kt,
|
||||
stride_kh,
|
||||
stride_kd,
|
||||
stride_vb,
|
||||
stride_vt,
|
||||
stride_vh,
|
||||
stride_vd,
|
||||
stride_ob,
|
||||
stride_ot,
|
||||
stride_oh,
|
||||
stride_od,
|
||||
layout_crow_ptr,
|
||||
layout_col_ptr,
|
||||
layout_crow_stride_h,
|
||||
layout_crow_stride_m,
|
||||
layout_col_stride_h,
|
||||
layout_col_stride_m,
|
||||
q_k_ratio,
|
||||
HAS_BATCH_DIM: tl.constexpr,
|
||||
D_HEAD: tl.constexpr,
|
||||
BLOCK_M: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr,
|
||||
BLOCK_D: tl.constexpr,
|
||||
BLOCK_M_LOADING: tl.constexpr,
|
||||
EVEN_D: tl.constexpr,
|
||||
M_LT_N: tl.constexpr,
|
||||
):
|
||||
"""
|
||||
NOTATION:
|
||||
pid: position id
|
||||
sid: storage id
|
||||
sbid: storage block id
|
||||
pbid: position block id
|
||||
offs_m, offs_n: storage offsets of m-dim(q, row) and n-dim(k, col)
|
||||
|
||||
TODO(linxihui):
|
||||
Optimize grouped-attn
|
||||
"""
|
||||
off_zm = tl.program_id(0)
|
||||
off_h = tl.program_id(1)
|
||||
|
||||
off_h_for_kv = off_h // q_k_ratio
|
||||
|
||||
if HAS_BATCH_DIM:
|
||||
off_z = tl.program_id(2)
|
||||
Q += off_z * stride_qb
|
||||
K += off_z * stride_kb
|
||||
V += off_z * stride_vb
|
||||
Out += off_z * stride_ob
|
||||
start_m = off_zm
|
||||
q_start_sid = start_m * BLOCK_M # always 0 for decoding
|
||||
else:
|
||||
off_z = tl.load(q_batch_ids + off_zm).to(tl.int32) # [0, 0, 0, 1]
|
||||
q_start_sid = tl.load(q_start_sids + off_zm)
|
||||
start_m = q_start_sid // BLOCK_M # q_sbid
|
||||
|
||||
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M_LOADING)
|
||||
offs_n = tl.arange(0, BLOCK_N)
|
||||
offs_d = tl.arange(0, BLOCK_D)
|
||||
|
||||
q_cu_start = tl.load(q_batch_starts + off_z).to(tl.int32)
|
||||
q_seqlen = tl.load(q_batch_ends + off_z).to(tl.int32) - q_cu_start
|
||||
k_cu_start = tl.load(k_batch_starts + off_z).to(tl.int32)
|
||||
k_seqlen = tl.load(k_batch_ends + off_z).to(tl.int32) - k_cu_start
|
||||
past_len = k_seqlen - q_seqlen
|
||||
|
||||
Q += q_cu_start * stride_qt + off_h * stride_qh
|
||||
K += k_cu_start * stride_kt + off_h_for_kv * stride_kh
|
||||
V += k_cu_start * stride_vt + off_h_for_kv * stride_vh
|
||||
Out += q_cu_start * stride_ot + off_h * stride_oh
|
||||
|
||||
q_pbid = (past_len + q_start_sid) // BLOCK_M
|
||||
|
||||
if EVEN_D:
|
||||
q = tl.load(
|
||||
Q + offs_m[:, None] * stride_qt + offs_d[None, :] * stride_qd,
|
||||
mask=offs_m[:, None] < q_seqlen,
|
||||
other=0.0,
|
||||
)
|
||||
else:
|
||||
q = tl.load(
|
||||
Q + offs_m[:, None] * stride_qt + offs_d[None, :] * stride_qd,
|
||||
mask=(offs_m[:, None] < q_seqlen) & (offs_d[None, :] < D_HEAD),
|
||||
other=0.0,
|
||||
)
|
||||
|
||||
sparse_crow_ptr = (layout_crow_ptr + off_h * layout_crow_stride_h +
|
||||
q_pbid * layout_crow_stride_m)
|
||||
|
||||
# TODO(linxihui): load at once, with any Triton version
|
||||
# that supports `tl.split`, e.g., Triton 3.0
|
||||
k_block_start = tl.load(sparse_crow_ptr).to(tl.int32)
|
||||
k_block_end = tl.load(sparse_crow_ptr + 1).to(tl.int32)
|
||||
|
||||
m_i = tl.zeros([BLOCK_M_LOADING], dtype=tl.float32) - float("inf")
|
||||
l_i = tl.zeros([BLOCK_M_LOADING], dtype=tl.float32)
|
||||
acc = tl.zeros([BLOCK_M_LOADING, BLOCK_D], dtype=tl.float32)
|
||||
|
||||
k_ptrs = K + offs_n[None, :] * stride_kt + offs_d[:, None] * stride_kd
|
||||
v_ptrs = V + offs_n[:, None] * stride_vt + offs_d[None, :] * stride_vd
|
||||
|
||||
sm_scale *= (
|
||||
1.44269504 # 1/log2 as we use base2 for exponential and logarithm
|
||||
)
|
||||
|
||||
for k_block_col_idx in range(k_block_start, k_block_end - 1):
|
||||
acc, l_i, m_i = _fwd_kernel_inner(
|
||||
acc,
|
||||
l_i,
|
||||
m_i,
|
||||
q,
|
||||
Q,
|
||||
k_block_col_idx,
|
||||
layout_col_ptr,
|
||||
layout_col_stride_h,
|
||||
layout_col_stride_m,
|
||||
k_ptrs,
|
||||
v_ptrs,
|
||||
off_h,
|
||||
offs_m,
|
||||
offs_n,
|
||||
offs_d,
|
||||
stride_kt,
|
||||
stride_vt,
|
||||
sm_scale,
|
||||
k_seqlen,
|
||||
past_len,
|
||||
False,
|
||||
BLOCK_M_LOADING,
|
||||
BLOCK_N,
|
||||
D_HEAD,
|
||||
EVEN_D,
|
||||
M_LT_N,
|
||||
)
|
||||
|
||||
acc, l_i, m_i = _fwd_kernel_inner(
|
||||
acc,
|
||||
l_i,
|
||||
m_i,
|
||||
q,
|
||||
Q,
|
||||
k_block_end - 1,
|
||||
layout_col_ptr,
|
||||
layout_col_stride_h,
|
||||
layout_col_stride_m,
|
||||
k_ptrs,
|
||||
v_ptrs,
|
||||
off_h,
|
||||
offs_m,
|
||||
offs_n,
|
||||
offs_d,
|
||||
stride_kt,
|
||||
stride_vt,
|
||||
sm_scale,
|
||||
k_seqlen,
|
||||
past_len,
|
||||
True,
|
||||
BLOCK_M_LOADING,
|
||||
BLOCK_N,
|
||||
D_HEAD,
|
||||
EVEN_D,
|
||||
M_LT_N,
|
||||
)
|
||||
|
||||
# flash-attn 2
|
||||
m_i += tl.math.log2(l_i)
|
||||
acc = acc / l_i[:, None]
|
||||
|
||||
# write output
|
||||
if EVEN_D:
|
||||
tl.store(
|
||||
Out + offs_m[:, None] * stride_ot + offs_d[None, :] * stride_od,
|
||||
acc,
|
||||
mask=offs_m[:, None] < q_seqlen,
|
||||
)
|
||||
else:
|
||||
tl.store(
|
||||
Out + offs_m[:, None] * stride_ot + offs_d[None, :] * stride_od,
|
||||
acc,
|
||||
mask=(offs_m[:, None] < q_seqlen) & (offs_d[None, :] < D_HEAD),
|
||||
)
|
||||
@ -1,239 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import math
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from .utils import (dense_to_crow_col, get_head_sliding_step,
|
||||
get_sparse_attn_mask)
|
||||
|
||||
IS_COMPUTE_8_OR_ABOVE = current_platform.has_device_capability(80)
|
||||
|
||||
if IS_COMPUTE_8_OR_ABOVE:
|
||||
from .blocksparse_attention_kernel import blocksparse_flash_attn_varlen_fwd
|
||||
|
||||
|
||||
class LocalStridedBlockSparseAttn(torch.nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
n_heads,
|
||||
max_seqlen,
|
||||
local_blocks,
|
||||
vert_stride,
|
||||
block_size,
|
||||
device=None,
|
||||
dtype=None,
|
||||
homo_head=False,
|
||||
active_head_range=None,
|
||||
q_block_size=None,
|
||||
use_spda=None,
|
||||
):
|
||||
super().__init__()
|
||||
if use_spda is None:
|
||||
use_spda = current_platform.is_rocm() or \
|
||||
current_platform.is_cpu() or not \
|
||||
IS_COMPUTE_8_OR_ABOVE
|
||||
device = device or (torch.cuda.current_device()
|
||||
if current_platform.is_cuda_alike() else "cpu")
|
||||
device = torch.device(device)
|
||||
# NOTE: vllm CPU backend support BF16 instead of FP16.
|
||||
dtype = dtype or (torch.bfloat16 if IS_COMPUTE_8_OR_ABOVE
|
||||
or device.type == "cpu" else torch.half)
|
||||
|
||||
self.n_heads = n_heads
|
||||
self.max_seqlen = max_seqlen
|
||||
self.local_blocks = local_blocks
|
||||
self.vert_stride = vert_stride
|
||||
self.use_spda = use_spda
|
||||
self.dtype = dtype
|
||||
self.device = device
|
||||
self.block_size = block_size
|
||||
self.q_block_size = q_block_size
|
||||
self.homo_head = homo_head
|
||||
self.active_head_range = active_head_range
|
||||
self.head_sliding_step = get_head_sliding_step(n_heads, vert_stride,
|
||||
homo_head)
|
||||
|
||||
sparse_layout, sparse_pattern, self.dense_attn_mask = (
|
||||
self.get_attn_pattern(dtype, device))
|
||||
|
||||
if q_block_size is not None and q_block_size != block_size:
|
||||
if q_block_size > block_size:
|
||||
assert q_block_size % block_size == 0
|
||||
blocks_to_merge = q_block_size // block_size
|
||||
shape = sparse_pattern.shape
|
||||
sparse_pattern = sparse_pattern.view(shape[0], -1,
|
||||
blocks_to_merge,
|
||||
shape[-1])
|
||||
sparse_pattern = sparse_pattern.sum(2)
|
||||
sparse_layout = dense_to_crow_col(sparse_pattern)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Does not support smaller q_block_size. It will be slower."
|
||||
)
|
||||
|
||||
self.sparse_layout = sparse_layout
|
||||
|
||||
def get_attn_pattern(self, dtype, device):
|
||||
sparse_layout, sparse_pattern, dense_attn_mask = get_sparse_attn_mask(
|
||||
self.n_heads,
|
||||
self.max_seqlen,
|
||||
self.max_seqlen,
|
||||
dtype,
|
||||
device,
|
||||
block_size=self.block_size,
|
||||
local_blocks=self.local_blocks,
|
||||
vert_stride=self.vert_stride,
|
||||
homo_head=self.homo_head,
|
||||
return_dense=self.use_spda,
|
||||
dense_mask_type="bias",
|
||||
)
|
||||
if (not self.homo_head) and (self.active_head_range is not None):
|
||||
assert isinstance(self.active_head_range, tuple)
|
||||
assert (len(self.active_head_range) == 2)
|
||||
h_start, h_end = self.active_head_range
|
||||
sparse_layout = tuple(x[h_start:h_end] for x in sparse_layout)
|
||||
if self.use_spda:
|
||||
dense_attn_mask = dense_attn_mask[h_start:h_end]
|
||||
return sparse_layout, sparse_pattern, dense_attn_mask
|
||||
|
||||
def varlen_attn(self,
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
cu_seqlens_k,
|
||||
cu_seqlens_q=None,
|
||||
sm_scale=None):
|
||||
"""
|
||||
q, k, v: shape = (num_tokens, num_heads_q/kv, head_size).
|
||||
Support grouped attention, with `q[:, i*r:(i*r + r)]`
|
||||
is correspondent to `k[:, i]`, where `r` is the q/k ratio.
|
||||
cu_seqlens_k: shape=(batch_size + 1,),
|
||||
indicating segment of samples,
|
||||
e.g., `k[cu_seqlen[i]:cu_seqlne[i+1]]` is q of sample i
|
||||
cu_seqlens_q: shape=(batch_size + 1, ).
|
||||
Default None: same as cu_seqlens_k for prefilling or
|
||||
[0, 1, .., batch_size] for decoding.
|
||||
The only case you need to specify is when q is a mix of
|
||||
prefilling and decoding.
|
||||
sm_scale: softmax scale, default to 1/sqrt(head_size).
|
||||
|
||||
return: tensor of shape as q.
|
||||
"""
|
||||
assert (
|
||||
IS_COMPUTE_8_OR_ABOVE
|
||||
), "Requires compute capability of 8 or above (Ampere or newer) to use \
|
||||
Triton kernel."
|
||||
|
||||
sm_scale = sm_scale or 1.0 / math.sqrt(q.size(-1))
|
||||
|
||||
return blocksparse_flash_attn_varlen_fwd(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
cu_seqlens_k,
|
||||
cu_seqlens_q,
|
||||
sm_scale,
|
||||
self.sparse_layout,
|
||||
block_size=self.block_size,
|
||||
q_block_size=self.q_block_size,
|
||||
max_seqlen=self.max_seqlen,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def transpose_and_pad(x, cu_seqlens, maxlen, head_repeats=1):
|
||||
"""
|
||||
:param x: (total_tokens, n_heads, head_size)
|
||||
:return: (batch, n_heads, length, head_size)
|
||||
"""
|
||||
x_padded = x.new_empty(
|
||||
len(cu_seqlens) - 1, x.size(1), head_repeats, maxlen, x.size(2))
|
||||
cu_seqlens = cu_seqlens.cpu()
|
||||
for i, (s, e) in enumerate(zip(cu_seqlens[:-1], cu_seqlens[1:])):
|
||||
x_padded[i, :, :, :e - s].copy_(x[s:e].transpose(0,
|
||||
1).unsqueeze(1))
|
||||
return x_padded.flatten(1, 2)
|
||||
|
||||
@staticmethod
|
||||
def transpose_and_unpad(x_padded, cu_seqlens):
|
||||
"""
|
||||
:param x_padded: (batch, n_heads, length, head_size)
|
||||
:return: (total_tokens, n_heads, head_size)
|
||||
"""
|
||||
cu_seqlens = cu_seqlens.cpu()
|
||||
total_n_tokens = cu_seqlens[-1]
|
||||
x = x_padded.new_empty(total_n_tokens, x_padded.size(1),
|
||||
x_padded.size(3))
|
||||
for i, (s, e) in enumerate(zip(cu_seqlens[:-1], cu_seqlens[1:])):
|
||||
x[s:e].copy_(x_padded[i, :, :e - s].transpose(0, 1))
|
||||
return x
|
||||
|
||||
def spda(self, q, k, v, cu_seqlens_k, cu_seqlens_q=None, sm_scale=None):
|
||||
"""For CPU, V100 or other older GPUs.
|
||||
NOTE: torch SPDA supports nested tensor,
|
||||
but seems extremely slow. Choose to pad instead.
|
||||
"""
|
||||
assert (cu_seqlens_q is None or
|
||||
(cu_seqlens_q
|
||||
== cu_seqlens_k).all()), "Can only handle prompt with SPDA."
|
||||
assert q.size(0) == k.size(0), "can only handle prompt with SPDA."
|
||||
|
||||
assert q.size(1) % k.size(1) == 0
|
||||
q_k_ratio = q.size(1) // k.size(1)
|
||||
sm_scale = sm_scale or 1.0 / math.sqrt(q.size(-1))
|
||||
cu_seqlens = cu_seqlens_k.cpu()
|
||||
maxlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
|
||||
|
||||
if (self.dense_attn_mask.dtype != q.dtype
|
||||
or self.dense_attn_mask.device != q.device):
|
||||
_, _, self.dense_attn_mask = self.get_attn_pattern(
|
||||
q.dtype, q.device)
|
||||
attn_mask = self.dense_attn_mask[None, :, :maxlen, :maxlen]
|
||||
|
||||
q2 = self.transpose_and_pad(q, cu_seqlens, maxlen, 1)
|
||||
k2, v2 = (self.transpose_and_pad(x, cu_seqlens, maxlen, q_k_ratio)
|
||||
for x in [k, v])
|
||||
spda_output = torch.nn.functional.scaled_dot_product_attention(
|
||||
q2, k2, v2, attn_mask=attn_mask, scale=sm_scale)
|
||||
return self.transpose_and_unpad(spda_output, cu_seqlens)
|
||||
|
||||
def forward(self, q, k, v, cu_seqlens_k, cu_seqlens_q=None, sm_scale=None):
|
||||
"""Dispatch to `varlen_attn` (Ampere or newer) or
|
||||
`self.spda`(cpu, Volta, Turing or older)based on
|
||||
the type of device used and cuda compute capability.
|
||||
|
||||
q, k, v: shape = (num_tokens, num_heads_q/kv, head_size).
|
||||
Support grouped attention, with `q[:, i*r:(i*r + r)]`
|
||||
is correspondent to `k[:, i]`, where `r` is the q/k ratio.
|
||||
cu_seqlens_k: shape=(batch_size + 1,), indicating segment of samples,
|
||||
e.g., `k[cu_seqlen[i]:cu_seqlne[i+1]]` is q of sample i
|
||||
cu_seqlens_q: shape=(batch_size + 1, ).
|
||||
Default None: same as cu_seqlens_k for prefilling or
|
||||
[0, 1, .., batch_size] for decoding.
|
||||
The only case you need to specify
|
||||
is when q is a mix of prefilling
|
||||
and decoding.
|
||||
sm_scale: softmax scale, default to 1/sqrt(head_size).
|
||||
|
||||
return: tensor of shape as q.
|
||||
"""
|
||||
assert k.dim() == 3
|
||||
if self.use_spda:
|
||||
return self.spda(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
cu_seqlens_k,
|
||||
cu_seqlens_q=cu_seqlens_q,
|
||||
sm_scale=sm_scale,
|
||||
)
|
||||
return self.varlen_attn(q,
|
||||
k,
|
||||
v,
|
||||
cu_seqlens_k,
|
||||
cu_seqlens_q=cu_seqlens_q,
|
||||
sm_scale=sm_scale)
|
||||
@ -1,246 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
# Helper functions for 3D sparse pattern
|
||||
# These function are not optimized and very inefficient.
|
||||
# Avoid calling them too frequent or use a cache mechanism.
|
||||
|
||||
from functools import lru_cache
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from vllm.triton_utils import triton
|
||||
|
||||
|
||||
class csr_matrix:
|
||||
"""Simple implementation of CSR matrix conversion without scipy.
|
||||
This replaced scipy.sparse.csr_matrix() previously used."""
|
||||
|
||||
def __init__(self, input_array):
|
||||
if not isinstance(input_array, np.ndarray):
|
||||
raise ValueError("Input must be a NumPy array")
|
||||
|
||||
self.shape = input_array.shape
|
||||
rows, cols = self.shape
|
||||
data = []
|
||||
indices = []
|
||||
indptr = [0]
|
||||
|
||||
for i in range(rows):
|
||||
for j in range(cols):
|
||||
if input_array[i, j]:
|
||||
data.append(input_array[i, j])
|
||||
indices.append(j)
|
||||
indptr.append(len(indices))
|
||||
|
||||
self.data = np.array(data)
|
||||
self.indices = np.array(indices)
|
||||
self.indptr = np.array(indptr)
|
||||
|
||||
|
||||
def dense_to_crow_col(x: torch.Tensor):
|
||||
"""Turning a 2D/3D torch tensor (x) to CSR rows/cols indexing.
|
||||
NOTE: col_indices padded -1
|
||||
"""
|
||||
device = x.device
|
||||
pad = -1
|
||||
dim = x.dim()
|
||||
assert x.dim() in (2, 3)
|
||||
if x.dim() == 2:
|
||||
x = x[None]
|
||||
x = [csr_matrix(xi.bool().cpu().numpy()) for xi in x]
|
||||
crows = torch.vstack([torch.from_numpy(xi.indptr) for xi in x])
|
||||
cols = [torch.from_numpy(xi.indices) for xi in x]
|
||||
max_cols = max(len(xi) for xi in cols)
|
||||
cols = [
|
||||
torch.cat([xi, pad + xi.new_zeros(max_cols - xi.shape[0])])
|
||||
for xi in cols
|
||||
]
|
||||
cols = torch.vstack(cols)
|
||||
if dim == 2:
|
||||
crows = crows[0]
|
||||
cols = cols[0]
|
||||
return crows.to(device), cols.to(device)
|
||||
|
||||
|
||||
def crow_col_to_dense(crows: torch.Tensor,
|
||||
cols: torch.Tensor,
|
||||
dtype: torch.dtype = torch.float16):
|
||||
dim = crows.dim()
|
||||
if dim == 1:
|
||||
crows = crows[None]
|
||||
cols = cols[None]
|
||||
device = crows.device
|
||||
crows, cols = crows.cpu(), cols.cpu() # faster in cpu
|
||||
shape = (crows.shape[0], crows.shape[1] - 1, cols.max() + 1)
|
||||
x = torch.zeros(shape, dtype=dtype)
|
||||
for i in range(shape[0]):
|
||||
for j in range(shape[1]):
|
||||
x[i, j, cols[i, crows[i, j]:crows[i, j + 1]]] = 1
|
||||
if dim == 1:
|
||||
x = x[0]
|
||||
return x.to(device)
|
||||
|
||||
|
||||
def dense_to_ccol_row(x: torch.Tensor):
|
||||
"""Similar, but to CSC format"""
|
||||
x = x.transpose(-2, -1)
|
||||
return dense_to_crow_col(x)
|
||||
|
||||
|
||||
def ccol_row_to_dense(ccol: torch.Tensor,
|
||||
rows: torch.Tensor,
|
||||
dtype: torch.dtype = torch.float16):
|
||||
return crow_col_to_dense(ccol, rows, dtype).permute(0, 2, 1).contiguous()
|
||||
|
||||
|
||||
def _get_sparse_attn_mask_homo_head(
|
||||
q_len: int,
|
||||
max_seqlen: int,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
block_size: int = 128,
|
||||
local_blocks: int = 4,
|
||||
vert_stride: int = 4,
|
||||
return_dense: bool = False,
|
||||
):
|
||||
"""
|
||||
:return: a tuple of 3:
|
||||
- tuple of crow_indices, col_indices representation
|
||||
of CSR format.
|
||||
- block dense mask
|
||||
- all token dense mask (be aware that it can be
|
||||
OOM if it is too big) if `return_dense==True`,
|
||||
otherwise, None
|
||||
"""
|
||||
with torch.no_grad():
|
||||
num_blocks = triton.cdiv(max_seqlen, block_size)
|
||||
q_pos = torch.arange(num_blocks)[:, None]
|
||||
k_pos = torch.arange(num_blocks)[None]
|
||||
mask_vert_strided = (torch.arange(num_blocks) + 1) % vert_stride == 0
|
||||
block_mask_dense = (((q_pos >= k_pos)
|
||||
& ((q_pos - k_pos < local_blocks)
|
||||
| mask_vert_strided)).to(device).to(dtype))
|
||||
num_blocks_q = triton.cdiv(q_len, block_size)
|
||||
block_mask_dense_output = (dense_to_crow_col(
|
||||
block_mask_dense[-num_blocks_q:].contiguous()))
|
||||
if return_dense:
|
||||
mask_dense = torch.kron(
|
||||
block_mask_dense,
|
||||
block_mask_dense.new_ones((block_size, block_size)),
|
||||
)
|
||||
causal_mask = torch.tril(torch.ones(
|
||||
max_seqlen, max_seqlen)).type_as(mask_dense)[-q_len:]
|
||||
mask_dense = mask_dense[-q_len:, :max_seqlen] * causal_mask
|
||||
return (
|
||||
block_mask_dense_output,
|
||||
block_mask_dense,
|
||||
mask_dense,
|
||||
)
|
||||
else:
|
||||
return (
|
||||
block_mask_dense_output,
|
||||
block_mask_dense,
|
||||
None,
|
||||
)
|
||||
|
||||
|
||||
def binary_mask_to_bias(mask_dense: torch.Tensor):
|
||||
mask_dense = 1 - mask_dense
|
||||
mask_dense.masked_fill_(mask_dense.bool(), -torch.inf)
|
||||
return mask_dense
|
||||
|
||||
|
||||
def get_head_sliding_step(n_heads: int,
|
||||
vert_stride: int,
|
||||
homo_head: bool = False):
|
||||
if homo_head:
|
||||
return 0
|
||||
return max(1, int(vert_stride / n_heads))
|
||||
|
||||
|
||||
@lru_cache
|
||||
def get_sparse_attn_mask(
|
||||
n_heads: int,
|
||||
q_len: int,
|
||||
max_seqlen: int,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
block_size: int = 64,
|
||||
local_blocks: int = 4,
|
||||
vert_stride: int = 4,
|
||||
homo_head: bool = True,
|
||||
return_dense: bool = False,
|
||||
dense_mask_type: str = "binary",
|
||||
):
|
||||
"""
|
||||
:param dense_mask_type: "binary" (0 for skip token, 1 for others)
|
||||
or "bias" (-inf for skip token, 0 or others)
|
||||
:return: a tuple of 3:
|
||||
- tuple of crow_indices, col_indices representation
|
||||
of CSR format.
|
||||
- block dense mask
|
||||
- all token dense mask (be aware that it can be OOM if it
|
||||
is too big) if `return_dense==True`, otherwise, None
|
||||
"""
|
||||
assert dense_mask_type in ("binary", "bias")
|
||||
if homo_head:
|
||||
with torch.no_grad():
|
||||
(crow, col), block_mask_dense, mask_dense = (
|
||||
_get_sparse_attn_mask_homo_head(
|
||||
q_len,
|
||||
max_seqlen,
|
||||
dtype,
|
||||
device,
|
||||
block_size,
|
||||
local_blocks,
|
||||
vert_stride,
|
||||
return_dense,
|
||||
))
|
||||
crow = crow[None].expand(n_heads, crow.shape[0])
|
||||
col = col[None].expand(n_heads, col.shape[0])
|
||||
if return_dense:
|
||||
mask_dense = mask_dense[None].expand(n_heads,
|
||||
*mask_dense.shape)
|
||||
if dense_mask_type == "bias":
|
||||
mask_dense = binary_mask_to_bias(mask_dense)
|
||||
return (crow, col), block_mask_dense, mask_dense
|
||||
|
||||
with torch.no_grad():
|
||||
num_blocks = triton.cdiv(max_seqlen, block_size)
|
||||
q_pos = torch.arange(num_blocks)[None, :, None]
|
||||
k_pos = torch.arange(num_blocks)[None, None]
|
||||
head_sliding_step = get_head_sliding_step(n_heads, vert_stride)
|
||||
mask_vert_strided = [
|
||||
(torch.arange(num_blocks) + h * head_sliding_step + 1) %
|
||||
vert_stride == 0 for h in range(n_heads)
|
||||
]
|
||||
mask_vert_strided = torch.vstack(mask_vert_strided).unsqueeze(1)
|
||||
block_mask_dense = (((q_pos >= k_pos)
|
||||
& ((q_pos - k_pos < local_blocks)
|
||||
| mask_vert_strided)).to(device).to(dtype))
|
||||
num_blocks_q = triton.cdiv(q_len, block_size)
|
||||
block_mask_dense_output = block_mask_dense[:, -num_blocks_q:]
|
||||
if return_dense:
|
||||
mask_dense = torch.kron(
|
||||
block_mask_dense,
|
||||
block_mask_dense.new_ones((block_size, block_size)),
|
||||
)
|
||||
causal_mask = torch.tril(torch.ones(
|
||||
max_seqlen, max_seqlen)).type_as(mask_dense)[-q_len:]
|
||||
mask_dense = mask_dense[..., -q_len:, :max_seqlen] * causal_mask[None]
|
||||
if dense_mask_type == "bias":
|
||||
mask_dense = binary_mask_to_bias(mask_dense)
|
||||
|
||||
return (
|
||||
dense_to_crow_col(block_mask_dense_output),
|
||||
block_mask_dense,
|
||||
mask_dense,
|
||||
)
|
||||
else:
|
||||
return (
|
||||
dense_to_crow_col(block_mask_dense_output),
|
||||
block_mask_dense,
|
||||
None,
|
||||
)
|
||||
@ -143,7 +143,6 @@ def get_attn_backend(
|
||||
kv_cache_dtype: Optional[str],
|
||||
block_size: int,
|
||||
is_attention_free: bool,
|
||||
is_blocksparse: bool = False,
|
||||
use_mla: bool = False,
|
||||
) -> type[AttentionBackend]:
|
||||
"""Selects which attention backend to use and lazily imports it."""
|
||||
@ -157,7 +156,6 @@ def get_attn_backend(
|
||||
kv_cache_dtype=kv_cache_dtype,
|
||||
block_size=block_size,
|
||||
is_attention_free=is_attention_free,
|
||||
is_blocksparse=is_blocksparse,
|
||||
use_v1=envs.VLLM_USE_V1,
|
||||
use_mla=use_mla,
|
||||
)
|
||||
@ -170,16 +168,9 @@ def _cached_get_attn_backend(
|
||||
kv_cache_dtype: Optional[str],
|
||||
block_size: int,
|
||||
is_attention_free: bool,
|
||||
is_blocksparse: bool = False,
|
||||
use_v1: bool = False,
|
||||
use_mla: bool = False,
|
||||
) -> type[AttentionBackend]:
|
||||
if is_blocksparse:
|
||||
logger.info("Using BlocksparseFlashAttention backend.")
|
||||
from vllm.attention.backends.blocksparse_attn import (
|
||||
BlocksparseFlashAttentionBackend)
|
||||
return BlocksparseFlashAttentionBackend
|
||||
|
||||
# If there are no attention layers (e.g. we are running Mamba),
|
||||
# use the placeholder NO_ATTENTION
|
||||
if is_attention_free:
|
||||
|
||||
@ -1,465 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import math
|
||||
from collections.abc import Iterable
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
|
||||
from vllm.attention import Attention
|
||||
from vllm.config import CacheConfig, VllmConfig
|
||||
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size)
|
||||
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from .interfaces import SupportsPP
|
||||
from .utils import (AutoWeightsLoader, WeightsMapper, is_pp_missing_parameter,
|
||||
make_empty_intermediate_tensors_factory, make_layers,
|
||||
maybe_prefix)
|
||||
|
||||
|
||||
def load_column_parallel_weight(param: torch.nn.Parameter,
|
||||
loaded_weight: torch.Tensor):
|
||||
tp = get_tensor_model_parallel_world_size()
|
||||
rk = get_tensor_model_parallel_rank()
|
||||
assert param.size(0) * tp == loaded_weight.size(0)
|
||||
s = rk * param.size(0)
|
||||
e = (rk + 1) * param.size(0)
|
||||
loaded_weight = loaded_weight[s:e]
|
||||
assert param.shape == loaded_weight.shape
|
||||
param.data.copy_(loaded_weight)
|
||||
|
||||
|
||||
class HeadMajorQKVParallelLinear(QKVParallelLinear):
|
||||
|
||||
def weight_loader(self, param: torch.nn.Parameter,
|
||||
loaded_weight: torch.Tensor):
|
||||
return load_column_parallel_weight(param, loaded_weight)
|
||||
|
||||
|
||||
class HeadMajorColumnParallelLinear(MergedColumnParallelLinear):
|
||||
|
||||
def weight_loader(self, param: torch.nn.Parameter,
|
||||
loaded_weight: torch.Tensor):
|
||||
return load_column_parallel_weight(param, loaded_weight)
|
||||
|
||||
|
||||
@torch.compile(dynamic=True, backend=current_platform.simple_compile_backend)
|
||||
def quick_gelu(x):
|
||||
return x * torch.sigmoid(1.702 * x)
|
||||
|
||||
|
||||
@torch.compile(dynamic=True, backend=current_platform.simple_compile_backend)
|
||||
def gegelu(input, limit: Optional[float] = None):
|
||||
a_gelu, a_linear = input[..., ::2], input[..., 1::2]
|
||||
if limit is not None:
|
||||
a_gelu = torch.where(torch.isinf(a_gelu), a_gelu,
|
||||
a_gelu.clamp(min=None, max=limit))
|
||||
a_linear = torch.where(
|
||||
torch.isinf(a_linear),
|
||||
a_linear,
|
||||
a_linear.clamp(min=-limit, max=limit),
|
||||
)
|
||||
out_gelu = quick_gelu(a_gelu)
|
||||
return out_gelu * (a_linear + 1)
|
||||
|
||||
|
||||
class Phi3SmallMLP(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
assert (self.config.hidden_act == "gegelu"
|
||||
), "Only `gegelu` is supported for the 4.7 series of models .."
|
||||
self.hidden_size = config.hidden_size
|
||||
self.gegelu_limit = config.gegelu_limit
|
||||
self.intermediate_size = config.intermediate_size
|
||||
|
||||
self.up_proj = HeadMajorColumnParallelLinear(
|
||||
self.hidden_size,
|
||||
2 * [self.intermediate_size],
|
||||
bias=True,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
self.down_proj = RowParallelLinear(
|
||||
self.intermediate_size,
|
||||
self.hidden_size,
|
||||
bias=True,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
gate_up, _ = self.up_proj(x)
|
||||
x = gegelu(gate_up)
|
||||
x, _ = self.down_proj(x)
|
||||
return x
|
||||
|
||||
|
||||
class Phi3SmallSelfAttention(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
layer_idx: int,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.layer_idx = layer_idx
|
||||
self.config = config
|
||||
self.sparse_block_size = config.blocksparse_block_size
|
||||
self.homo_heads = config.blocksparse_homo_head_pattern
|
||||
self.local_blocks = config.blocksparse_num_local_blocks
|
||||
self.vert_stride = config.blocksparse_vert_stride
|
||||
|
||||
assert (config.blocksparse_block_size ==
|
||||
config.blocksparse_triton_kernel_block_size)
|
||||
|
||||
self.hidden_size = config.hidden_size
|
||||
# Number of Query Heads
|
||||
self.num_heads = config.num_attention_heads
|
||||
|
||||
self.head_dim = self.hidden_size // self.num_heads
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
# Number of total Key Value Heads before tensor parallel
|
||||
self.num_key_value_heads = config.num_key_value_heads
|
||||
self.num_q_per_kv = self.num_heads // self.num_key_value_heads
|
||||
if self.tp_size > 1:
|
||||
assert self.num_key_value_heads % self.tp_size == 0
|
||||
self.num_kv_heads_per_partition = max(
|
||||
1, self.num_key_value_heads // self.tp_size)
|
||||
self.num_heads_per_partition = self.num_heads // self.tp_size
|
||||
|
||||
self.max_position_embeddings = config.max_position_embeddings
|
||||
self.rope_embedding_base = config.rope_embedding_base
|
||||
self.rope_position_scale = config.rope_position_scale
|
||||
self.is_causal = True
|
||||
|
||||
norm_factor = None
|
||||
if config.mup_use_scaling:
|
||||
norm_factor = self.head_dim / config.mup_attn_multiplier
|
||||
else:
|
||||
norm_factor = math.sqrt(self.head_dim)
|
||||
self.scale = 1 / norm_factor
|
||||
|
||||
self.query_key_value = HeadMajorQKVParallelLinear(
|
||||
self.hidden_size,
|
||||
self.head_dim,
|
||||
self.num_heads,
|
||||
self.num_key_value_heads,
|
||||
bias=True,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
|
||||
self.dense = RowParallelLinear(self.hidden_size,
|
||||
self.hidden_size,
|
||||
bias=True,
|
||||
quant_config=quant_config)
|
||||
|
||||
if getattr(self.config, "rope_scaling", None) is not None:
|
||||
rope_scaling = self.config.rope_scaling
|
||||
for key in rope_scaling:
|
||||
if isinstance(rope_scaling[key], list):
|
||||
rope_scaling[key] = tuple(rope_scaling[key])
|
||||
|
||||
if "factor" not in rope_scaling:
|
||||
rope_scaling["factor"] = self.rope_position_scale
|
||||
else:
|
||||
rope_scaling = {
|
||||
"rope_type": "linear",
|
||||
"factor": self.rope_position_scale,
|
||||
}
|
||||
|
||||
self.rotary_emb = get_rope(
|
||||
self.head_dim,
|
||||
rotary_dim=self.head_dim,
|
||||
max_position=self.max_position_embeddings,
|
||||
base=self.rope_embedding_base,
|
||||
rope_scaling=rope_scaling,
|
||||
)
|
||||
|
||||
# blocksparse params
|
||||
self.blocksparse_block_size = config.blocksparse_block_size
|
||||
self.blocksparse_num_local_blocks = config.blocksparse_num_local_blocks
|
||||
self.blocksparse_vert_stride = config.blocksparse_vert_stride
|
||||
|
||||
use_dense_attn = (getattr(self.config,
|
||||
"dense_attention_every_n_layers", None)
|
||||
and (self.layer_idx + 1) %
|
||||
self.config.dense_attention_every_n_layers == 0)
|
||||
|
||||
bs_params = None
|
||||
if not use_dense_attn:
|
||||
bs_params = {
|
||||
'max_seqlen': self.max_position_embeddings,
|
||||
'num_heads': self.num_heads_per_partition,
|
||||
"num_kv_heads": self.num_kv_heads_per_partition,
|
||||
"block_size": self.sparse_block_size,
|
||||
"local_blocks": self.local_blocks,
|
||||
"vert_stride": self.vert_stride,
|
||||
"homo_head": self.homo_heads
|
||||
}
|
||||
|
||||
self.attn = Attention(self.num_heads_per_partition,
|
||||
self.head_dim,
|
||||
self.scale,
|
||||
num_kv_heads=self.num_kv_heads_per_partition,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
blocksparse_params=bs_params,
|
||||
prefix=f"{prefix}.attn")
|
||||
|
||||
def forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, Optional[torch.Tensor],
|
||||
Optional[tuple[torch.Tensor]]]:
|
||||
qkv, _ = self.query_key_value(hidden_states)
|
||||
|
||||
qkv = qkv.view(qkv.shape[:-1] +
|
||||
(-1, (self.num_q_per_kv + 2), self.head_dim))
|
||||
q, k, v = qkv.split([self.num_q_per_kv, 1, 1], dim=-2)
|
||||
|
||||
# NOTE: this is required by RotaryEmbed, which indeed does not have to
|
||||
# TODO: allow 3D QK for rotary forward
|
||||
q = q.reshape(-1, self.head_dim * self.num_heads_per_partition)
|
||||
k = k.reshape(-1, self.head_dim * self.num_kv_heads_per_partition)
|
||||
v = v.reshape(-1, self.head_dim * self.num_kv_heads_per_partition)
|
||||
|
||||
q, k = self.rotary_emb(positions, q, k)
|
||||
attn_output = self.attn(q, k, v)
|
||||
output, _ = self.dense(attn_output)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
class Phi3SmallDecoderLayer(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
layer_idx: int,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
self.self_attn = Phi3SmallSelfAttention(config,
|
||||
layer_idx,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.self_attn")
|
||||
self.mlp = Phi3SmallMLP(config, quant_config)
|
||||
|
||||
self.input_layernorm = nn.LayerNorm(config.hidden_size,
|
||||
eps=config.layer_norm_epsilon)
|
||||
self.post_attention_layernorm = nn.LayerNorm(
|
||||
config.hidden_size, eps=config.layer_norm_epsilon)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
residual = hidden_states
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
|
||||
hidden_states = self.self_attn(
|
||||
positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
residual = hidden_states
|
||||
hidden_states = self.post_attention_layernorm(hidden_states)
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
hidden_states = residual + hidden_states
|
||||
return hidden_states
|
||||
|
||||
|
||||
class Phi3SmallModel(nn.Module):
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
|
||||
config = vllm_config.model_config.hf_config
|
||||
cache_config = vllm_config.cache_config
|
||||
quant_config = vllm_config.quant_config
|
||||
|
||||
self.config = config
|
||||
self.embed_tokens = VocabParallelEmbedding(config.vocab_size,
|
||||
config.hidden_size)
|
||||
self.mup_embedding_multiplier = config.mup_embedding_multiplier
|
||||
self.start_layer, self.end_layer, self.layers = make_layers(
|
||||
config.num_hidden_layers,
|
||||
lambda prefix: Phi3SmallDecoderLayer(config,
|
||||
int(prefix.split('.')[-1]),
|
||||
cache_config,
|
||||
quant_config,
|
||||
prefix=prefix),
|
||||
prefix=f"{prefix}.layers")
|
||||
|
||||
self.final_layernorm = nn.LayerNorm(config.hidden_size,
|
||||
eps=config.layer_norm_epsilon)
|
||||
self.make_empty_intermediate_tensors = (
|
||||
make_empty_intermediate_tensors_factory(["hidden_states"],
|
||||
config.hidden_size))
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.embed_tokens(input_ids)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor,
|
||||
positions: Optional[torch.LongTensor],
|
||||
intermediate_tensors: Optional[IntermediateTensors],
|
||||
inputs_embeds: Optional[torch.Tensor],
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
if get_pp_group().is_first_rank:
|
||||
if inputs_embeds is not None:
|
||||
hidden_states = inputs_embeds
|
||||
else:
|
||||
hidden_states = self.get_input_embeddings(input_ids)
|
||||
if (self.mup_embedding_multiplier is not None
|
||||
and self.mup_embedding_multiplier > 0.0):
|
||||
hidden_states = hidden_states * self.mup_embedding_multiplier
|
||||
else:
|
||||
assert intermediate_tensors
|
||||
hidden_states = intermediate_tensors["hidden_states"]
|
||||
for layer in self.layers[self.start_layer:self.end_layer]:
|
||||
hidden_states = layer(positions, hidden_states)
|
||||
if not get_pp_group().is_last_rank:
|
||||
return IntermediateTensors({"hidden_states": hidden_states})
|
||||
hidden_states = self.final_layernorm(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str,
|
||||
torch.Tensor]]) -> set[str]:
|
||||
params_dict = dict(self.named_parameters())
|
||||
loaded_params: set[str] = set()
|
||||
for name, loaded_weight in weights:
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(name)
|
||||
return loaded_params
|
||||
|
||||
|
||||
class Phi3SmallForCausalLM(nn.Module, SupportsPP):
|
||||
_tied_weights_keys = ["lm_head.weight"]
|
||||
|
||||
hf_to_vllm_mapper = WeightsMapper(
|
||||
orig_to_new_suffix={"rotary_emb.inv_freq": None})
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
config = vllm_config.model_config.hf_config
|
||||
quant_config = vllm_config.quant_config
|
||||
self.config = config
|
||||
self.quant_config = quant_config
|
||||
self.model = Phi3SmallModel(vllm_config=vllm_config,
|
||||
prefix=maybe_prefix(prefix, "model"))
|
||||
self.vocab_size = config.vocab_size
|
||||
self.mup_width_multiplier = config.mup_width_multiplier
|
||||
self.lm_head = ParallelLMHead(
|
||||
self.vocab_size,
|
||||
config.hidden_size,
|
||||
org_num_embeddings=config.vocab_size,
|
||||
padding_size=DEFAULT_VOCAB_PADDING_SIZE,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
if self.config.tie_word_embeddings:
|
||||
self.lm_head.weight = self.model.embed_tokens.weight
|
||||
self.logits_processor = LogitsProcessor(config.vocab_size)
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.model.make_empty_intermediate_tensors)
|
||||
|
||||
# tokens in tiktoken but not used
|
||||
if hasattr(config, 'dummy_token_indices'):
|
||||
device = self.lm_head.weight.device
|
||||
self.register_buffer('dummy_token_indices',
|
||||
torch.LongTensor(
|
||||
config.dummy_token_indices).to(device),
|
||||
persistent=False)
|
||||
else:
|
||||
self.dummy_token_indices = None
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.model.get_input_embeddings(input_ids)
|
||||
|
||||
def set_input_embeddings(self, value):
|
||||
self.model.embed_tokens = value
|
||||
|
||||
def get_output_embeddings(self):
|
||||
return self.lm_head
|
||||
|
||||
def set_output_embeddings(self, value):
|
||||
self.lm_head = value
|
||||
|
||||
def set_decoder(self, decoder):
|
||||
self.model = decoder
|
||||
|
||||
def get_decoder(self):
|
||||
return self.model
|
||||
|
||||
def compute_logits(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> Optional[torch.Tensor]:
|
||||
logits = self.logits_processor(self.lm_head, hidden_states,
|
||||
sampling_metadata)
|
||||
if self.dummy_token_indices is not None and logits is not None:
|
||||
logits.index_fill_(-1, self.dummy_token_indices, -torch.inf)
|
||||
logits = logits / self.mup_width_multiplier
|
||||
return logits
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor,
|
||||
positions: Optional[torch.LongTensor],
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
output_hidden_states = self.model(
|
||||
input_ids=input_ids,
|
||||
positions=positions,
|
||||
intermediate_tensors=intermediate_tensors,
|
||||
inputs_embeds=inputs_embeds,
|
||||
)
|
||||
output_hidden_states = output_hidden_states
|
||||
return output_hidden_states
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str,
|
||||
torch.Tensor]]) -> set[str]:
|
||||
loader = AutoWeightsLoader(
|
||||
self,
|
||||
skip_prefixes=(["lm_head.weight"]
|
||||
if self.config.tie_word_embeddings else None))
|
||||
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
|
||||
@ -110,7 +110,6 @@ _TEXT_GENERATION_MODELS = {
|
||||
"PersimmonForCausalLM": ("persimmon", "PersimmonForCausalLM"),
|
||||
"PhiForCausalLM": ("phi", "PhiForCausalLM"),
|
||||
"Phi3ForCausalLM": ("phi3", "Phi3ForCausalLM"),
|
||||
"Phi3SmallForCausalLM": ("phi3_small", "Phi3SmallForCausalLM"),
|
||||
"PhiMoEForCausalLM": ("phimoe", "PhiMoEForCausalLM"),
|
||||
"Phi4FlashForCausalLM": ("phi4flash", "Phi4FlashForCausalLM"),
|
||||
"Plamo2ForCausalLM": ("plamo2", "Plamo2ForCausalLM"),
|
||||
|
||||
@ -57,7 +57,6 @@ class _Backend(enum.Enum):
|
||||
PALLAS = enum.auto()
|
||||
PALLAS_VLLM_V1 = enum.auto()
|
||||
IPEX = enum.auto()
|
||||
BLOCK_SPARSE_FLASH_ATTN = enum.auto()
|
||||
DUAL_CHUNK_FLASH_ATTN = enum.auto()
|
||||
DIFFERENTIAL_FLASH_ATTN = enum.auto()
|
||||
NO_ATTENTION = enum.auto()
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Optional
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@ -443,7 +443,6 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
|
||||
alibi_slopes: Optional[list[float]],
|
||||
sliding_window: Optional[int],
|
||||
kv_cache_dtype: str,
|
||||
blocksparse_params: Optional[dict[str, Any]] = None,
|
||||
logits_soft_cap: Optional[float] = None,
|
||||
attn_type: str = AttentionType.DECODER,
|
||||
kv_sharing_target_layer_name: Optional[str] = None,
|
||||
@ -451,9 +450,6 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
|
||||
) -> None:
|
||||
if kv_sharing_target_layer_name is not None:
|
||||
raise NotImplementedError("KV sharing is not supported in V0.")
|
||||
if blocksparse_params is not None:
|
||||
raise ValueError(
|
||||
"Torch SPDA does not support block-sparse attention.")
|
||||
if logits_soft_cap is not None:
|
||||
logger.warning_once("Torch SPDA does not support logits soft cap. "
|
||||
"Outputs may be slightly off.")
|
||||
|
||||
@ -2,7 +2,7 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Attention layer with FlashAttention."""
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, ClassVar, Optional
|
||||
from typing import ClassVar, Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@ -349,15 +349,11 @@ class FlashAttentionImpl(AttentionImpl):
|
||||
alibi_slopes: Optional[list[float]],
|
||||
sliding_window: Optional[int],
|
||||
kv_cache_dtype: str,
|
||||
blocksparse_params: Optional[dict[str, Any]] = None,
|
||||
logits_soft_cap: Optional[float] = None,
|
||||
attn_type: AttentionType = AttentionType.DECODER,
|
||||
kv_sharing_target_layer_name: Optional[str] = None,
|
||||
use_irope: bool = False,
|
||||
) -> None:
|
||||
if blocksparse_params is not None:
|
||||
raise ValueError(
|
||||
"FlashAttention does not support block-sparse attention.")
|
||||
self.num_heads = num_heads
|
||||
self.head_size = head_size
|
||||
self.scale = float(scale)
|
||||
|
||||
@ -4,7 +4,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
import torch
|
||||
from flashinfer import (BatchDecodeWithPagedKVCacheWrapper,
|
||||
@ -490,7 +490,6 @@ class FlashInferImpl(AttentionImpl):
|
||||
alibi_slopes: Optional[list[float]],
|
||||
sliding_window: Optional[int],
|
||||
kv_cache_dtype: str,
|
||||
blocksparse_params: Optional[dict[str, Any]] = None,
|
||||
logits_soft_cap: Optional[float] = None,
|
||||
attn_type: AttentionType = AttentionType.DECODER,
|
||||
kv_sharing_target_layer_name: Optional[int] = None,
|
||||
|
||||
@ -3,7 +3,7 @@
|
||||
"""Attention layer with FlashAttention."""
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Optional
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from torch.nn.attention.flex_attention import (BlockMask, _mask_mod_signature,
|
||||
@ -342,15 +342,10 @@ class FlexAttentionImpl(AttentionImpl):
|
||||
alibi_slopes: Optional[list[float]],
|
||||
sliding_window: Optional[int],
|
||||
kv_cache_dtype: str,
|
||||
blocksparse_params: Optional[dict[str, Any]] = None,
|
||||
logits_soft_cap: Optional[float] = None,
|
||||
attn_type: AttentionType = AttentionType.DECODER,
|
||||
kv_sharing_target_layer_name: Optional[str] = None,
|
||||
) -> None:
|
||||
if blocksparse_params is not None:
|
||||
# TODO we should support this :think
|
||||
raise ValueError(
|
||||
"FlashAttention does not support block-sparse attention.")
|
||||
self.num_heads = num_heads
|
||||
self.head_size = head_size
|
||||
self.scale = float(scale)
|
||||
|
||||
@ -190,7 +190,7 @@ return curr_o @ W_O
|
||||
import functools
|
||||
from abc import abstractmethod
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING, Any, Generic, Optional, TypeVar, Union
|
||||
from typing import TYPE_CHECKING, Generic, Optional, TypeVar, Union
|
||||
|
||||
import torch
|
||||
|
||||
@ -754,7 +754,6 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
|
||||
alibi_slopes: Optional[list[float]],
|
||||
sliding_window: Optional[int],
|
||||
kv_cache_dtype: str,
|
||||
blocksparse_params: Optional[dict[str, Any]],
|
||||
logits_soft_cap: Optional[float],
|
||||
attn_type: str,
|
||||
kv_sharing_target_layer_name: Optional[str],
|
||||
|
||||
@ -2,7 +2,7 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import os
|
||||
from typing import Any, Optional
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
@ -74,7 +74,6 @@ class CutlassMLAImpl(MLACommonImpl[MLACommonMetadata]):
|
||||
alibi_slopes: Optional[list[float]],
|
||||
sliding_window: Optional[int],
|
||||
kv_cache_dtype: str,
|
||||
blocksparse_params: Optional[dict[str, Any]],
|
||||
logits_soft_cap: Optional[float],
|
||||
attn_type: str,
|
||||
kv_sharing_target_layer_name: Optional[str],
|
||||
@ -82,17 +81,14 @@ class CutlassMLAImpl(MLACommonImpl[MLACommonMetadata]):
|
||||
**mla_args) -> None:
|
||||
super().__init__(num_heads, head_size, scale, num_kv_heads,
|
||||
alibi_slopes, sliding_window, kv_cache_dtype,
|
||||
blocksparse_params, logits_soft_cap, attn_type,
|
||||
logits_soft_cap, attn_type,
|
||||
kv_sharing_target_layer_name, **mla_args)
|
||||
|
||||
unsupported_features = [
|
||||
alibi_slopes, sliding_window, blocksparse_params, logits_soft_cap
|
||||
]
|
||||
unsupported_features = [alibi_slopes, sliding_window, logits_soft_cap]
|
||||
if any(unsupported_features):
|
||||
raise NotImplementedError(
|
||||
"CutlassMLAImpl does not support one of the following: "
|
||||
"alibi_slopes, sliding_window, blocksparse_params, "
|
||||
"logits_soft_cap")
|
||||
"alibi_slopes, sliding_window, logits_soft_cap")
|
||||
|
||||
if attn_type != AttentionType.DECODER:
|
||||
raise NotImplementedError("Encoder self-attention and "
|
||||
|
||||
@ -2,7 +2,7 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, ClassVar, Optional
|
||||
from typing import ClassVar, Optional
|
||||
|
||||
import torch
|
||||
|
||||
@ -119,7 +119,6 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
|
||||
alibi_slopes: Optional[list[float]],
|
||||
sliding_window: Optional[int],
|
||||
kv_cache_dtype: str,
|
||||
blocksparse_params: Optional[dict[str, Any]],
|
||||
logits_soft_cap: Optional[float],
|
||||
attn_type: str,
|
||||
kv_sharing_target_layer_name: Optional[str],
|
||||
@ -127,20 +126,17 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
|
||||
**mla_args) -> None:
|
||||
super().__init__(num_heads, head_size, scale, num_kv_heads,
|
||||
alibi_slopes, sliding_window, kv_cache_dtype,
|
||||
blocksparse_params, logits_soft_cap, attn_type,
|
||||
logits_soft_cap, attn_type,
|
||||
kv_sharing_target_layer_name, **mla_args)
|
||||
|
||||
assert is_flashmla_supported(), \
|
||||
"FlashMLA is not supported on this device"
|
||||
|
||||
unsupported_features = [
|
||||
alibi_slopes, sliding_window, blocksparse_params, logits_soft_cap
|
||||
]
|
||||
unsupported_features = [alibi_slopes, sliding_window, logits_soft_cap]
|
||||
if any(unsupported_features):
|
||||
raise NotImplementedError(
|
||||
"FlashMLAImpl does not support one of the following: "
|
||||
"alibi_slopes, sliding_window, blocksparse_params, "
|
||||
"logits_soft_cap")
|
||||
"alibi_slopes, sliding_window, logits_soft_cap")
|
||||
|
||||
if attn_type != AttentionType.DECODER:
|
||||
raise NotImplementedError("Encoder self-attention and "
|
||||
|
||||
@ -2,7 +2,7 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, ClassVar, Optional
|
||||
from typing import ClassVar, Optional
|
||||
|
||||
import torch
|
||||
|
||||
@ -167,7 +167,6 @@ class AiterMLAImpl(MLACommonImpl[AiterMLAMetadata]):
|
||||
alibi_slopes: Optional[list[float]],
|
||||
sliding_window: Optional[int],
|
||||
kv_cache_dtype: str,
|
||||
blocksparse_params: Optional[dict[str, Any]],
|
||||
logits_soft_cap: Optional[float],
|
||||
attn_type: str,
|
||||
kv_sharing_target_layer_name: Optional[str],
|
||||
@ -175,20 +174,17 @@ class AiterMLAImpl(MLACommonImpl[AiterMLAMetadata]):
|
||||
**mla_args) -> None:
|
||||
super().__init__(num_heads, head_size, scale, num_kv_heads,
|
||||
alibi_slopes, sliding_window, kv_cache_dtype,
|
||||
blocksparse_params, logits_soft_cap, attn_type,
|
||||
logits_soft_cap, attn_type,
|
||||
kv_sharing_target_layer_name, **mla_args)
|
||||
assert (num_heads == 16 or num_heads == 128), (
|
||||
f"Aiter MLA only supports 16 or 128 number of heads.\n"
|
||||
f"Provided {num_heads} number of heads.\n"
|
||||
"Try adjusting tensor_parallel_size value.")
|
||||
unsupported_features = [
|
||||
alibi_slopes, sliding_window, blocksparse_params, logits_soft_cap
|
||||
]
|
||||
unsupported_features = [alibi_slopes, sliding_window, logits_soft_cap]
|
||||
if any(unsupported_features):
|
||||
raise NotImplementedError(
|
||||
"Aiter MLA does not support one of the following: "
|
||||
"alibi_slopes, sliding_window, blocksparse_params, "
|
||||
"logits_soft_cap")
|
||||
"alibi_slopes, sliding_window, logits_soft_cap")
|
||||
|
||||
from aiter import flash_attn_varlen_func
|
||||
self.flash_attn_varlen_func = flash_attn_varlen_func
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Any, Optional
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
@ -42,7 +42,6 @@ class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]):
|
||||
alibi_slopes: Optional[list[float]],
|
||||
sliding_window: Optional[int],
|
||||
kv_cache_dtype: str,
|
||||
blocksparse_params: Optional[dict[str, Any]],
|
||||
logits_soft_cap: Optional[float],
|
||||
attn_type: str,
|
||||
kv_sharing_target_layer_name: Optional[str],
|
||||
@ -50,17 +49,14 @@ class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]):
|
||||
**mla_args) -> None:
|
||||
super().__init__(num_heads, head_size, scale, num_kv_heads,
|
||||
alibi_slopes, sliding_window, kv_cache_dtype,
|
||||
blocksparse_params, logits_soft_cap, attn_type,
|
||||
logits_soft_cap, attn_type,
|
||||
kv_sharing_target_layer_name, **mla_args)
|
||||
|
||||
unsupported_features = [
|
||||
alibi_slopes, sliding_window, blocksparse_params, logits_soft_cap
|
||||
]
|
||||
unsupported_features = [alibi_slopes, sliding_window, logits_soft_cap]
|
||||
if any(unsupported_features):
|
||||
raise NotImplementedError(
|
||||
"TritonMLAImpl does not support one of the following: "
|
||||
"alibi_slopes, sliding_window, blocksparse_params, "
|
||||
"logits_soft_cap")
|
||||
"alibi_slopes, sliding_window, logits_soft_cap")
|
||||
|
||||
if attn_type != AttentionType.DECODER:
|
||||
raise NotImplementedError("Encoder self-attention and "
|
||||
|
||||
@ -2,7 +2,7 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Optional
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch_xla.core.xla_builder as xb
|
||||
@ -132,7 +132,6 @@ class PallasAttentionBackendImpl(AttentionImpl):
|
||||
alibi_slopes: Optional[list[float]],
|
||||
sliding_window: Optional[int],
|
||||
kv_cache_dtype: str,
|
||||
blocksparse_params: Optional[dict[str, Any]] = None,
|
||||
logits_soft_cap: Optional[float] = None,
|
||||
attn_type: str = AttentionType.DECODER,
|
||||
kv_sharing_target_layer_name: Optional[int] = None,
|
||||
@ -142,9 +141,6 @@ class PallasAttentionBackendImpl(AttentionImpl):
|
||||
logger.warning_once(
|
||||
"Using irope in Pallas is not supported yet, it will fall back "
|
||||
"to global attention for long context.")
|
||||
if blocksparse_params is not None:
|
||||
raise ValueError("Paged attention Pallas kernel does "
|
||||
"not support block-sparse attention.")
|
||||
self.num_heads = num_heads
|
||||
self.head_size = head_size
|
||||
self.scale = float(scale)
|
||||
@ -158,8 +154,6 @@ class PallasAttentionBackendImpl(AttentionImpl):
|
||||
raise NotImplementedError("Alibi slopes is not supported.")
|
||||
if kv_cache_dtype != "auto":
|
||||
raise NotImplementedError("FP8 KV cache dtype is not supported.")
|
||||
if blocksparse_params is not None:
|
||||
raise NotImplementedError("Blocksparse is not supported.")
|
||||
|
||||
if attn_type != AttentionType.DECODER:
|
||||
raise NotImplementedError("Encoder self-attention and "
|
||||
|
||||
@ -2,7 +2,7 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Attention layer with AiterFlashAttention."""
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Optional
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
@ -334,15 +334,11 @@ class AiterFlashAttentionImpl(AttentionImpl):
|
||||
alibi_slopes: Optional[list[float]],
|
||||
sliding_window: Optional[int],
|
||||
kv_cache_dtype: str,
|
||||
blocksparse_params: Optional[dict[str, Any]] = None,
|
||||
logits_soft_cap: Optional[float] = None,
|
||||
attn_type: AttentionType = AttentionType.DECODER,
|
||||
kv_sharing_target_layer_name: Optional[int] = None,
|
||||
use_irope: bool = False,
|
||||
) -> None:
|
||||
if blocksparse_params is not None:
|
||||
raise ValueError(
|
||||
"AiterFlashAttention does not support block-sparse attention.")
|
||||
self.num_heads = num_heads
|
||||
self.head_size = head_size
|
||||
self.scale = float(scale)
|
||||
|
||||
@ -2,7 +2,7 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Attention layer with PagedAttention and Triton prefix prefill."""
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, ClassVar, Optional
|
||||
from typing import ClassVar, Optional
|
||||
|
||||
import torch
|
||||
|
||||
@ -205,15 +205,11 @@ class TritonAttentionImpl(AttentionImpl):
|
||||
alibi_slopes: Optional[list[float]],
|
||||
sliding_window: Optional[int],
|
||||
kv_cache_dtype: str,
|
||||
blocksparse_params: Optional[dict[str, Any]] = None,
|
||||
logits_soft_cap: Optional[float] = None,
|
||||
attn_type: AttentionType = AttentionType.DECODER,
|
||||
kv_sharing_target_layer_name: Optional[int] = None,
|
||||
use_irope: bool = False,
|
||||
) -> None:
|
||||
if blocksparse_params is not None:
|
||||
raise ValueError(
|
||||
"TritonAttention does not support block-sparse attention.")
|
||||
self.num_heads = num_heads
|
||||
self.head_size = head_size
|
||||
self.scale = float(scale)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user