mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 10:30:37 +08:00
register custom op for flash attn and use from torch.ops (#7536)
This commit is contained in:
parent
50b8d08dbd
commit
54bd9a03c4
@ -163,6 +163,13 @@ steps:
|
||||
- pytest -v -s models/test_oot_registration.py # it needs a clean process
|
||||
- pytest -v -s models -m \"not vlm\" --ignore=models/test_oot_registration.py
|
||||
|
||||
- label: torch compile integration test
|
||||
source_file_dependencies:
|
||||
- vllm/
|
||||
commands:
|
||||
- pytest -v -s ./compile/test_full_graph.py
|
||||
|
||||
|
||||
- label: Vision Language Models Test # 42min
|
||||
mirror_hardwares: [amd]
|
||||
source_file_dependencies:
|
||||
|
||||
20
tests/compile/test_full_graph.py
Normal file
20
tests/compile/test_full_graph.py
Normal file
@ -0,0 +1,20 @@
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", ["meta-llama/Meta-Llama-3-8B"])
|
||||
def test_full_graph(model):
|
||||
# make sure these models can be captured in full graph mode
|
||||
os.environ["VLLM_TEST_DYNAMO_GRAPH_CAPTURE"] = "1"
|
||||
|
||||
from vllm import LLM, SamplingParams
|
||||
prompts = [
|
||||
"Hello, my name is",
|
||||
"The president of the United States is",
|
||||
"The capital of France is",
|
||||
"The future of AI is",
|
||||
]
|
||||
sampling_params = SamplingParams(temperature=0)
|
||||
llm = LLM(model="meta-llama/Meta-Llama-3-8B")
|
||||
llm.generate(prompts, sampling_params)
|
||||
@ -2,13 +2,16 @@ from typing import List, Optional, Tuple
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from vllm_flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache
|
||||
|
||||
NUM_HEADS = [(16, 16), (32, 8), (64, 8)]
|
||||
import vllm.attention.backends.flash_attn # noqa: F401
|
||||
|
||||
NUM_HEADS = [(4, 4), (8, 2), (16, 2)]
|
||||
HEAD_SIZES = [128, 256]
|
||||
BLOCK_SIZES = [16, 32]
|
||||
DTYPES = [torch.float16, torch.bfloat16]
|
||||
NUM_BLOCKS = 32768 # Large enough to test overflow in index calculation.
|
||||
# one value large enough to test overflow in index calculation.
|
||||
# one value small enough to test the schema op check
|
||||
NUM_BLOCKS = [32768, 2048]
|
||||
|
||||
|
||||
def ref_paged_attn(
|
||||
@ -72,6 +75,7 @@ def ref_paged_attn(
|
||||
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("soft_cap", [None, 10.0, 50.0])
|
||||
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
|
||||
@torch.inference_mode()
|
||||
def test_flash_attn_with_paged_kv(
|
||||
kv_lens: List[int],
|
||||
@ -80,6 +84,7 @@ def test_flash_attn_with_paged_kv(
|
||||
dtype: torch.dtype,
|
||||
block_size: int,
|
||||
soft_cap: Optional[float],
|
||||
num_blocks: int,
|
||||
) -> None:
|
||||
torch.set_default_device("cuda")
|
||||
torch.cuda.manual_seed_all(0)
|
||||
@ -91,7 +96,7 @@ def test_flash_attn_with_paged_kv(
|
||||
scale = head_size**-0.5
|
||||
|
||||
query = torch.randn(num_seqs, num_query_heads, head_size, dtype=dtype)
|
||||
key_cache = torch.randn(NUM_BLOCKS,
|
||||
key_cache = torch.randn(num_blocks,
|
||||
block_size,
|
||||
num_kv_heads,
|
||||
head_size,
|
||||
@ -101,14 +106,14 @@ def test_flash_attn_with_paged_kv(
|
||||
|
||||
max_num_blocks_per_seq = (max_kv_len + block_size - 1) // block_size
|
||||
block_tables = torch.randint(0,
|
||||
NUM_BLOCKS,
|
||||
num_blocks,
|
||||
(num_seqs, max_num_blocks_per_seq),
|
||||
dtype=torch.int32)
|
||||
|
||||
output = flash_attn_with_kvcache(
|
||||
q=query.unsqueeze(1),
|
||||
k_cache=key_cache,
|
||||
v_cache=value_cache,
|
||||
output = torch.ops.vllm.flash_attn_with_kvcache(
|
||||
decode_query=query.unsqueeze(1),
|
||||
key_cache=key_cache,
|
||||
value_cache=value_cache,
|
||||
softmax_scale=scale,
|
||||
causal=True,
|
||||
block_table=block_tables,
|
||||
@ -116,6 +121,25 @@ def test_flash_attn_with_paged_kv(
|
||||
softcap=soft_cap if soft_cap is not None else 0,
|
||||
).squeeze(1)
|
||||
|
||||
if num_blocks <= 2048:
|
||||
test_utils = ["test_faketensor", "test_schema"]
|
||||
else:
|
||||
test_utils = ["test_faketensor"]
|
||||
|
||||
torch.library.opcheck(torch.ops.vllm.flash_attn_with_kvcache,
|
||||
args=tuple(),
|
||||
kwargs=dict(
|
||||
decode_query=query.unsqueeze(1),
|
||||
key_cache=key_cache,
|
||||
value_cache=value_cache,
|
||||
softmax_scale=scale,
|
||||
causal=True,
|
||||
block_table=block_tables,
|
||||
cache_seqlens=kv_lens_tensor,
|
||||
softcap=soft_cap if soft_cap is not None else 0,
|
||||
),
|
||||
test_utils=test_utils)
|
||||
|
||||
ref_output = ref_paged_attn(
|
||||
query=query,
|
||||
key_cache=key_cache,
|
||||
@ -137,6 +161,7 @@ def test_flash_attn_with_paged_kv(
|
||||
@pytest.mark.parametrize("sliding_window", [None])
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("soft_cap", [None, 10.0, 50.0])
|
||||
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
|
||||
@torch.inference_mode()
|
||||
def test_varlen_with_paged_kv(
|
||||
seq_lens: List[Tuple[int, int]],
|
||||
@ -146,6 +171,7 @@ def test_varlen_with_paged_kv(
|
||||
dtype: torch.dtype,
|
||||
block_size: int,
|
||||
soft_cap: Optional[float],
|
||||
num_blocks: int,
|
||||
) -> None:
|
||||
torch.set_default_device("cuda")
|
||||
torch.cuda.manual_seed_all(0)
|
||||
@ -166,7 +192,7 @@ def test_varlen_with_paged_kv(
|
||||
num_query_heads,
|
||||
head_size,
|
||||
dtype=dtype)
|
||||
key_cache = torch.randn(NUM_BLOCKS,
|
||||
key_cache = torch.randn(num_blocks,
|
||||
block_size,
|
||||
num_kv_heads,
|
||||
head_size,
|
||||
@ -181,11 +207,11 @@ def test_varlen_with_paged_kv(
|
||||
|
||||
max_num_blocks_per_seq = (max_kv_len + block_size - 1) // block_size
|
||||
block_tables = torch.randint(0,
|
||||
NUM_BLOCKS,
|
||||
num_blocks,
|
||||
(num_seqs, max_num_blocks_per_seq),
|
||||
dtype=torch.int32)
|
||||
|
||||
output = flash_attn_varlen_func(
|
||||
output = torch.ops.vllm.flash_attn_varlen_func(
|
||||
q=query,
|
||||
k=key_cache,
|
||||
v=value_cache,
|
||||
@ -200,6 +226,29 @@ def test_varlen_with_paged_kv(
|
||||
softcap=soft_cap if soft_cap is not None else 0,
|
||||
)
|
||||
|
||||
if num_blocks <= 2048:
|
||||
test_utils = ["test_faketensor", "test_schema"]
|
||||
else:
|
||||
test_utils = ["test_faketensor"]
|
||||
|
||||
torch.library.opcheck(torch.ops.vllm.flash_attn_varlen_func,
|
||||
args=tuple(),
|
||||
kwargs=dict(
|
||||
q=query,
|
||||
k=key_cache,
|
||||
v=value_cache,
|
||||
cu_seqlens_q=cu_query_lens,
|
||||
cu_seqlens_k=cu_kv_lens,
|
||||
max_seqlen_q=max_query_len,
|
||||
max_seqlen_k=max_kv_len,
|
||||
softmax_scale=scale,
|
||||
causal=True,
|
||||
window_size=window_size,
|
||||
block_table=block_tables,
|
||||
softcap=soft_cap if soft_cap is not None else 0,
|
||||
),
|
||||
test_utils=test_utils)
|
||||
|
||||
ref_output = ref_paged_attn(
|
||||
query=query,
|
||||
key_cache=key_cache,
|
||||
|
||||
@ -3,7 +3,6 @@ from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type
|
||||
|
||||
import torch
|
||||
from vllm_flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||
@ -18,6 +17,108 @@ from vllm.utils import async_tensor_h2d, make_tensor_with_pad
|
||||
if TYPE_CHECKING:
|
||||
from vllm.worker.model_runner import ModelInputForGPUBuilder
|
||||
|
||||
from vllm_flash_attn import flash_attn_varlen_func as _flash_attn_varlen_func
|
||||
from vllm_flash_attn import flash_attn_with_kvcache as _flash_attn_with_kvcache
|
||||
|
||||
|
||||
@torch.library.custom_op("vllm::flash_attn_varlen_func", mutates_args=[])
|
||||
def flash_attn_varlen_func(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
cu_seqlens_q: torch.Tensor,
|
||||
cu_seqlens_k: torch.Tensor,
|
||||
max_seqlen_q: int,
|
||||
max_seqlen_k: int,
|
||||
softmax_scale: Optional[float] = None,
|
||||
causal: bool = False,
|
||||
window_size: Optional[List[int]] = None,
|
||||
softcap: float = 0.0,
|
||||
alibi_slopes: Optional[torch.Tensor] = None,
|
||||
block_table: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
# custom op does not support tuple input
|
||||
real_window_size: Tuple[int, int]
|
||||
if window_size is None:
|
||||
real_window_size = (-1, -1)
|
||||
else:
|
||||
assert len(window_size) == 2
|
||||
real_window_size = (window_size[0], window_size[1])
|
||||
return _flash_attn_varlen_func(
|
||||
q=q,
|
||||
k=k,
|
||||
v=v,
|
||||
cu_seqlens_q=cu_seqlens_q,
|
||||
cu_seqlens_k=cu_seqlens_k,
|
||||
max_seqlen_q=max_seqlen_q,
|
||||
max_seqlen_k=max_seqlen_k,
|
||||
softmax_scale=softmax_scale,
|
||||
causal=causal,
|
||||
window_size=real_window_size,
|
||||
softcap=softcap,
|
||||
alibi_slopes=alibi_slopes,
|
||||
block_table=block_table,
|
||||
)
|
||||
|
||||
|
||||
@flash_attn_varlen_func.register_fake # type: ignore
|
||||
def _(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
cu_seqlens_q: torch.Tensor,
|
||||
cu_seqlens_k: torch.Tensor,
|
||||
max_seqlen_q: int,
|
||||
max_seqlen_k: int,
|
||||
softmax_scale: Optional[float] = None,
|
||||
causal: bool = False,
|
||||
window_size: Optional[List[int]] = None,
|
||||
softcap: float = 0.0,
|
||||
alibi_slopes: Optional[torch.Tensor] = None,
|
||||
block_table: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
return torch.empty_like(q)
|
||||
|
||||
|
||||
@torch.library.custom_op("vllm::flash_attn_with_kvcache", mutates_args=[])
|
||||
def flash_attn_with_kvcache(
|
||||
decode_query: torch.Tensor,
|
||||
key_cache: torch.Tensor,
|
||||
value_cache: torch.Tensor,
|
||||
cache_seqlens: Optional[torch.Tensor] = None,
|
||||
block_table: Optional[torch.Tensor] = None,
|
||||
softmax_scale: Optional[float] = None,
|
||||
causal: bool = False,
|
||||
alibi_slopes: Optional[torch.Tensor] = None,
|
||||
softcap: float = 0.0,
|
||||
) -> torch.Tensor:
|
||||
return _flash_attn_with_kvcache(
|
||||
decode_query,
|
||||
key_cache,
|
||||
value_cache,
|
||||
cache_seqlens=cache_seqlens,
|
||||
block_table=block_table,
|
||||
softmax_scale=softmax_scale,
|
||||
causal=causal,
|
||||
alibi_slopes=alibi_slopes,
|
||||
softcap=softcap,
|
||||
)
|
||||
|
||||
|
||||
@flash_attn_with_kvcache.register_fake # type: ignore
|
||||
def _(
|
||||
decode_query: torch.Tensor,
|
||||
key_cache: torch.Tensor,
|
||||
value_cache: torch.Tensor,
|
||||
cache_seqlens: Optional[torch.Tensor] = None,
|
||||
block_table: Optional[torch.Tensor] = None,
|
||||
softmax_scale: Optional[float] = None,
|
||||
causal: bool = False,
|
||||
alibi_slopes: Optional[torch.Tensor] = None,
|
||||
softcap: float = 0.0,
|
||||
) -> torch.Tensor:
|
||||
return torch.empty_like(decode_query)
|
||||
|
||||
|
||||
class FlashAttentionBackend(AttentionBackend):
|
||||
|
||||
@ -517,7 +618,7 @@ class FlashAttentionImpl(AttentionImpl):
|
||||
# normal attention
|
||||
# When block_tables are not filled, it means q and k are the
|
||||
# prompt, and they have the same length.
|
||||
out = flash_attn_varlen_func(
|
||||
out = torch.ops.vllm.flash_attn_varlen_func(
|
||||
q=query,
|
||||
k=key,
|
||||
v=value,
|
||||
@ -537,34 +638,36 @@ class FlashAttentionImpl(AttentionImpl):
|
||||
# prefix-enabled attention
|
||||
assert prefill_meta.seq_lens is not None
|
||||
max_seq_len = max(prefill_meta.seq_lens)
|
||||
output[:num_prefill_tokens] = flash_attn_varlen_func(
|
||||
q=query,
|
||||
k=key_cache,
|
||||
v=value_cache,
|
||||
cu_seqlens_q=prefill_meta.query_start_loc,
|
||||
max_seqlen_q=prefill_meta.max_query_len,
|
||||
cu_seqlens_k=prefill_meta.seq_start_loc,
|
||||
max_seqlen_k=max_seq_len,
|
||||
softmax_scale=self.scale,
|
||||
causal=True,
|
||||
alibi_slopes=self.alibi_slopes,
|
||||
block_table=prefill_meta.block_tables,
|
||||
softcap=self.logits_soft_cap,
|
||||
)
|
||||
output[:
|
||||
num_prefill_tokens] = torch.ops.vllm.flash_attn_varlen_func( # noqa
|
||||
q=query,
|
||||
k=key_cache,
|
||||
v=value_cache,
|
||||
cu_seqlens_q=prefill_meta.query_start_loc,
|
||||
max_seqlen_q=prefill_meta.max_query_len,
|
||||
cu_seqlens_k=prefill_meta.seq_start_loc,
|
||||
max_seqlen_k=max_seq_len,
|
||||
softmax_scale=self.scale,
|
||||
causal=True,
|
||||
alibi_slopes=self.alibi_slopes,
|
||||
block_table=prefill_meta.block_tables,
|
||||
softcap=self.logits_soft_cap,
|
||||
)
|
||||
|
||||
if decode_meta := attn_metadata.decode_metadata:
|
||||
# Decoding run.
|
||||
output[num_prefill_tokens:] = flash_attn_with_kvcache(
|
||||
decode_query.unsqueeze(1),
|
||||
key_cache,
|
||||
value_cache,
|
||||
block_table=decode_meta.block_tables,
|
||||
cache_seqlens=decode_meta.seq_lens_tensor,
|
||||
softmax_scale=self.scale,
|
||||
causal=True,
|
||||
alibi_slopes=self.alibi_slopes,
|
||||
softcap=self.logits_soft_cap,
|
||||
).squeeze(1)
|
||||
output[
|
||||
num_prefill_tokens:] = torch.ops.vllm.flash_attn_with_kvcache(
|
||||
decode_query.unsqueeze(1),
|
||||
key_cache,
|
||||
value_cache,
|
||||
block_table=decode_meta.block_tables,
|
||||
cache_seqlens=decode_meta.seq_lens_tensor,
|
||||
softmax_scale=self.scale,
|
||||
causal=True,
|
||||
alibi_slopes=self.alibi_slopes,
|
||||
softcap=self.logits_soft_cap,
|
||||
).squeeze(1)
|
||||
|
||||
# Reshape the output tensor.
|
||||
return output.view(num_tokens, hidden_size)
|
||||
|
||||
@ -4,9 +4,9 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple, Type
|
||||
try:
|
||||
from flashinfer import BatchDecodeWithPagedKVCacheWrapper
|
||||
from flashinfer.prefill import BatchPrefillWithPagedKVCacheWrapper
|
||||
from vllm_flash_attn import flash_attn_varlen_func
|
||||
|
||||
import vllm.attention.backends.flash_attn # noqa
|
||||
except ImportError:
|
||||
flash_attn_varlen_func = None
|
||||
BatchDecodeWithPagedKVCacheWrapper = None
|
||||
BatchPrefillWithPagedKVCacheWrapper = None
|
||||
|
||||
@ -520,7 +520,7 @@ class FlashInferImpl(AttentionImpl):
|
||||
# This happens when vllm runs the profiling to
|
||||
# determine the number of blocks.
|
||||
if kv_cache is None:
|
||||
output = flash_attn_varlen_func(
|
||||
output = torch.ops.vllm.flash_attn_varlen_func(
|
||||
q=query,
|
||||
k=key,
|
||||
v=value,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user