mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 23:05:21 +08:00
[NVIDIA] Support Flashinfer TRT-LLM Prefill Attention Kernel (#22095)
Signed-off-by: elvischenv <219235043+elvischenv@users.noreply.github.com>
This commit is contained in:
parent
4771df7b2b
commit
83156c7b89
@ -664,7 +664,7 @@ steps:
|
|||||||
# Attention
|
# Attention
|
||||||
# num_heads2 broken by https://github.com/flashinfer-ai/flashinfer/issues/1353
|
# num_heads2 broken by https://github.com/flashinfer-ai/flashinfer/issues/1353
|
||||||
- pytest -v -s tests/kernels/attention/test_flashinfer.py -k 'not num_heads2'
|
- pytest -v -s tests/kernels/attention/test_flashinfer.py -k 'not num_heads2'
|
||||||
- pytest -v -s tests/kernels/attention/test_flashinfer_trtllm_decode_attention.py
|
- pytest -v -s tests/kernels/attention/test_flashinfer_trtllm_attention.py
|
||||||
- pytest -v -s tests/kernels/test_cutlass_mla_decode.py
|
- pytest -v -s tests/kernels/test_cutlass_mla_decode.py
|
||||||
# Quantization
|
# Quantization
|
||||||
- pytest -v -s tests/kernels/quantization/test_cutlass_scaled_mm.py -k 'fp8'
|
- pytest -v -s tests/kernels/quantization/test_cutlass_scaled_mm.py -k 'fp8'
|
||||||
|
|||||||
@ -41,7 +41,6 @@ def benchmark_decode(
|
|||||||
device = "cuda"
|
device = "cuda"
|
||||||
torch.manual_seed(0)
|
torch.manual_seed(0)
|
||||||
|
|
||||||
# Currently only HEAD_GRP_SIZE == 8 is supported
|
|
||||||
HEAD_GRP_SIZE = 8
|
HEAD_GRP_SIZE = 8
|
||||||
MAX_SEQ_LEN = max_seq_len
|
MAX_SEQ_LEN = max_seq_len
|
||||||
|
|
||||||
250
benchmarks/kernels/benchmark_trtllm_prefill_attention.py
Normal file
250
benchmarks/kernels/benchmark_trtllm_prefill_attention.py
Normal file
@ -0,0 +1,250 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
import csv
|
||||||
|
import os
|
||||||
|
import random
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
import flashinfer
|
||||||
|
import torch
|
||||||
|
|
||||||
|
FLOAT32_BYTES = torch.finfo(torch.float).bits // 8
|
||||||
|
|
||||||
|
# KV Cache Layout for TRT-LLM
|
||||||
|
# kv_cache_shape = (num_blocks, 2, num_kv_heads, page_size, head_dim)
|
||||||
|
|
||||||
|
|
||||||
|
def to_float8(x, dtype=torch.float8_e4m3fn):
|
||||||
|
finfo = torch.finfo(dtype)
|
||||||
|
min_val, max_val = x.aminmax()
|
||||||
|
amax = torch.maximum(min_val.abs(), max_val.abs()).clamp(min=1e-12)
|
||||||
|
scale = finfo.max / amax * 0.1
|
||||||
|
x_scl_sat = (x * scale).clamp(min=finfo.min, max=finfo.max)
|
||||||
|
return x_scl_sat.to(dtype), scale.float().reciprocal()
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def benchmark_prefill(
|
||||||
|
num_seqs,
|
||||||
|
max_seq_len,
|
||||||
|
page_size=16,
|
||||||
|
dtype=torch.bfloat16,
|
||||||
|
kv_layout="HND",
|
||||||
|
num_kv_heads=8,
|
||||||
|
kv_cache_dtype="auto",
|
||||||
|
head_dim=128,
|
||||||
|
warmup=10,
|
||||||
|
trials=20,
|
||||||
|
):
|
||||||
|
torch.set_default_device("cuda")
|
||||||
|
torch.manual_seed(0)
|
||||||
|
|
||||||
|
HEAD_GRP_SIZE = 8
|
||||||
|
MAX_SEQ_LEN = max_seq_len
|
||||||
|
|
||||||
|
# large number to reduce kv_cache reuse
|
||||||
|
NUM_BLOCKS = int(256000 / page_size)
|
||||||
|
|
||||||
|
workspace_buffer = torch.empty(1024 * 1024 * 1024, dtype=torch.int8)
|
||||||
|
|
||||||
|
num_qo_heads = num_kv_heads * HEAD_GRP_SIZE
|
||||||
|
sm_scale = float(1.0 / (head_dim**0.5))
|
||||||
|
|
||||||
|
q_lens = [random.randint(1, MAX_SEQ_LEN) for _ in range(num_seqs)]
|
||||||
|
q_lens[-1] = MAX_SEQ_LEN
|
||||||
|
max_q_len = max(q_lens)
|
||||||
|
q_indptr = torch.cat(
|
||||||
|
[
|
||||||
|
torch.tensor([0], dtype=torch.int32),
|
||||||
|
torch.cumsum(
|
||||||
|
torch.tensor(q_lens, dtype=torch.int32), dim=0, dtype=torch.int32
|
||||||
|
),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
q = torch.randn(sum(q_lens), num_qo_heads, head_dim, dtype=dtype)
|
||||||
|
|
||||||
|
kv_lens = [random.randint(0, MAX_SEQ_LEN) for _ in range(num_seqs)]
|
||||||
|
kv_lens[-1] = MAX_SEQ_LEN
|
||||||
|
|
||||||
|
seq_lens = [q_len + kv_len for q_len, kv_len in zip(q_lens, kv_lens)]
|
||||||
|
max_seq_len = max(seq_lens)
|
||||||
|
seq_lens_tensor = torch.tensor(seq_lens, dtype=torch.int32)
|
||||||
|
|
||||||
|
max_num_blocks_per_seq = (max_seq_len + page_size - 1) // page_size
|
||||||
|
block_tables = torch.randint(
|
||||||
|
0, NUM_BLOCKS, (num_seqs, max_num_blocks_per_seq), dtype=torch.int32
|
||||||
|
)
|
||||||
|
|
||||||
|
kv_cache_shape = (NUM_BLOCKS, 2, num_kv_heads, page_size, head_dim)
|
||||||
|
kv_cache = torch.randn(size=kv_cache_shape, dtype=dtype)
|
||||||
|
k_scale = v_scale = 1.0
|
||||||
|
|
||||||
|
if kv_cache_dtype.startswith("fp8"):
|
||||||
|
kv_cache, _ = to_float8(kv_cache)
|
||||||
|
|
||||||
|
output_trtllm = torch.empty(q.shape, dtype=dtype)
|
||||||
|
|
||||||
|
kv_indptr = [0]
|
||||||
|
kv_indices = []
|
||||||
|
kv_last_page_lens = []
|
||||||
|
for i in range(num_seqs):
|
||||||
|
seq_len = seq_lens[i]
|
||||||
|
assert seq_len > 0
|
||||||
|
num_blocks = (seq_len + page_size - 1) // page_size
|
||||||
|
kv_indices.extend(block_tables[i, :num_blocks])
|
||||||
|
kv_indptr.append(kv_indptr[-1] + num_blocks)
|
||||||
|
kv_last_page_len = seq_len % page_size
|
||||||
|
if kv_last_page_len == 0:
|
||||||
|
kv_last_page_len = page_size
|
||||||
|
kv_last_page_lens.append(kv_last_page_len)
|
||||||
|
|
||||||
|
kv_indptr = torch.tensor(kv_indptr, dtype=torch.int32)
|
||||||
|
kv_indices = torch.tensor(kv_indices, dtype=torch.int32)
|
||||||
|
kv_last_page_lens = torch.tensor(kv_last_page_lens, dtype=torch.int32)
|
||||||
|
|
||||||
|
output_baseline = torch.empty(q.shape, dtype=dtype)
|
||||||
|
|
||||||
|
wrapper = flashinfer.BatchPrefillWithPagedKVCacheWrapper(
|
||||||
|
workspace_buffer, kv_layout
|
||||||
|
)
|
||||||
|
wrapper.plan(
|
||||||
|
q_indptr,
|
||||||
|
kv_indptr,
|
||||||
|
kv_indices,
|
||||||
|
kv_last_page_lens,
|
||||||
|
num_qo_heads,
|
||||||
|
num_kv_heads,
|
||||||
|
head_dim,
|
||||||
|
page_size,
|
||||||
|
causal=True,
|
||||||
|
sm_scale=sm_scale,
|
||||||
|
q_data_type=dtype,
|
||||||
|
kv_data_type=kv_cache.dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
def time_fn(fn, warmup=10, trials=20):
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
start = torch.cuda.Event(enable_timing=True)
|
||||||
|
end = torch.cuda.Event(enable_timing=True)
|
||||||
|
times = []
|
||||||
|
for i in range(warmup):
|
||||||
|
fn()
|
||||||
|
for i in range(trials):
|
||||||
|
start.record()
|
||||||
|
fn()
|
||||||
|
end.record()
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
times.append(start.elapsed_time(end)) # ms
|
||||||
|
return sum(times) / len(times), torch.std(torch.tensor(times))
|
||||||
|
|
||||||
|
def baseline_prefill():
|
||||||
|
return wrapper.run(
|
||||||
|
q, kv_cache, k_scale=k_scale, v_scale=v_scale, out=output_baseline
|
||||||
|
)
|
||||||
|
|
||||||
|
def trt_prefill():
|
||||||
|
return flashinfer.prefill.trtllm_batch_context_with_kv_cache(
|
||||||
|
query=q,
|
||||||
|
kv_cache=kv_cache,
|
||||||
|
workspace_buffer=workspace_buffer,
|
||||||
|
block_tables=block_tables,
|
||||||
|
seq_lens=seq_lens_tensor,
|
||||||
|
max_q_len=max_q_len,
|
||||||
|
max_kv_len=max_seq_len,
|
||||||
|
bmm1_scale=k_scale * sm_scale,
|
||||||
|
bmm2_scale=v_scale,
|
||||||
|
batch_size=num_seqs,
|
||||||
|
cum_seq_lens_q=q_indptr,
|
||||||
|
cum_seq_lens_kv=kv_indptr,
|
||||||
|
out=output_trtllm,
|
||||||
|
)
|
||||||
|
|
||||||
|
trt_mean, trt_std = time_fn(trt_prefill)
|
||||||
|
baseline_mean, baseline_std = time_fn(baseline_prefill)
|
||||||
|
|
||||||
|
# Calculate percentage speedup (positive means TRT is faster)
|
||||||
|
speedup_percent = (baseline_mean - trt_mean) / baseline_mean
|
||||||
|
|
||||||
|
print(
|
||||||
|
f"\t{num_seqs}\t{max_seq_len}\t{trt_mean:.5f}\t{trt_std.item():.5f}"
|
||||||
|
f"\t{baseline_mean:.5f}\t{baseline_std.item():.5f}\t{speedup_percent:.5f}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Return results for CSV writing
|
||||||
|
return {
|
||||||
|
"num_seqs": num_seqs,
|
||||||
|
"trt_mean": trt_mean,
|
||||||
|
"trt_std": trt_std.item(),
|
||||||
|
"baseline_mean": baseline_mean,
|
||||||
|
"baseline_std": baseline_std.item(),
|
||||||
|
"speedup_percent": speedup_percent,
|
||||||
|
"q_dtype": str(dtype),
|
||||||
|
"kv_cache_dtype": kv_cache_dtype,
|
||||||
|
"page_size": page_size,
|
||||||
|
"num_kv_heads": num_kv_heads,
|
||||||
|
"head_dim": head_dim,
|
||||||
|
"max_seq_len": max_seq_len,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def write_results_to_csv(results, filename=None):
|
||||||
|
"""Write benchmark results to CSV file."""
|
||||||
|
if filename is None:
|
||||||
|
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||||
|
filename = f"flashinfer_trtllm_benchmark_{timestamp}.csv"
|
||||||
|
|
||||||
|
fieldnames = [
|
||||||
|
"num_seqs",
|
||||||
|
"trt_mean",
|
||||||
|
"trt_std",
|
||||||
|
"baseline_mean",
|
||||||
|
"baseline_std",
|
||||||
|
"speedup_percent",
|
||||||
|
"q_dtype",
|
||||||
|
"kv_cache_dtype",
|
||||||
|
"page_size",
|
||||||
|
"num_kv_heads",
|
||||||
|
"head_dim",
|
||||||
|
"max_seq_len",
|
||||||
|
]
|
||||||
|
|
||||||
|
file_exists = os.path.exists(filename)
|
||||||
|
|
||||||
|
with open(filename, "a", newline="") as csvfile:
|
||||||
|
writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
|
||||||
|
|
||||||
|
if not file_exists:
|
||||||
|
writer.writeheader()
|
||||||
|
|
||||||
|
for result in results:
|
||||||
|
writer.writerow(result)
|
||||||
|
|
||||||
|
print(f"Results written to {filename}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
num_seqs = [1, 4, 8, 16, 32, 64, 128, 256]
|
||||||
|
max_seq_lens = [1024, 2048, 4096, 8192, 16384, 32768, 65536, 131072]
|
||||||
|
all_results = []
|
||||||
|
|
||||||
|
print(
|
||||||
|
"Running benchmark for q_dtype = bfloat16, kv_cache_dtype: bfloat16, "
|
||||||
|
"output_dtype: bfloat16"
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
"\tnum_seqs\tmax_seq_len\ttrt_mean\ttrt_std\tbaseline_mean\t"
|
||||||
|
"baseline_std\tspeedup_percent"
|
||||||
|
)
|
||||||
|
for max_seq_len in max_seq_lens:
|
||||||
|
for bs in num_seqs:
|
||||||
|
result = benchmark_prefill(
|
||||||
|
bs,
|
||||||
|
max_seq_len,
|
||||||
|
dtype=torch.bfloat16,
|
||||||
|
kv_cache_dtype="auto",
|
||||||
|
)
|
||||||
|
all_results.append(result)
|
||||||
|
|
||||||
|
# Write all results to CSV
|
||||||
|
write_results_to_csv(all_results)
|
||||||
293
tests/kernels/attention/test_flashinfer_trtllm_attention.py
Normal file
293
tests/kernels/attention/test_flashinfer_trtllm_attention.py
Normal file
@ -0,0 +1,293 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import flashinfer
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
|
if not current_platform.is_device_capability(100):
|
||||||
|
pytest.skip("This TRTLLM kernel requires NVIDIA Blackwell.",
|
||||||
|
allow_module_level=True)
|
||||||
|
|
||||||
|
FLOAT32_BYTES = torch.finfo(torch.float).bits // 8
|
||||||
|
|
||||||
|
# KV Cache Layout for TRT-LLM
|
||||||
|
# kv_cache_shape = (num_blocks, 2, num_kv_heads, page_size, head_dim)
|
||||||
|
|
||||||
|
MAX_Q_LEN = 1024
|
||||||
|
MAX_KV_LEN = 4096
|
||||||
|
BATCH_SIZES = [4, 12]
|
||||||
|
NUM_HEADS = [(64, 8), (16, 16), (40, 8), (32, 8)]
|
||||||
|
HEAD_SIZES = [128]
|
||||||
|
BLOCK_SIZES = [16, 32]
|
||||||
|
KV_LAYOUTS = ["HND"]
|
||||||
|
DTYPES = [torch.float16, torch.bfloat16]
|
||||||
|
KV_CACHE_DTYPES = [None, current_platform.fp8_dtype()]
|
||||||
|
NUM_BLOCKS = 32768 # Large enough to test overflow in index calculation.
|
||||||
|
SOFT_CAPS = [None, 50.0]
|
||||||
|
|
||||||
|
|
||||||
|
def to_float8(x, dtype=torch.float8_e4m3fn):
|
||||||
|
finfo = torch.finfo(dtype)
|
||||||
|
min_val, max_val = x.aminmax()
|
||||||
|
amax = torch.maximum(min_val.abs(), max_val.abs()).clamp(min=1e-12)
|
||||||
|
scale = finfo.max / amax * 0.1
|
||||||
|
x_scl_sat = (x * scale).clamp(min=finfo.min, max=finfo.max)
|
||||||
|
return x_scl_sat.to(dtype), scale.float().reciprocal()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("batch_size", BATCH_SIZES)
|
||||||
|
@pytest.mark.parametrize("num_heads", NUM_HEADS)
|
||||||
|
@pytest.mark.parametrize("head_size", HEAD_SIZES)
|
||||||
|
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
|
||||||
|
@pytest.mark.parametrize("kv_layout", KV_LAYOUTS)
|
||||||
|
@pytest.mark.parametrize("dtype", DTYPES)
|
||||||
|
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPES)
|
||||||
|
@pytest.mark.parametrize("soft_cap", SOFT_CAPS)
|
||||||
|
@torch.inference_mode
|
||||||
|
def test_flashinfer_trtllm_decode_with_baseline(
|
||||||
|
batch_size: int,
|
||||||
|
num_heads: tuple[int, int],
|
||||||
|
head_size: int,
|
||||||
|
block_size: int,
|
||||||
|
kv_layout: str,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
kv_cache_dtype: Optional[torch.dtype],
|
||||||
|
soft_cap: Optional[float],
|
||||||
|
) -> None:
|
||||||
|
kv_cache_dtype = dtype if kv_cache_dtype is None else kv_cache_dtype
|
||||||
|
|
||||||
|
torch.set_default_device("cuda")
|
||||||
|
current_platform.seed_everything(0)
|
||||||
|
|
||||||
|
kv_lens = torch.randint(1, MAX_KV_LEN, (batch_size, ), dtype=torch.int32)
|
||||||
|
kv_lens[-1] = MAX_KV_LEN
|
||||||
|
max_kv_len = torch.max(kv_lens).item()
|
||||||
|
num_seqs = len(kv_lens)
|
||||||
|
|
||||||
|
num_query_heads = num_heads[0]
|
||||||
|
num_kv_heads = num_heads[1]
|
||||||
|
assert num_query_heads % num_kv_heads == 0
|
||||||
|
|
||||||
|
scale = head_size**-0.5
|
||||||
|
|
||||||
|
query = torch.randn(num_seqs, num_query_heads, head_size, dtype=dtype)
|
||||||
|
|
||||||
|
kv_cache_shape = None
|
||||||
|
if kv_layout == "NHD":
|
||||||
|
kv_cache_shape = (NUM_BLOCKS, 2, block_size, num_kv_heads, head_size)
|
||||||
|
elif kv_layout == "HND":
|
||||||
|
kv_cache_shape = (NUM_BLOCKS, 2, num_kv_heads, block_size, head_size)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Invalid kv_layout: {kv_layout}")
|
||||||
|
key_value_cache = torch.randn(kv_cache_shape, dtype=dtype)
|
||||||
|
kv_scale = 1.0
|
||||||
|
if kv_cache_dtype is current_platform.fp8_dtype():
|
||||||
|
key_value_cache, kv_scale = to_float8(key_value_cache,
|
||||||
|
current_platform.fp8_dtype())
|
||||||
|
|
||||||
|
max_num_blocks_per_seq = (max_kv_len + block_size - 1) // block_size
|
||||||
|
block_tables = torch.randint(0,
|
||||||
|
NUM_BLOCKS,
|
||||||
|
(num_seqs, max_num_blocks_per_seq),
|
||||||
|
dtype=torch.int32)
|
||||||
|
k_scale = v_scale = kv_scale
|
||||||
|
kv_indptr = [0]
|
||||||
|
kv_indices = []
|
||||||
|
kv_last_page_lens = []
|
||||||
|
for i in range(num_seqs):
|
||||||
|
seq_len = kv_lens[i]
|
||||||
|
assert seq_len > 0
|
||||||
|
num_blocks = (seq_len + block_size - 1) // block_size
|
||||||
|
kv_indices.extend(block_tables[i, :num_blocks])
|
||||||
|
kv_indptr.append(kv_indptr[-1] + num_blocks)
|
||||||
|
kv_last_page_len = seq_len % block_size
|
||||||
|
if kv_last_page_len == 0:
|
||||||
|
kv_last_page_len = block_size
|
||||||
|
kv_last_page_lens.append(kv_last_page_len)
|
||||||
|
|
||||||
|
kv_indptr = torch.tensor(kv_indptr, dtype=torch.int32)
|
||||||
|
kv_indices = torch.tensor(kv_indices, dtype=torch.int32)
|
||||||
|
kv_last_page_lens = torch.tensor(kv_last_page_lens, dtype=torch.int32)
|
||||||
|
|
||||||
|
workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8)
|
||||||
|
wrapper = flashinfer.BatchDecodeWithPagedKVCacheWrapper(
|
||||||
|
workspace_buffer,
|
||||||
|
kv_layout,
|
||||||
|
use_tensor_cores=((num_query_heads // num_kv_heads) > 4))
|
||||||
|
wrapper.plan(kv_indptr,
|
||||||
|
kv_indices,
|
||||||
|
kv_last_page_lens,
|
||||||
|
num_query_heads,
|
||||||
|
num_kv_heads,
|
||||||
|
head_size,
|
||||||
|
block_size,
|
||||||
|
"NONE",
|
||||||
|
sm_scale=scale,
|
||||||
|
q_data_type=dtype,
|
||||||
|
kv_data_type=kv_cache_dtype,
|
||||||
|
logits_soft_cap=soft_cap)
|
||||||
|
|
||||||
|
output = torch.empty(query.shape, dtype=dtype)
|
||||||
|
wrapper.run(query,
|
||||||
|
key_value_cache,
|
||||||
|
k_scale=k_scale,
|
||||||
|
v_scale=v_scale,
|
||||||
|
out=output)
|
||||||
|
|
||||||
|
# TRTLLM Decode
|
||||||
|
kv_lens_tensor = torch.tensor(kv_lens, dtype=torch.int32)
|
||||||
|
output_trtllm = torch.empty(query.shape, dtype=dtype)
|
||||||
|
flashinfer.decode.trtllm_batch_decode_with_kv_cache(
|
||||||
|
query=query.contiguous(),
|
||||||
|
kv_cache=key_value_cache,
|
||||||
|
workspace_buffer=workspace_buffer,
|
||||||
|
block_tables=block_tables,
|
||||||
|
seq_lens=kv_lens_tensor,
|
||||||
|
max_seq_len=max_kv_len,
|
||||||
|
bmm1_scale=k_scale * scale,
|
||||||
|
bmm2_scale=v_scale,
|
||||||
|
out=output_trtllm,
|
||||||
|
)
|
||||||
|
|
||||||
|
torch.testing.assert_close(output, output_trtllm, atol=1e-2, rtol=1e-2), \
|
||||||
|
f"{torch.max(torch.abs(output - output_trtllm))}"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("batch_size", BATCH_SIZES)
|
||||||
|
@pytest.mark.parametrize("num_heads", NUM_HEADS)
|
||||||
|
@pytest.mark.parametrize("head_size", HEAD_SIZES)
|
||||||
|
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
|
||||||
|
@pytest.mark.parametrize("kv_layout", KV_LAYOUTS)
|
||||||
|
@pytest.mark.parametrize("dtype", DTYPES)
|
||||||
|
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPES)
|
||||||
|
@pytest.mark.parametrize("soft_cap", [None])
|
||||||
|
@torch.inference_mode
|
||||||
|
def test_flashinfer_trtllm_prefill_with_baseline(
|
||||||
|
batch_size: int,
|
||||||
|
num_heads: tuple[int, int],
|
||||||
|
head_size: int,
|
||||||
|
block_size: int,
|
||||||
|
kv_layout: str,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
kv_cache_dtype: Optional[torch.dtype],
|
||||||
|
soft_cap: Optional[float],
|
||||||
|
) -> None:
|
||||||
|
kv_cache_dtype = dtype if kv_cache_dtype is None else kv_cache_dtype
|
||||||
|
if dtype != kv_cache_dtype:
|
||||||
|
pytest.skip(f"Not supported dtype({dtype}) with "
|
||||||
|
"kv_cache_dtype({kv_cache_dtype})")
|
||||||
|
|
||||||
|
torch.set_default_device("cuda")
|
||||||
|
current_platform.seed_everything(0)
|
||||||
|
|
||||||
|
q_lens = torch.randint(1, MAX_Q_LEN, (batch_size, ), dtype=torch.int32)
|
||||||
|
q_lens[-1] = MAX_Q_LEN
|
||||||
|
max_q_len = torch.max(q_lens).item()
|
||||||
|
q_indptr = torch.cat([
|
||||||
|
torch.tensor([0], dtype=torch.int32),
|
||||||
|
torch.cumsum(q_lens, dim=0, dtype=torch.int32),
|
||||||
|
])
|
||||||
|
|
||||||
|
kv_lens = torch.randint(0, MAX_KV_LEN, (batch_size, ), dtype=torch.int32)
|
||||||
|
kv_lens[-1] = MAX_KV_LEN
|
||||||
|
|
||||||
|
seq_lens = kv_lens + q_lens
|
||||||
|
max_seq_len = torch.max(seq_lens).item()
|
||||||
|
num_seqs = len(seq_lens)
|
||||||
|
|
||||||
|
num_query_heads = num_heads[0]
|
||||||
|
num_kv_heads = num_heads[1]
|
||||||
|
assert num_query_heads % num_kv_heads == 0
|
||||||
|
|
||||||
|
scale = head_size**-0.5
|
||||||
|
|
||||||
|
query = torch.randn(torch.sum(q_lens).item(),
|
||||||
|
num_query_heads,
|
||||||
|
head_size,
|
||||||
|
dtype=dtype)
|
||||||
|
|
||||||
|
kv_cache_shape = None
|
||||||
|
if kv_layout == "NHD":
|
||||||
|
kv_cache_shape = (NUM_BLOCKS, 2, block_size, num_kv_heads, head_size)
|
||||||
|
elif kv_layout == "HND":
|
||||||
|
kv_cache_shape = (NUM_BLOCKS, 2, num_kv_heads, block_size, head_size)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Invalid kv_layout: {kv_layout}")
|
||||||
|
key_value_cache = torch.randn(kv_cache_shape, dtype=dtype)
|
||||||
|
kv_scale = 1.0
|
||||||
|
if kv_cache_dtype is current_platform.fp8_dtype():
|
||||||
|
key_value_cache, kv_scale = to_float8(key_value_cache,
|
||||||
|
current_platform.fp8_dtype())
|
||||||
|
|
||||||
|
max_num_blocks_per_seq = (max_seq_len + block_size - 1) // block_size
|
||||||
|
block_tables = torch.randint(0,
|
||||||
|
NUM_BLOCKS,
|
||||||
|
(num_seqs, max_num_blocks_per_seq),
|
||||||
|
dtype=torch.int32)
|
||||||
|
k_scale = v_scale = kv_scale
|
||||||
|
kv_indptr = [0]
|
||||||
|
kv_indices = []
|
||||||
|
kv_last_page_lens = []
|
||||||
|
for i in range(num_seqs):
|
||||||
|
seq_len = seq_lens[i]
|
||||||
|
assert seq_len > 0
|
||||||
|
num_blocks = (seq_len + block_size - 1) // block_size
|
||||||
|
kv_indices.extend(block_tables[i, :num_blocks])
|
||||||
|
kv_indptr.append(kv_indptr[-1] + num_blocks)
|
||||||
|
kv_last_page_len = seq_len % block_size
|
||||||
|
if kv_last_page_len == 0:
|
||||||
|
kv_last_page_len = block_size
|
||||||
|
kv_last_page_lens.append(kv_last_page_len)
|
||||||
|
|
||||||
|
kv_indptr = torch.tensor(kv_indptr, dtype=torch.int32)
|
||||||
|
kv_indices = torch.tensor(kv_indices, dtype=torch.int32)
|
||||||
|
kv_last_page_lens = torch.tensor(kv_last_page_lens, dtype=torch.int32)
|
||||||
|
|
||||||
|
workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8)
|
||||||
|
wrapper = flashinfer.BatchPrefillWithPagedKVCacheWrapper(
|
||||||
|
workspace_buffer, kv_layout)
|
||||||
|
wrapper.plan(q_indptr,
|
||||||
|
kv_indptr,
|
||||||
|
kv_indices,
|
||||||
|
kv_last_page_lens,
|
||||||
|
num_query_heads,
|
||||||
|
num_kv_heads,
|
||||||
|
head_size,
|
||||||
|
block_size,
|
||||||
|
causal=True,
|
||||||
|
sm_scale=scale,
|
||||||
|
q_data_type=dtype,
|
||||||
|
kv_data_type=kv_cache_dtype,
|
||||||
|
logits_soft_cap=soft_cap)
|
||||||
|
|
||||||
|
output = torch.empty(query.shape, dtype=dtype)
|
||||||
|
wrapper.run(query,
|
||||||
|
key_value_cache,
|
||||||
|
k_scale=k_scale,
|
||||||
|
v_scale=v_scale,
|
||||||
|
out=output)
|
||||||
|
|
||||||
|
# TRTLLM Decode
|
||||||
|
output_trtllm = torch.empty(query.shape, dtype=dtype)
|
||||||
|
flashinfer.prefill.trtllm_batch_context_with_kv_cache(
|
||||||
|
query=query.contiguous(),
|
||||||
|
kv_cache=key_value_cache,
|
||||||
|
workspace_buffer=workspace_buffer,
|
||||||
|
block_tables=block_tables,
|
||||||
|
seq_lens=seq_lens,
|
||||||
|
max_q_len=max_q_len,
|
||||||
|
max_kv_len=max_seq_len,
|
||||||
|
bmm1_scale=k_scale * scale,
|
||||||
|
bmm2_scale=v_scale,
|
||||||
|
batch_size=num_seqs,
|
||||||
|
cum_seq_lens_q=q_indptr,
|
||||||
|
cum_seq_lens_kv=kv_indptr,
|
||||||
|
out=output_trtllm,
|
||||||
|
)
|
||||||
|
|
||||||
|
torch.testing.assert_close(output, output_trtllm, atol=1e-2, rtol=1e-2), \
|
||||||
|
f"{torch.max(torch.abs(output - output_trtllm))}"
|
||||||
@ -1,138 +0,0 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
import flashinfer
|
|
||||||
import pytest
|
|
||||||
import torch
|
|
||||||
|
|
||||||
from vllm.platforms import current_platform
|
|
||||||
|
|
||||||
if not current_platform.is_device_capability(100):
|
|
||||||
pytest.skip("This TRTLLM kernel requires NVIDIA Blackwell.",
|
|
||||||
allow_module_level=True)
|
|
||||||
|
|
||||||
FLOAT32_BYTES = torch.finfo(torch.float).bits // 8
|
|
||||||
|
|
||||||
# KV Cache Layout for TRT-LLM
|
|
||||||
# kv_cache_shape = (num_blocks, 2, num_kv_heads, page_size, head_dim)
|
|
||||||
|
|
||||||
NUM_HEADS = [(64, 8), (16, 16), (40, 8), (32, 8)]
|
|
||||||
HEAD_SIZES = [128]
|
|
||||||
BLOCK_SIZES = [16, 32]
|
|
||||||
DTYPES = [torch.float16, torch.bfloat16]
|
|
||||||
NUM_BLOCKS = 32768 # Large enough to test overflow in index calculation.
|
|
||||||
SOFT_CAPS = [None, 30.0, 50.0]
|
|
||||||
|
|
||||||
|
|
||||||
def to_float8(x, dtype=torch.float8_e4m3fn):
|
|
||||||
finfo = torch.finfo(dtype)
|
|
||||||
min_val, max_val = x.aminmax()
|
|
||||||
amax = torch.maximum(min_val.abs(), max_val.abs()).clamp(min=1e-12)
|
|
||||||
scale = finfo.max / amax * 0.1
|
|
||||||
x_scl_sat = (x * scale).clamp(min=finfo.min, max=finfo.max)
|
|
||||||
return x_scl_sat.to(dtype), scale.float().reciprocal()
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("kv_lens", [[1328, 18, 463], [1, 54, 293, 70]])
|
|
||||||
@pytest.mark.parametrize("num_heads", NUM_HEADS)
|
|
||||||
@pytest.mark.parametrize("head_size", HEAD_SIZES)
|
|
||||||
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
|
|
||||||
@pytest.mark.parametrize("kv_layout", ["HND"])
|
|
||||||
@pytest.mark.parametrize("dtype", DTYPES)
|
|
||||||
@pytest.mark.parametrize("soft_cap", SOFT_CAPS)
|
|
||||||
@torch.inference_mode
|
|
||||||
def test_flashinfer_trtllm_decode_with_baseline(
|
|
||||||
kv_lens: list[int],
|
|
||||||
num_heads: tuple[int, int],
|
|
||||||
head_size: int,
|
|
||||||
dtype: torch.dtype,
|
|
||||||
block_size: int,
|
|
||||||
soft_cap: Optional[float],
|
|
||||||
kv_layout: str,
|
|
||||||
) -> None:
|
|
||||||
torch.set_default_device("cuda")
|
|
||||||
current_platform.seed_everything(0)
|
|
||||||
num_seqs = len(kv_lens)
|
|
||||||
num_query_heads = num_heads[0]
|
|
||||||
num_kv_heads = num_heads[1]
|
|
||||||
|
|
||||||
assert num_query_heads % num_kv_heads == 0
|
|
||||||
max_kv_len = max(kv_lens)
|
|
||||||
scale = head_size**-0.5
|
|
||||||
|
|
||||||
query = torch.randn(num_seqs, num_query_heads, head_size, dtype=dtype)
|
|
||||||
kv_cache_shape = None
|
|
||||||
if kv_layout == "NHD":
|
|
||||||
kv_cache_shape = (NUM_BLOCKS, 2, block_size, num_kv_heads, head_size)
|
|
||||||
elif kv_layout == "HND":
|
|
||||||
kv_cache_shape = (NUM_BLOCKS, 2, num_kv_heads, block_size, head_size)
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Invalid kv_layout: {kv_layout}")
|
|
||||||
key_value_cache = torch.randn(kv_cache_shape, dtype=dtype)
|
|
||||||
|
|
||||||
max_num_blocks_per_seq = (max_kv_len + block_size - 1) // block_size
|
|
||||||
block_tables = torch.randint(0,
|
|
||||||
NUM_BLOCKS,
|
|
||||||
(num_seqs, max_num_blocks_per_seq),
|
|
||||||
dtype=torch.int32)
|
|
||||||
k_scale = v_scale = 1.0
|
|
||||||
kv_indptr = [0]
|
|
||||||
kv_indices = []
|
|
||||||
kv_last_page_lens = []
|
|
||||||
for i in range(num_seqs):
|
|
||||||
seq_len = kv_lens[i]
|
|
||||||
assert seq_len > 0
|
|
||||||
num_blocks = (seq_len + block_size - 1) // block_size
|
|
||||||
kv_indices.extend(block_tables[i, :num_blocks])
|
|
||||||
kv_indptr.append(kv_indptr[-1] + num_blocks)
|
|
||||||
kv_last_page_len = seq_len % block_size
|
|
||||||
if kv_last_page_len == 0:
|
|
||||||
kv_last_page_len = block_size
|
|
||||||
kv_last_page_lens.append(kv_last_page_len)
|
|
||||||
|
|
||||||
kv_indptr = torch.tensor(kv_indptr, dtype=torch.int32)
|
|
||||||
kv_indices = torch.tensor(kv_indices, dtype=torch.int32)
|
|
||||||
kv_last_page_lens = torch.tensor(kv_last_page_lens, dtype=torch.int32)
|
|
||||||
|
|
||||||
workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8)
|
|
||||||
wrapper = flashinfer.\
|
|
||||||
BatchDecodeWithPagedKVCacheWrapper(workspace_buffer, kv_layout,
|
|
||||||
use_tensor_cores=(
|
|
||||||
(num_query_heads//num_kv_heads) > 4)
|
|
||||||
)
|
|
||||||
wrapper.plan(kv_indptr,
|
|
||||||
kv_indices,
|
|
||||||
kv_last_page_lens,
|
|
||||||
num_query_heads,
|
|
||||||
num_kv_heads,
|
|
||||||
head_size,
|
|
||||||
block_size,
|
|
||||||
"NONE",
|
|
||||||
q_data_type=dtype,
|
|
||||||
kv_data_type=dtype,
|
|
||||||
logits_soft_cap=soft_cap)
|
|
||||||
|
|
||||||
output = torch.empty(query.shape, dtype=dtype)
|
|
||||||
wrapper.run(query, key_value_cache, scale, out=output)
|
|
||||||
|
|
||||||
# TRTLLM Decode
|
|
||||||
max_kv_len = max(kv_lens)
|
|
||||||
kv_lens_tensor = torch.tensor(kv_lens,
|
|
||||||
dtype=torch.int,
|
|
||||||
device=query.device)
|
|
||||||
output_trtllm = torch.empty(query.shape, dtype=dtype)
|
|
||||||
flashinfer.decode.trtllm_batch_decode_with_kv_cache(
|
|
||||||
query.contiguous(),
|
|
||||||
key_value_cache,
|
|
||||||
workspace_buffer,
|
|
||||||
block_tables,
|
|
||||||
kv_lens_tensor,
|
|
||||||
max_kv_len,
|
|
||||||
bmm1_scale=k_scale * scale,
|
|
||||||
bmm2_scale=v_scale,
|
|
||||||
out=output_trtllm,
|
|
||||||
)
|
|
||||||
|
|
||||||
torch.testing.assert_close(output, output_trtllm, atol=1e-2, rtol=1e-2), \
|
|
||||||
f"{torch.max(torch.abs(output - output_trtllm))}"
|
|
||||||
@ -46,7 +46,7 @@ from vllm.config import VllmConfig, get_layers_from_vllm_config
|
|||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.utils import (async_tensor_h2d, get_kv_cache_torch_dtype,
|
from vllm.utils import (async_tensor_h2d, get_kv_cache_torch_dtype,
|
||||||
make_tensor_with_pad)
|
make_tensor_with_pad)
|
||||||
from vllm.utils.flashinfer import use_trtllm_decode_attention
|
from vllm.utils.flashinfer import use_trtllm_attention
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
@ -1114,7 +1114,7 @@ class FlashInferImpl(AttentionImpl):
|
|||||||
assert decode_meta.decode_wrapper._sm_scale == softmax_scale
|
assert decode_meta.decode_wrapper._sm_scale == softmax_scale
|
||||||
# TODO: @pavanimajety Remove this once the switch happens
|
# TODO: @pavanimajety Remove this once the switch happens
|
||||||
# inside flashinfer.
|
# inside flashinfer.
|
||||||
if not use_trtllm_decode_attention(
|
if not use_trtllm_attention(
|
||||||
num_decode_tokens, attn_metadata.max_decode_seq_len,
|
num_decode_tokens, attn_metadata.max_decode_seq_len,
|
||||||
kv_cache_dtype, attn_metadata.num_qo_heads,
|
kv_cache_dtype, attn_metadata.num_qo_heads,
|
||||||
attn_metadata.num_kv_heads, attn_metadata.head_dim):
|
attn_metadata.num_kv_heads, attn_metadata.head_dim):
|
||||||
|
|||||||
@ -1027,9 +1027,9 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
|||||||
"VLLM_USE_CUDNN_PREFILL":
|
"VLLM_USE_CUDNN_PREFILL":
|
||||||
lambda: bool(int(os.getenv("VLLM_USE_CUDNN_PREFILL", "0"))),
|
lambda: bool(int(os.getenv("VLLM_USE_CUDNN_PREFILL", "0"))),
|
||||||
|
|
||||||
# If set to 1, use the TRTLLM Decode Attention backend in flashinfer.
|
# If set to 1, use the TRTLLM Attention backend in flashinfer.
|
||||||
"VLLM_USE_TRTLLM_DECODE_ATTENTION":
|
"VLLM_USE_TRTLLM_ATTENTION":
|
||||||
lambda: os.getenv("VLLM_USE_TRTLLM_DECODE_ATTENTION", None),
|
lambda: os.getenv("VLLM_USE_TRTLLM_ATTENTION", None),
|
||||||
|
|
||||||
# Controls garbage collection during CUDA graph capture.
|
# Controls garbage collection during CUDA graph capture.
|
||||||
# If set to 0 (default), enables GC freezing to speed up capture time.
|
# If set to 0 (default), enables GC freezing to speed up capture time.
|
||||||
|
|||||||
@ -144,7 +144,7 @@ def has_nvidia_artifactory() -> bool:
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
def use_trtllm_decode_attention(
|
def use_trtllm_attention(
|
||||||
num_tokens: int,
|
num_tokens: int,
|
||||||
max_seq_len: int,
|
max_seq_len: int,
|
||||||
kv_cache_dtype: str,
|
kv_cache_dtype: str,
|
||||||
@ -159,29 +159,26 @@ def use_trtllm_decode_attention(
|
|||||||
|
|
||||||
# Check if the dimensions are supported by TRTLLM decode attention
|
# Check if the dimensions are supported by TRTLLM decode attention
|
||||||
if (attn_head_size is None or num_qo_heads is None or num_kv_heads is None
|
if (attn_head_size is None or num_qo_heads is None or num_kv_heads is None
|
||||||
or num_qo_heads // num_kv_heads > 8
|
|
||||||
or num_qo_heads % num_kv_heads != 0 or attn_head_size != 128):
|
or num_qo_heads % num_kv_heads != 0 or attn_head_size != 128):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
env_value = envs.VLLM_USE_TRTLLM_DECODE_ATTENTION
|
env_value = envs.VLLM_USE_TRTLLM_ATTENTION
|
||||||
if env_value is not None:
|
if env_value is not None:
|
||||||
logger.info_once("VLLM_USE_TRTLLM_DECODE_ATTENTION is set to %s",
|
logger.info_once("VLLM_USE_TRTLLM_ATTENTION is set to %s", env_value)
|
||||||
env_value)
|
|
||||||
# Environment variable is set - respect it
|
# Environment variable is set - respect it
|
||||||
# Making the conditional check for zero because
|
# Making the conditional check for zero because
|
||||||
# the path is automatically enabled if the batch size condition
|
# the path is automatically enabled if the batch size condition
|
||||||
# is satisfied.
|
# is satisfied.
|
||||||
no_use_trtllm = (env_value == "0")
|
no_use_trtllm = (env_value == "0")
|
||||||
if not no_use_trtllm:
|
if not no_use_trtllm:
|
||||||
logger.info_once("Using TRTLLM decode attention.")
|
logger.info_once("Using TRTLLM attention.")
|
||||||
return not no_use_trtllm
|
return not no_use_trtllm
|
||||||
else:
|
else:
|
||||||
# Environment variable not set - use auto-detection
|
# Environment variable not set - use auto-detection
|
||||||
use_trtllm = (num_tokens <= 256 and max_seq_len < 131072
|
use_trtllm = (num_tokens <= 256 and max_seq_len < 131072
|
||||||
and kv_cache_dtype == "auto")
|
and kv_cache_dtype == "auto")
|
||||||
if use_trtllm:
|
if use_trtllm:
|
||||||
logger.warning_once(
|
logger.warning_once("Using TRTLLM attention (auto-detected).")
|
||||||
"Using TRTLLM decode attention (auto-detected).")
|
|
||||||
return use_trtllm
|
return use_trtllm
|
||||||
|
|
||||||
|
|
||||||
@ -195,5 +192,5 @@ __all__ = [
|
|||||||
"has_flashinfer_moe",
|
"has_flashinfer_moe",
|
||||||
"has_flashinfer_cutlass_fused_moe",
|
"has_flashinfer_cutlass_fused_moe",
|
||||||
"has_nvidia_artifactory",
|
"has_nvidia_artifactory",
|
||||||
"use_trtllm_decode_attention",
|
"use_trtllm_attention",
|
||||||
]
|
]
|
||||||
|
|||||||
@ -12,6 +12,7 @@ from flashinfer import (BatchDecodeWithPagedKVCacheWrapper,
|
|||||||
MultiLevelCascadeAttentionWrapper)
|
MultiLevelCascadeAttentionWrapper)
|
||||||
from flashinfer.decode import (_get_range_buf, get_seq_lens,
|
from flashinfer.decode import (_get_range_buf, get_seq_lens,
|
||||||
trtllm_batch_decode_with_kv_cache)
|
trtllm_batch_decode_with_kv_cache)
|
||||||
|
from flashinfer.prefill import trtllm_batch_context_with_kv_cache
|
||||||
|
|
||||||
import vllm.envs as envs
|
import vllm.envs as envs
|
||||||
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||||
@ -19,7 +20,7 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
|||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.utils import cdiv, is_pin_memory_available
|
from vllm.utils import cdiv, is_pin_memory_available
|
||||||
from vllm.utils.flashinfer import use_trtllm_decode_attention
|
from vllm.utils.flashinfer import use_trtllm_attention
|
||||||
from vllm.v1.attention.backends.flash_attn import use_cascade_attention
|
from vllm.v1.attention.backends.flash_attn import use_cascade_attention
|
||||||
# yapf conflicts with isort for this block
|
# yapf conflicts with isort for this block
|
||||||
# yapf: disable
|
# yapf: disable
|
||||||
@ -149,9 +150,12 @@ class FlashInferMetadata:
|
|||||||
slot_mapping: torch.Tensor
|
slot_mapping: torch.Tensor
|
||||||
|
|
||||||
# For flashinfer trtllm batch decode
|
# For flashinfer trtllm batch decode
|
||||||
|
max_q_len: int
|
||||||
max_seq_len: int
|
max_seq_len: int
|
||||||
seq_lens: torch.Tensor
|
seq_lens: torch.Tensor
|
||||||
block_table_tensor: torch.Tensor
|
block_table_tensor: torch.Tensor
|
||||||
|
prefill_use_trtllm: bool
|
||||||
|
decode_use_trtllm: bool
|
||||||
|
|
||||||
# For handling prefill decode split
|
# For handling prefill decode split
|
||||||
num_decodes: int
|
num_decodes: int
|
||||||
@ -170,6 +174,9 @@ class FlashInferMetadata:
|
|||||||
decode_wrapper: Optional[BatchDecodeWithPagedKVCacheWrapper] = None
|
decode_wrapper: Optional[BatchDecodeWithPagedKVCacheWrapper] = None
|
||||||
cascade_wrapper: Optional[MultiLevelCascadeAttentionWrapper] = None
|
cascade_wrapper: Optional[MultiLevelCascadeAttentionWrapper] = None
|
||||||
|
|
||||||
|
qo_indptr_gpu: Optional[torch.Tensor] = None
|
||||||
|
paged_kv_indptr_gpu: Optional[torch.Tensor] = None
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
if self.head_dim is not None:
|
if self.head_dim is not None:
|
||||||
FlashInferBackend.validate_head_size(self.head_dim)
|
FlashInferBackend.validate_head_size(self.head_dim)
|
||||||
@ -305,8 +312,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
|||||||
2, self._get_workspace_buffer(), get_kv_cache_layout())
|
2, self._get_workspace_buffer(), get_kv_cache_layout())
|
||||||
return self._cascade_wrapper
|
return self._cascade_wrapper
|
||||||
|
|
||||||
def _plan(self, num_prefills: int, num_decodes: int,
|
def _plan(self, attn_metadata: FlashInferMetadata):
|
||||||
attn_metadata: FlashInferMetadata):
|
|
||||||
if attn_metadata.use_cascade:
|
if attn_metadata.use_cascade:
|
||||||
attn_metadata.cascade_wrapper = self._get_cascade_wrapper()
|
attn_metadata.cascade_wrapper = self._get_cascade_wrapper()
|
||||||
attn_metadata.cascade_wrapper.plan(
|
attn_metadata.cascade_wrapper.plan(
|
||||||
@ -341,6 +347,8 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
|||||||
# Regular attention (common case).
|
# Regular attention (common case).
|
||||||
# Decodes are at the front and prefills are at the back,
|
# Decodes are at the front and prefills are at the back,
|
||||||
# according to reorder_batch()
|
# according to reorder_batch()
|
||||||
|
num_prefills = attn_metadata.num_prefills
|
||||||
|
num_decodes = attn_metadata.num_decodes
|
||||||
if num_prefills > 0:
|
if num_prefills > 0:
|
||||||
# Decodes are first so prefills start after the last decode
|
# Decodes are first so prefills start after the last decode
|
||||||
prefill_start = num_decodes
|
prefill_start = num_decodes
|
||||||
@ -356,11 +364,15 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
|||||||
# to be relative to the start of the prefill queries.
|
# to be relative to the start of the prefill queries.
|
||||||
qo_indptr_cpu = attn_metadata.qo_indptr_cpu[
|
qo_indptr_cpu = attn_metadata.qo_indptr_cpu[
|
||||||
prefill_start:] - attn_metadata.qo_indptr_cpu[prefill_start]
|
prefill_start:] - attn_metadata.qo_indptr_cpu[prefill_start]
|
||||||
|
paged_kv_indptr_cpu = attn_metadata.paged_kv_indptr_cpu[
|
||||||
|
prefill_start:]
|
||||||
|
if not attn_metadata.prefill_use_trtllm:
|
||||||
attn_metadata.prefill_wrapper.plan(
|
attn_metadata.prefill_wrapper.plan(
|
||||||
qo_indptr_cpu,
|
qo_indptr_cpu,
|
||||||
attn_metadata.paged_kv_indptr_cpu[prefill_start:],
|
paged_kv_indptr_cpu,
|
||||||
attn_metadata.paged_kv_indices,
|
attn_metadata.paged_kv_indices,
|
||||||
attn_metadata.paged_kv_last_page_len_cpu[prefill_start:],
|
attn_metadata.
|
||||||
|
paged_kv_last_page_len_cpu[prefill_start:],
|
||||||
attn_metadata.num_qo_heads,
|
attn_metadata.num_qo_heads,
|
||||||
attn_metadata.num_kv_heads,
|
attn_metadata.num_kv_heads,
|
||||||
attn_metadata.head_dim,
|
attn_metadata.head_dim,
|
||||||
@ -373,6 +385,10 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
|||||||
q_data_type=attn_metadata.q_data_type,
|
q_data_type=attn_metadata.q_data_type,
|
||||||
kv_data_type=attn_metadata.kv_data_type,
|
kv_data_type=attn_metadata.kv_data_type,
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
attn_metadata.qo_indptr_gpu = qo_indptr_cpu.to(self.device)
|
||||||
|
attn_metadata.paged_kv_indptr_gpu = paged_kv_indptr_cpu.to(
|
||||||
|
self.device)
|
||||||
|
|
||||||
if num_decodes > 0:
|
if num_decodes > 0:
|
||||||
pure_decode = num_prefills == 0
|
pure_decode = num_prefills == 0
|
||||||
@ -400,11 +416,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
|||||||
|
|
||||||
attn_metadata.decode_wrapper = self._get_decode_wrapper(
|
attn_metadata.decode_wrapper = self._get_decode_wrapper(
|
||||||
num_input_tokens, use_cudagraph)
|
num_input_tokens, use_cudagraph)
|
||||||
if not use_trtllm_decode_attention(
|
if not attn_metadata.decode_use_trtllm:
|
||||||
num_decodes, attn_metadata.max_seq_len,
|
|
||||||
self.cache_config.cache_dtype,
|
|
||||||
attn_metadata.num_qo_heads, attn_metadata.num_kv_heads,
|
|
||||||
attn_metadata.head_dim):
|
|
||||||
# Use the persistent buffer with padding length,
|
# Use the persistent buffer with padding length,
|
||||||
# instead of the same address but chunked version
|
# instead of the same address but chunked version
|
||||||
# in atten_metadata when using cudagraph.
|
# in atten_metadata when using cudagraph.
|
||||||
@ -437,6 +449,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
|||||||
split_decodes_and_prefills(common_attn_metadata)
|
split_decodes_and_prefills(common_attn_metadata)
|
||||||
|
|
||||||
page_size = self.kv_cache_spec.block_size
|
page_size = self.kv_cache_spec.block_size
|
||||||
|
max_q_len = common_attn_metadata.max_query_len
|
||||||
max_seq_len = common_attn_metadata.seq_lens_cpu.max()
|
max_seq_len = common_attn_metadata.seq_lens_cpu.max()
|
||||||
seq_lens = common_attn_metadata.seq_lens
|
seq_lens = common_attn_metadata.seq_lens
|
||||||
seq_lens_cpu = common_attn_metadata.seq_lens_cpu
|
seq_lens_cpu = common_attn_metadata.seq_lens_cpu
|
||||||
@ -503,6 +516,24 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
|||||||
cache_dtype)
|
cache_dtype)
|
||||||
else:
|
else:
|
||||||
kv_cache_dtype = self.kv_cache_spec.dtype
|
kv_cache_dtype = self.kv_cache_spec.dtype
|
||||||
|
|
||||||
|
num_qo_heads = self.vllm_config.model_config.get_num_attention_heads(
|
||||||
|
self.vllm_config.parallel_config)
|
||||||
|
num_kv_heads = self.kv_cache_spec.num_kv_heads
|
||||||
|
head_dim = self.kv_cache_spec.head_size
|
||||||
|
|
||||||
|
# currently prefill trtllm attention does not support fp8 kv cache
|
||||||
|
# trtllm may not support sliding window
|
||||||
|
prefill_use_trtllm = (self.global_hyperparameters.window_left == -1
|
||||||
|
and not cache_dtype.startswith("fp8")
|
||||||
|
and use_trtllm_attention(
|
||||||
|
num_prefill_tokens, max_seq_len, cache_dtype,
|
||||||
|
num_qo_heads, num_kv_heads, head_dim))
|
||||||
|
decode_use_trtllm = (self.global_hyperparameters.window_left == -1
|
||||||
|
and use_trtllm_attention(
|
||||||
|
num_decode_tokens, max_seq_len, cache_dtype,
|
||||||
|
num_qo_heads, num_kv_heads, head_dim))
|
||||||
|
|
||||||
attn_metadata = FlashInferMetadata(
|
attn_metadata = FlashInferMetadata(
|
||||||
num_actual_tokens=num_actual_tokens,
|
num_actual_tokens=num_actual_tokens,
|
||||||
qo_indptr_cpu=common_attn_metadata.query_start_loc_cpu,
|
qo_indptr_cpu=common_attn_metadata.query_start_loc_cpu,
|
||||||
@ -510,14 +541,19 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
|||||||
paged_kv_indices=paged_kv_indices,
|
paged_kv_indices=paged_kv_indices,
|
||||||
paged_kv_last_page_len_cpu=self.
|
paged_kv_last_page_len_cpu=self.
|
||||||
paged_kv_last_page_len_cpu[:num_reqs],
|
paged_kv_last_page_len_cpu[:num_reqs],
|
||||||
num_qo_heads=self.vllm_config.model_config.get_num_attention_heads(
|
num_qo_heads=num_qo_heads,
|
||||||
self.vllm_config.parallel_config),
|
num_kv_heads=num_kv_heads,
|
||||||
num_kv_heads=self.kv_cache_spec.num_kv_heads,
|
head_dim=head_dim,
|
||||||
head_dim=self.kv_cache_spec.head_size,
|
|
||||||
page_size=page_size,
|
page_size=page_size,
|
||||||
kv_data_type=kv_cache_dtype,
|
kv_data_type=kv_cache_dtype,
|
||||||
q_data_type=self.vllm_config.model_config.dtype,
|
q_data_type=self.vllm_config.model_config.dtype,
|
||||||
slot_mapping=common_attn_metadata.slot_mapping,
|
slot_mapping=common_attn_metadata.slot_mapping,
|
||||||
|
max_q_len=max_q_len,
|
||||||
|
max_seq_len=max_seq_len,
|
||||||
|
seq_lens=seq_lens,
|
||||||
|
block_table_tensor=block_table_tensor,
|
||||||
|
prefill_use_trtllm=prefill_use_trtllm,
|
||||||
|
decode_use_trtllm=decode_use_trtllm,
|
||||||
num_decodes=num_decodes,
|
num_decodes=num_decodes,
|
||||||
num_decode_tokens=num_decode_tokens,
|
num_decode_tokens=num_decode_tokens,
|
||||||
num_prefills=num_prefills,
|
num_prefills=num_prefills,
|
||||||
@ -527,12 +563,9 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
|||||||
shared_kv_page_indptr_cpu=shared_kv_page_indptr_cpu,
|
shared_kv_page_indptr_cpu=shared_kv_page_indptr_cpu,
|
||||||
shared_kv_page_indices_cpu=shared_kv_page_indices_cpu,
|
shared_kv_page_indices_cpu=shared_kv_page_indices_cpu,
|
||||||
shared_kv_last_page_len_cpu=shared_kv_last_page_len_cpu,
|
shared_kv_last_page_len_cpu=shared_kv_last_page_len_cpu,
|
||||||
max_seq_len=max_seq_len,
|
|
||||||
seq_lens=seq_lens,
|
|
||||||
block_table_tensor=block_table_tensor,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
self._plan(num_prefills, num_decodes, attn_metadata)
|
self._plan(attn_metadata)
|
||||||
|
|
||||||
return attn_metadata
|
return attn_metadata
|
||||||
|
|
||||||
@ -698,14 +731,17 @@ class FlashInferImpl(AttentionImpl):
|
|||||||
# Regular attention (common case).
|
# Regular attention (common case).
|
||||||
# Decodes are at the front and prefills are at the back,
|
# Decodes are at the front and prefills are at the back,
|
||||||
# according to reorder_batch()
|
# according to reorder_batch()
|
||||||
if prefill_wrapper := attn_metadata.prefill_wrapper:
|
if num_prefill_tokens > 0:
|
||||||
|
prefill_wrapper = attn_metadata.prefill_wrapper
|
||||||
prefill_query = query[num_decode_tokens:]
|
prefill_query = query[num_decode_tokens:]
|
||||||
assert prefill_query.shape[0] == num_prefill_tokens
|
assert prefill_query.shape[0] == num_prefill_tokens
|
||||||
assert prefill_wrapper is not None
|
assert prefill_wrapper is not None
|
||||||
|
|
||||||
|
if not attn_metadata.prefill_use_trtllm:
|
||||||
assert prefill_wrapper._causal
|
assert prefill_wrapper._causal
|
||||||
assert prefill_wrapper._window_left == window_left
|
assert prefill_wrapper._window_left == window_left
|
||||||
assert prefill_wrapper._logits_soft_cap == (self.logits_soft_cap
|
assert prefill_wrapper._logits_soft_cap == (
|
||||||
or 0.0)
|
self.logits_soft_cap or 0.0)
|
||||||
assert prefill_wrapper._sm_scale == self.scale
|
assert prefill_wrapper._sm_scale == self.scale
|
||||||
prefill_wrapper.run(
|
prefill_wrapper.run(
|
||||||
prefill_query,
|
prefill_query,
|
||||||
@ -714,14 +750,45 @@ class FlashInferImpl(AttentionImpl):
|
|||||||
v_scale=layer._v_scale_float,
|
v_scale=layer._v_scale_float,
|
||||||
out=output[num_decode_tokens:],
|
out=output[num_decode_tokens:],
|
||||||
)
|
)
|
||||||
if decode_wrapper := attn_metadata.decode_wrapper:
|
else:
|
||||||
|
# prefill_query may be non-contiguous
|
||||||
|
prefill_query = prefill_query.contiguous()
|
||||||
|
workspace_buffer = prefill_wrapper._float_workspace_buffer
|
||||||
|
block_tables_prefill = attn_metadata.block_table_tensor[
|
||||||
|
num_decode_tokens:]
|
||||||
|
seq_lens_prefill = attn_metadata.seq_lens[num_decode_tokens:]
|
||||||
|
|
||||||
|
# This path needs to be enabled with VLLM_KV_CACHE_LAYOUT = HND
|
||||||
|
assert get_kv_cache_layout() == "HND"
|
||||||
|
assert prefill_query.is_contiguous()
|
||||||
|
assert kv_cache_permute.is_contiguous()
|
||||||
|
assert workspace_buffer.is_contiguous()
|
||||||
|
assert block_tables_prefill.is_contiguous()
|
||||||
|
assert seq_lens_prefill.is_contiguous()
|
||||||
|
|
||||||
|
trtllm_batch_context_with_kv_cache(
|
||||||
|
query=prefill_query,
|
||||||
|
kv_cache=kv_cache_permute,
|
||||||
|
workspace_buffer=workspace_buffer,
|
||||||
|
block_tables=block_tables_prefill,
|
||||||
|
seq_lens=seq_lens_prefill,
|
||||||
|
max_q_len=attn_metadata.max_q_len,
|
||||||
|
max_kv_len=attn_metadata.max_seq_len,
|
||||||
|
bmm1_scale=layer._k_scale_float * self.scale,
|
||||||
|
bmm2_scale=layer._v_scale_float,
|
||||||
|
batch_size=attn_metadata.num_prefills,
|
||||||
|
cum_seq_lens_q=attn_metadata.qo_indptr_gpu,
|
||||||
|
cum_seq_lens_kv=attn_metadata.paged_kv_indptr_gpu,
|
||||||
|
out=output[num_decode_tokens:],
|
||||||
|
)
|
||||||
|
|
||||||
|
if num_decode_tokens > 0:
|
||||||
|
decode_wrapper = attn_metadata.decode_wrapper
|
||||||
decode_query = query[:num_decode_tokens]
|
decode_query = query[:num_decode_tokens]
|
||||||
assert decode_query.shape[0] == num_decode_tokens
|
assert decode_query.shape[0] == num_decode_tokens
|
||||||
assert decode_wrapper is not None
|
assert decode_wrapper is not None
|
||||||
if not use_trtllm_decode_attention(
|
|
||||||
attn_metadata.num_decodes, attn_metadata.max_seq_len,
|
if not attn_metadata.decode_use_trtllm:
|
||||||
self.kv_cache_dtype, attn_metadata.num_qo_heads,
|
|
||||||
attn_metadata.num_kv_heads, attn_metadata.head_dim):
|
|
||||||
assert decode_wrapper._window_left == window_left
|
assert decode_wrapper._window_left == window_left
|
||||||
assert decode_wrapper._logits_soft_cap == (self.logits_soft_cap
|
assert decode_wrapper._logits_soft_cap == (self.logits_soft_cap
|
||||||
or 0.0)
|
or 0.0)
|
||||||
@ -734,22 +801,20 @@ class FlashInferImpl(AttentionImpl):
|
|||||||
out=output[:num_decode_tokens],
|
out=output[:num_decode_tokens],
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# This path needs to be enabled with VLLM_KV_CACHE_LAYOUT = HND
|
|
||||||
if num_decode_tokens > 0:
|
|
||||||
# decode_query may be non-contiguous
|
# decode_query may be non-contiguous
|
||||||
decode_query = decode_query.contiguous()
|
decode_query = decode_query.contiguous()
|
||||||
|
workspace_buffer = decode_wrapper._float_workspace_buffer
|
||||||
block_tables_decode = attn_metadata.block_table_tensor[:
|
block_tables_decode = attn_metadata.block_table_tensor[:
|
||||||
num_decode_tokens]
|
num_decode_tokens]
|
||||||
seq_lens_decode = attn_metadata.seq_lens[:
|
seq_lens_decode = attn_metadata.seq_lens[:num_decode_tokens]
|
||||||
num_decode_tokens]
|
|
||||||
workspace_buffer = decode_wrapper._float_workspace_buffer
|
|
||||||
|
|
||||||
|
# This path needs to be enabled with VLLM_KV_CACHE_LAYOUT = HND
|
||||||
assert get_kv_cache_layout() == "HND"
|
assert get_kv_cache_layout() == "HND"
|
||||||
assert decode_query.is_contiguous()
|
assert decode_query.is_contiguous()
|
||||||
assert kv_cache_permute.is_contiguous()
|
assert kv_cache_permute.is_contiguous()
|
||||||
|
assert workspace_buffer.is_contiguous()
|
||||||
assert block_tables_decode.is_contiguous()
|
assert block_tables_decode.is_contiguous()
|
||||||
assert seq_lens_decode.is_contiguous()
|
assert seq_lens_decode.is_contiguous()
|
||||||
assert workspace_buffer.is_contiguous()
|
|
||||||
|
|
||||||
trtllm_batch_decode_with_kv_cache(
|
trtllm_batch_decode_with_kv_cache(
|
||||||
query=decode_query,
|
query=decode_query,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user