[NVIDIA] Support Flashinfer TRT-LLM Prefill Attention Kernel (#22095)

Signed-off-by: elvischenv <219235043+elvischenv@users.noreply.github.com>
This commit is contained in:
elvischenv 2025-08-05 17:45:34 +08:00 committed by GitHub
parent 4771df7b2b
commit 83156c7b89
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 700 additions and 234 deletions

View File

@ -664,7 +664,7 @@ steps:
# Attention # Attention
# num_heads2 broken by https://github.com/flashinfer-ai/flashinfer/issues/1353 # num_heads2 broken by https://github.com/flashinfer-ai/flashinfer/issues/1353
- pytest -v -s tests/kernels/attention/test_flashinfer.py -k 'not num_heads2' - pytest -v -s tests/kernels/attention/test_flashinfer.py -k 'not num_heads2'
- pytest -v -s tests/kernels/attention/test_flashinfer_trtllm_decode_attention.py - pytest -v -s tests/kernels/attention/test_flashinfer_trtllm_attention.py
- pytest -v -s tests/kernels/test_cutlass_mla_decode.py - pytest -v -s tests/kernels/test_cutlass_mla_decode.py
# Quantization # Quantization
- pytest -v -s tests/kernels/quantization/test_cutlass_scaled_mm.py -k 'fp8' - pytest -v -s tests/kernels/quantization/test_cutlass_scaled_mm.py -k 'fp8'

View File

@ -41,7 +41,6 @@ def benchmark_decode(
device = "cuda" device = "cuda"
torch.manual_seed(0) torch.manual_seed(0)
# Currently only HEAD_GRP_SIZE == 8 is supported
HEAD_GRP_SIZE = 8 HEAD_GRP_SIZE = 8
MAX_SEQ_LEN = max_seq_len MAX_SEQ_LEN = max_seq_len

View File

@ -0,0 +1,250 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import csv
import os
import random
from datetime import datetime
import flashinfer
import torch
FLOAT32_BYTES = torch.finfo(torch.float).bits // 8
# KV Cache Layout for TRT-LLM
# kv_cache_shape = (num_blocks, 2, num_kv_heads, page_size, head_dim)
def to_float8(x, dtype=torch.float8_e4m3fn):
finfo = torch.finfo(dtype)
min_val, max_val = x.aminmax()
amax = torch.maximum(min_val.abs(), max_val.abs()).clamp(min=1e-12)
scale = finfo.max / amax * 0.1
x_scl_sat = (x * scale).clamp(min=finfo.min, max=finfo.max)
return x_scl_sat.to(dtype), scale.float().reciprocal()
@torch.no_grad()
def benchmark_prefill(
num_seqs,
max_seq_len,
page_size=16,
dtype=torch.bfloat16,
kv_layout="HND",
num_kv_heads=8,
kv_cache_dtype="auto",
head_dim=128,
warmup=10,
trials=20,
):
torch.set_default_device("cuda")
torch.manual_seed(0)
HEAD_GRP_SIZE = 8
MAX_SEQ_LEN = max_seq_len
# large number to reduce kv_cache reuse
NUM_BLOCKS = int(256000 / page_size)
workspace_buffer = torch.empty(1024 * 1024 * 1024, dtype=torch.int8)
num_qo_heads = num_kv_heads * HEAD_GRP_SIZE
sm_scale = float(1.0 / (head_dim**0.5))
q_lens = [random.randint(1, MAX_SEQ_LEN) for _ in range(num_seqs)]
q_lens[-1] = MAX_SEQ_LEN
max_q_len = max(q_lens)
q_indptr = torch.cat(
[
torch.tensor([0], dtype=torch.int32),
torch.cumsum(
torch.tensor(q_lens, dtype=torch.int32), dim=0, dtype=torch.int32
),
]
)
q = torch.randn(sum(q_lens), num_qo_heads, head_dim, dtype=dtype)
kv_lens = [random.randint(0, MAX_SEQ_LEN) for _ in range(num_seqs)]
kv_lens[-1] = MAX_SEQ_LEN
seq_lens = [q_len + kv_len for q_len, kv_len in zip(q_lens, kv_lens)]
max_seq_len = max(seq_lens)
seq_lens_tensor = torch.tensor(seq_lens, dtype=torch.int32)
max_num_blocks_per_seq = (max_seq_len + page_size - 1) // page_size
block_tables = torch.randint(
0, NUM_BLOCKS, (num_seqs, max_num_blocks_per_seq), dtype=torch.int32
)
kv_cache_shape = (NUM_BLOCKS, 2, num_kv_heads, page_size, head_dim)
kv_cache = torch.randn(size=kv_cache_shape, dtype=dtype)
k_scale = v_scale = 1.0
if kv_cache_dtype.startswith("fp8"):
kv_cache, _ = to_float8(kv_cache)
output_trtllm = torch.empty(q.shape, dtype=dtype)
kv_indptr = [0]
kv_indices = []
kv_last_page_lens = []
for i in range(num_seqs):
seq_len = seq_lens[i]
assert seq_len > 0
num_blocks = (seq_len + page_size - 1) // page_size
kv_indices.extend(block_tables[i, :num_blocks])
kv_indptr.append(kv_indptr[-1] + num_blocks)
kv_last_page_len = seq_len % page_size
if kv_last_page_len == 0:
kv_last_page_len = page_size
kv_last_page_lens.append(kv_last_page_len)
kv_indptr = torch.tensor(kv_indptr, dtype=torch.int32)
kv_indices = torch.tensor(kv_indices, dtype=torch.int32)
kv_last_page_lens = torch.tensor(kv_last_page_lens, dtype=torch.int32)
output_baseline = torch.empty(q.shape, dtype=dtype)
wrapper = flashinfer.BatchPrefillWithPagedKVCacheWrapper(
workspace_buffer, kv_layout
)
wrapper.plan(
q_indptr,
kv_indptr,
kv_indices,
kv_last_page_lens,
num_qo_heads,
num_kv_heads,
head_dim,
page_size,
causal=True,
sm_scale=sm_scale,
q_data_type=dtype,
kv_data_type=kv_cache.dtype,
)
def time_fn(fn, warmup=10, trials=20):
torch.cuda.synchronize()
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
times = []
for i in range(warmup):
fn()
for i in range(trials):
start.record()
fn()
end.record()
torch.cuda.synchronize()
times.append(start.elapsed_time(end)) # ms
return sum(times) / len(times), torch.std(torch.tensor(times))
def baseline_prefill():
return wrapper.run(
q, kv_cache, k_scale=k_scale, v_scale=v_scale, out=output_baseline
)
def trt_prefill():
return flashinfer.prefill.trtllm_batch_context_with_kv_cache(
query=q,
kv_cache=kv_cache,
workspace_buffer=workspace_buffer,
block_tables=block_tables,
seq_lens=seq_lens_tensor,
max_q_len=max_q_len,
max_kv_len=max_seq_len,
bmm1_scale=k_scale * sm_scale,
bmm2_scale=v_scale,
batch_size=num_seqs,
cum_seq_lens_q=q_indptr,
cum_seq_lens_kv=kv_indptr,
out=output_trtllm,
)
trt_mean, trt_std = time_fn(trt_prefill)
baseline_mean, baseline_std = time_fn(baseline_prefill)
# Calculate percentage speedup (positive means TRT is faster)
speedup_percent = (baseline_mean - trt_mean) / baseline_mean
print(
f"\t{num_seqs}\t{max_seq_len}\t{trt_mean:.5f}\t{trt_std.item():.5f}"
f"\t{baseline_mean:.5f}\t{baseline_std.item():.5f}\t{speedup_percent:.5f}"
)
# Return results for CSV writing
return {
"num_seqs": num_seqs,
"trt_mean": trt_mean,
"trt_std": trt_std.item(),
"baseline_mean": baseline_mean,
"baseline_std": baseline_std.item(),
"speedup_percent": speedup_percent,
"q_dtype": str(dtype),
"kv_cache_dtype": kv_cache_dtype,
"page_size": page_size,
"num_kv_heads": num_kv_heads,
"head_dim": head_dim,
"max_seq_len": max_seq_len,
}
def write_results_to_csv(results, filename=None):
"""Write benchmark results to CSV file."""
if filename is None:
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
filename = f"flashinfer_trtllm_benchmark_{timestamp}.csv"
fieldnames = [
"num_seqs",
"trt_mean",
"trt_std",
"baseline_mean",
"baseline_std",
"speedup_percent",
"q_dtype",
"kv_cache_dtype",
"page_size",
"num_kv_heads",
"head_dim",
"max_seq_len",
]
file_exists = os.path.exists(filename)
with open(filename, "a", newline="") as csvfile:
writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
if not file_exists:
writer.writeheader()
for result in results:
writer.writerow(result)
print(f"Results written to {filename}")
if __name__ == "__main__":
num_seqs = [1, 4, 8, 16, 32, 64, 128, 256]
max_seq_lens = [1024, 2048, 4096, 8192, 16384, 32768, 65536, 131072]
all_results = []
print(
"Running benchmark for q_dtype = bfloat16, kv_cache_dtype: bfloat16, "
"output_dtype: bfloat16"
)
print(
"\tnum_seqs\tmax_seq_len\ttrt_mean\ttrt_std\tbaseline_mean\t"
"baseline_std\tspeedup_percent"
)
for max_seq_len in max_seq_lens:
for bs in num_seqs:
result = benchmark_prefill(
bs,
max_seq_len,
dtype=torch.bfloat16,
kv_cache_dtype="auto",
)
all_results.append(result)
# Write all results to CSV
write_results_to_csv(all_results)

View File

@ -0,0 +1,293 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Optional
import flashinfer
import pytest
import torch
from vllm.platforms import current_platform
if not current_platform.is_device_capability(100):
pytest.skip("This TRTLLM kernel requires NVIDIA Blackwell.",
allow_module_level=True)
FLOAT32_BYTES = torch.finfo(torch.float).bits // 8
# KV Cache Layout for TRT-LLM
# kv_cache_shape = (num_blocks, 2, num_kv_heads, page_size, head_dim)
MAX_Q_LEN = 1024
MAX_KV_LEN = 4096
BATCH_SIZES = [4, 12]
NUM_HEADS = [(64, 8), (16, 16), (40, 8), (32, 8)]
HEAD_SIZES = [128]
BLOCK_SIZES = [16, 32]
KV_LAYOUTS = ["HND"]
DTYPES = [torch.float16, torch.bfloat16]
KV_CACHE_DTYPES = [None, current_platform.fp8_dtype()]
NUM_BLOCKS = 32768 # Large enough to test overflow in index calculation.
SOFT_CAPS = [None, 50.0]
def to_float8(x, dtype=torch.float8_e4m3fn):
finfo = torch.finfo(dtype)
min_val, max_val = x.aminmax()
amax = torch.maximum(min_val.abs(), max_val.abs()).clamp(min=1e-12)
scale = finfo.max / amax * 0.1
x_scl_sat = (x * scale).clamp(min=finfo.min, max=finfo.max)
return x_scl_sat.to(dtype), scale.float().reciprocal()
@pytest.mark.parametrize("batch_size", BATCH_SIZES)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
@pytest.mark.parametrize("kv_layout", KV_LAYOUTS)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPES)
@pytest.mark.parametrize("soft_cap", SOFT_CAPS)
@torch.inference_mode
def test_flashinfer_trtllm_decode_with_baseline(
batch_size: int,
num_heads: tuple[int, int],
head_size: int,
block_size: int,
kv_layout: str,
dtype: torch.dtype,
kv_cache_dtype: Optional[torch.dtype],
soft_cap: Optional[float],
) -> None:
kv_cache_dtype = dtype if kv_cache_dtype is None else kv_cache_dtype
torch.set_default_device("cuda")
current_platform.seed_everything(0)
kv_lens = torch.randint(1, MAX_KV_LEN, (batch_size, ), dtype=torch.int32)
kv_lens[-1] = MAX_KV_LEN
max_kv_len = torch.max(kv_lens).item()
num_seqs = len(kv_lens)
num_query_heads = num_heads[0]
num_kv_heads = num_heads[1]
assert num_query_heads % num_kv_heads == 0
scale = head_size**-0.5
query = torch.randn(num_seqs, num_query_heads, head_size, dtype=dtype)
kv_cache_shape = None
if kv_layout == "NHD":
kv_cache_shape = (NUM_BLOCKS, 2, block_size, num_kv_heads, head_size)
elif kv_layout == "HND":
kv_cache_shape = (NUM_BLOCKS, 2, num_kv_heads, block_size, head_size)
else:
raise ValueError(f"Invalid kv_layout: {kv_layout}")
key_value_cache = torch.randn(kv_cache_shape, dtype=dtype)
kv_scale = 1.0
if kv_cache_dtype is current_platform.fp8_dtype():
key_value_cache, kv_scale = to_float8(key_value_cache,
current_platform.fp8_dtype())
max_num_blocks_per_seq = (max_kv_len + block_size - 1) // block_size
block_tables = torch.randint(0,
NUM_BLOCKS,
(num_seqs, max_num_blocks_per_seq),
dtype=torch.int32)
k_scale = v_scale = kv_scale
kv_indptr = [0]
kv_indices = []
kv_last_page_lens = []
for i in range(num_seqs):
seq_len = kv_lens[i]
assert seq_len > 0
num_blocks = (seq_len + block_size - 1) // block_size
kv_indices.extend(block_tables[i, :num_blocks])
kv_indptr.append(kv_indptr[-1] + num_blocks)
kv_last_page_len = seq_len % block_size
if kv_last_page_len == 0:
kv_last_page_len = block_size
kv_last_page_lens.append(kv_last_page_len)
kv_indptr = torch.tensor(kv_indptr, dtype=torch.int32)
kv_indices = torch.tensor(kv_indices, dtype=torch.int32)
kv_last_page_lens = torch.tensor(kv_last_page_lens, dtype=torch.int32)
workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8)
wrapper = flashinfer.BatchDecodeWithPagedKVCacheWrapper(
workspace_buffer,
kv_layout,
use_tensor_cores=((num_query_heads // num_kv_heads) > 4))
wrapper.plan(kv_indptr,
kv_indices,
kv_last_page_lens,
num_query_heads,
num_kv_heads,
head_size,
block_size,
"NONE",
sm_scale=scale,
q_data_type=dtype,
kv_data_type=kv_cache_dtype,
logits_soft_cap=soft_cap)
output = torch.empty(query.shape, dtype=dtype)
wrapper.run(query,
key_value_cache,
k_scale=k_scale,
v_scale=v_scale,
out=output)
# TRTLLM Decode
kv_lens_tensor = torch.tensor(kv_lens, dtype=torch.int32)
output_trtllm = torch.empty(query.shape, dtype=dtype)
flashinfer.decode.trtllm_batch_decode_with_kv_cache(
query=query.contiguous(),
kv_cache=key_value_cache,
workspace_buffer=workspace_buffer,
block_tables=block_tables,
seq_lens=kv_lens_tensor,
max_seq_len=max_kv_len,
bmm1_scale=k_scale * scale,
bmm2_scale=v_scale,
out=output_trtllm,
)
torch.testing.assert_close(output, output_trtllm, atol=1e-2, rtol=1e-2), \
f"{torch.max(torch.abs(output - output_trtllm))}"
@pytest.mark.parametrize("batch_size", BATCH_SIZES)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
@pytest.mark.parametrize("kv_layout", KV_LAYOUTS)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPES)
@pytest.mark.parametrize("soft_cap", [None])
@torch.inference_mode
def test_flashinfer_trtllm_prefill_with_baseline(
batch_size: int,
num_heads: tuple[int, int],
head_size: int,
block_size: int,
kv_layout: str,
dtype: torch.dtype,
kv_cache_dtype: Optional[torch.dtype],
soft_cap: Optional[float],
) -> None:
kv_cache_dtype = dtype if kv_cache_dtype is None else kv_cache_dtype
if dtype != kv_cache_dtype:
pytest.skip(f"Not supported dtype({dtype}) with "
"kv_cache_dtype({kv_cache_dtype})")
torch.set_default_device("cuda")
current_platform.seed_everything(0)
q_lens = torch.randint(1, MAX_Q_LEN, (batch_size, ), dtype=torch.int32)
q_lens[-1] = MAX_Q_LEN
max_q_len = torch.max(q_lens).item()
q_indptr = torch.cat([
torch.tensor([0], dtype=torch.int32),
torch.cumsum(q_lens, dim=0, dtype=torch.int32),
])
kv_lens = torch.randint(0, MAX_KV_LEN, (batch_size, ), dtype=torch.int32)
kv_lens[-1] = MAX_KV_LEN
seq_lens = kv_lens + q_lens
max_seq_len = torch.max(seq_lens).item()
num_seqs = len(seq_lens)
num_query_heads = num_heads[0]
num_kv_heads = num_heads[1]
assert num_query_heads % num_kv_heads == 0
scale = head_size**-0.5
query = torch.randn(torch.sum(q_lens).item(),
num_query_heads,
head_size,
dtype=dtype)
kv_cache_shape = None
if kv_layout == "NHD":
kv_cache_shape = (NUM_BLOCKS, 2, block_size, num_kv_heads, head_size)
elif kv_layout == "HND":
kv_cache_shape = (NUM_BLOCKS, 2, num_kv_heads, block_size, head_size)
else:
raise ValueError(f"Invalid kv_layout: {kv_layout}")
key_value_cache = torch.randn(kv_cache_shape, dtype=dtype)
kv_scale = 1.0
if kv_cache_dtype is current_platform.fp8_dtype():
key_value_cache, kv_scale = to_float8(key_value_cache,
current_platform.fp8_dtype())
max_num_blocks_per_seq = (max_seq_len + block_size - 1) // block_size
block_tables = torch.randint(0,
NUM_BLOCKS,
(num_seqs, max_num_blocks_per_seq),
dtype=torch.int32)
k_scale = v_scale = kv_scale
kv_indptr = [0]
kv_indices = []
kv_last_page_lens = []
for i in range(num_seqs):
seq_len = seq_lens[i]
assert seq_len > 0
num_blocks = (seq_len + block_size - 1) // block_size
kv_indices.extend(block_tables[i, :num_blocks])
kv_indptr.append(kv_indptr[-1] + num_blocks)
kv_last_page_len = seq_len % block_size
if kv_last_page_len == 0:
kv_last_page_len = block_size
kv_last_page_lens.append(kv_last_page_len)
kv_indptr = torch.tensor(kv_indptr, dtype=torch.int32)
kv_indices = torch.tensor(kv_indices, dtype=torch.int32)
kv_last_page_lens = torch.tensor(kv_last_page_lens, dtype=torch.int32)
workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8)
wrapper = flashinfer.BatchPrefillWithPagedKVCacheWrapper(
workspace_buffer, kv_layout)
wrapper.plan(q_indptr,
kv_indptr,
kv_indices,
kv_last_page_lens,
num_query_heads,
num_kv_heads,
head_size,
block_size,
causal=True,
sm_scale=scale,
q_data_type=dtype,
kv_data_type=kv_cache_dtype,
logits_soft_cap=soft_cap)
output = torch.empty(query.shape, dtype=dtype)
wrapper.run(query,
key_value_cache,
k_scale=k_scale,
v_scale=v_scale,
out=output)
# TRTLLM Decode
output_trtllm = torch.empty(query.shape, dtype=dtype)
flashinfer.prefill.trtllm_batch_context_with_kv_cache(
query=query.contiguous(),
kv_cache=key_value_cache,
workspace_buffer=workspace_buffer,
block_tables=block_tables,
seq_lens=seq_lens,
max_q_len=max_q_len,
max_kv_len=max_seq_len,
bmm1_scale=k_scale * scale,
bmm2_scale=v_scale,
batch_size=num_seqs,
cum_seq_lens_q=q_indptr,
cum_seq_lens_kv=kv_indptr,
out=output_trtllm,
)
torch.testing.assert_close(output, output_trtllm, atol=1e-2, rtol=1e-2), \
f"{torch.max(torch.abs(output - output_trtllm))}"

View File

@ -1,138 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Optional
import flashinfer
import pytest
import torch
from vllm.platforms import current_platform
if not current_platform.is_device_capability(100):
pytest.skip("This TRTLLM kernel requires NVIDIA Blackwell.",
allow_module_level=True)
FLOAT32_BYTES = torch.finfo(torch.float).bits // 8
# KV Cache Layout for TRT-LLM
# kv_cache_shape = (num_blocks, 2, num_kv_heads, page_size, head_dim)
NUM_HEADS = [(64, 8), (16, 16), (40, 8), (32, 8)]
HEAD_SIZES = [128]
BLOCK_SIZES = [16, 32]
DTYPES = [torch.float16, torch.bfloat16]
NUM_BLOCKS = 32768 # Large enough to test overflow in index calculation.
SOFT_CAPS = [None, 30.0, 50.0]
def to_float8(x, dtype=torch.float8_e4m3fn):
finfo = torch.finfo(dtype)
min_val, max_val = x.aminmax()
amax = torch.maximum(min_val.abs(), max_val.abs()).clamp(min=1e-12)
scale = finfo.max / amax * 0.1
x_scl_sat = (x * scale).clamp(min=finfo.min, max=finfo.max)
return x_scl_sat.to(dtype), scale.float().reciprocal()
@pytest.mark.parametrize("kv_lens", [[1328, 18, 463], [1, 54, 293, 70]])
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
@pytest.mark.parametrize("kv_layout", ["HND"])
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("soft_cap", SOFT_CAPS)
@torch.inference_mode
def test_flashinfer_trtllm_decode_with_baseline(
kv_lens: list[int],
num_heads: tuple[int, int],
head_size: int,
dtype: torch.dtype,
block_size: int,
soft_cap: Optional[float],
kv_layout: str,
) -> None:
torch.set_default_device("cuda")
current_platform.seed_everything(0)
num_seqs = len(kv_lens)
num_query_heads = num_heads[0]
num_kv_heads = num_heads[1]
assert num_query_heads % num_kv_heads == 0
max_kv_len = max(kv_lens)
scale = head_size**-0.5
query = torch.randn(num_seqs, num_query_heads, head_size, dtype=dtype)
kv_cache_shape = None
if kv_layout == "NHD":
kv_cache_shape = (NUM_BLOCKS, 2, block_size, num_kv_heads, head_size)
elif kv_layout == "HND":
kv_cache_shape = (NUM_BLOCKS, 2, num_kv_heads, block_size, head_size)
else:
raise ValueError(f"Invalid kv_layout: {kv_layout}")
key_value_cache = torch.randn(kv_cache_shape, dtype=dtype)
max_num_blocks_per_seq = (max_kv_len + block_size - 1) // block_size
block_tables = torch.randint(0,
NUM_BLOCKS,
(num_seqs, max_num_blocks_per_seq),
dtype=torch.int32)
k_scale = v_scale = 1.0
kv_indptr = [0]
kv_indices = []
kv_last_page_lens = []
for i in range(num_seqs):
seq_len = kv_lens[i]
assert seq_len > 0
num_blocks = (seq_len + block_size - 1) // block_size
kv_indices.extend(block_tables[i, :num_blocks])
kv_indptr.append(kv_indptr[-1] + num_blocks)
kv_last_page_len = seq_len % block_size
if kv_last_page_len == 0:
kv_last_page_len = block_size
kv_last_page_lens.append(kv_last_page_len)
kv_indptr = torch.tensor(kv_indptr, dtype=torch.int32)
kv_indices = torch.tensor(kv_indices, dtype=torch.int32)
kv_last_page_lens = torch.tensor(kv_last_page_lens, dtype=torch.int32)
workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8)
wrapper = flashinfer.\
BatchDecodeWithPagedKVCacheWrapper(workspace_buffer, kv_layout,
use_tensor_cores=(
(num_query_heads//num_kv_heads) > 4)
)
wrapper.plan(kv_indptr,
kv_indices,
kv_last_page_lens,
num_query_heads,
num_kv_heads,
head_size,
block_size,
"NONE",
q_data_type=dtype,
kv_data_type=dtype,
logits_soft_cap=soft_cap)
output = torch.empty(query.shape, dtype=dtype)
wrapper.run(query, key_value_cache, scale, out=output)
# TRTLLM Decode
max_kv_len = max(kv_lens)
kv_lens_tensor = torch.tensor(kv_lens,
dtype=torch.int,
device=query.device)
output_trtllm = torch.empty(query.shape, dtype=dtype)
flashinfer.decode.trtllm_batch_decode_with_kv_cache(
query.contiguous(),
key_value_cache,
workspace_buffer,
block_tables,
kv_lens_tensor,
max_kv_len,
bmm1_scale=k_scale * scale,
bmm2_scale=v_scale,
out=output_trtllm,
)
torch.testing.assert_close(output, output_trtllm, atol=1e-2, rtol=1e-2), \
f"{torch.max(torch.abs(output - output_trtllm))}"

View File

@ -46,7 +46,7 @@ from vllm.config import VllmConfig, get_layers_from_vllm_config
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.utils import (async_tensor_h2d, get_kv_cache_torch_dtype, from vllm.utils import (async_tensor_h2d, get_kv_cache_torch_dtype,
make_tensor_with_pad) make_tensor_with_pad)
from vllm.utils.flashinfer import use_trtllm_decode_attention from vllm.utils.flashinfer import use_trtllm_attention
logger = init_logger(__name__) logger = init_logger(__name__)
@ -1114,7 +1114,7 @@ class FlashInferImpl(AttentionImpl):
assert decode_meta.decode_wrapper._sm_scale == softmax_scale assert decode_meta.decode_wrapper._sm_scale == softmax_scale
# TODO: @pavanimajety Remove this once the switch happens # TODO: @pavanimajety Remove this once the switch happens
# inside flashinfer. # inside flashinfer.
if not use_trtllm_decode_attention( if not use_trtllm_attention(
num_decode_tokens, attn_metadata.max_decode_seq_len, num_decode_tokens, attn_metadata.max_decode_seq_len,
kv_cache_dtype, attn_metadata.num_qo_heads, kv_cache_dtype, attn_metadata.num_qo_heads,
attn_metadata.num_kv_heads, attn_metadata.head_dim): attn_metadata.num_kv_heads, attn_metadata.head_dim):

View File

@ -1027,9 +1027,9 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_USE_CUDNN_PREFILL": "VLLM_USE_CUDNN_PREFILL":
lambda: bool(int(os.getenv("VLLM_USE_CUDNN_PREFILL", "0"))), lambda: bool(int(os.getenv("VLLM_USE_CUDNN_PREFILL", "0"))),
# If set to 1, use the TRTLLM Decode Attention backend in flashinfer. # If set to 1, use the TRTLLM Attention backend in flashinfer.
"VLLM_USE_TRTLLM_DECODE_ATTENTION": "VLLM_USE_TRTLLM_ATTENTION":
lambda: os.getenv("VLLM_USE_TRTLLM_DECODE_ATTENTION", None), lambda: os.getenv("VLLM_USE_TRTLLM_ATTENTION", None),
# Controls garbage collection during CUDA graph capture. # Controls garbage collection during CUDA graph capture.
# If set to 0 (default), enables GC freezing to speed up capture time. # If set to 0 (default), enables GC freezing to speed up capture time.

View File

@ -124,7 +124,7 @@ def has_flashinfer_cutlass_fused_moe() -> bool:
@functools.cache @functools.cache
def has_nvidia_artifactory() -> bool: def has_nvidia_artifactory() -> bool:
"""Return ``True`` if NVIDIA's artifactory is accessible. """Return ``True`` if NVIDIA's artifactory is accessible.
This checks connectivity to the kernel inference library artifactory This checks connectivity to the kernel inference library artifactory
which is required for downloading certain cubin kernels like TRTLLM FHMA. which is required for downloading certain cubin kernels like TRTLLM FHMA.
""" """
@ -144,7 +144,7 @@ def has_nvidia_artifactory() -> bool:
return False return False
def use_trtllm_decode_attention( def use_trtllm_attention(
num_tokens: int, num_tokens: int,
max_seq_len: int, max_seq_len: int,
kv_cache_dtype: str, kv_cache_dtype: str,
@ -159,29 +159,26 @@ def use_trtllm_decode_attention(
# Check if the dimensions are supported by TRTLLM decode attention # Check if the dimensions are supported by TRTLLM decode attention
if (attn_head_size is None or num_qo_heads is None or num_kv_heads is None if (attn_head_size is None or num_qo_heads is None or num_kv_heads is None
or num_qo_heads // num_kv_heads > 8
or num_qo_heads % num_kv_heads != 0 or attn_head_size != 128): or num_qo_heads % num_kv_heads != 0 or attn_head_size != 128):
return False return False
env_value = envs.VLLM_USE_TRTLLM_DECODE_ATTENTION env_value = envs.VLLM_USE_TRTLLM_ATTENTION
if env_value is not None: if env_value is not None:
logger.info_once("VLLM_USE_TRTLLM_DECODE_ATTENTION is set to %s", logger.info_once("VLLM_USE_TRTLLM_ATTENTION is set to %s", env_value)
env_value)
# Environment variable is set - respect it # Environment variable is set - respect it
# Making the conditional check for zero because # Making the conditional check for zero because
# the path is automatically enabled if the batch size condition # the path is automatically enabled if the batch size condition
# is satisfied. # is satisfied.
no_use_trtllm = (env_value == "0") no_use_trtllm = (env_value == "0")
if not no_use_trtllm: if not no_use_trtllm:
logger.info_once("Using TRTLLM decode attention.") logger.info_once("Using TRTLLM attention.")
return not no_use_trtllm return not no_use_trtllm
else: else:
# Environment variable not set - use auto-detection # Environment variable not set - use auto-detection
use_trtllm = (num_tokens <= 256 and max_seq_len < 131072 use_trtllm = (num_tokens <= 256 and max_seq_len < 131072
and kv_cache_dtype == "auto") and kv_cache_dtype == "auto")
if use_trtllm: if use_trtllm:
logger.warning_once( logger.warning_once("Using TRTLLM attention (auto-detected).")
"Using TRTLLM decode attention (auto-detected).")
return use_trtllm return use_trtllm
@ -195,5 +192,5 @@ __all__ = [
"has_flashinfer_moe", "has_flashinfer_moe",
"has_flashinfer_cutlass_fused_moe", "has_flashinfer_cutlass_fused_moe",
"has_nvidia_artifactory", "has_nvidia_artifactory",
"use_trtllm_decode_attention", "use_trtllm_attention",
] ]

View File

@ -12,6 +12,7 @@ from flashinfer import (BatchDecodeWithPagedKVCacheWrapper,
MultiLevelCascadeAttentionWrapper) MultiLevelCascadeAttentionWrapper)
from flashinfer.decode import (_get_range_buf, get_seq_lens, from flashinfer.decode import (_get_range_buf, get_seq_lens,
trtllm_batch_decode_with_kv_cache) trtllm_batch_decode_with_kv_cache)
from flashinfer.prefill import trtllm_batch_context_with_kv_cache
import vllm.envs as envs import vllm.envs as envs
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
@ -19,7 +20,7 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.utils import cdiv, is_pin_memory_available from vllm.utils import cdiv, is_pin_memory_available
from vllm.utils.flashinfer import use_trtllm_decode_attention from vllm.utils.flashinfer import use_trtllm_attention
from vllm.v1.attention.backends.flash_attn import use_cascade_attention from vllm.v1.attention.backends.flash_attn import use_cascade_attention
# yapf conflicts with isort for this block # yapf conflicts with isort for this block
# yapf: disable # yapf: disable
@ -149,9 +150,12 @@ class FlashInferMetadata:
slot_mapping: torch.Tensor slot_mapping: torch.Tensor
# For flashinfer trtllm batch decode # For flashinfer trtllm batch decode
max_q_len: int
max_seq_len: int max_seq_len: int
seq_lens: torch.Tensor seq_lens: torch.Tensor
block_table_tensor: torch.Tensor block_table_tensor: torch.Tensor
prefill_use_trtllm: bool
decode_use_trtllm: bool
# For handling prefill decode split # For handling prefill decode split
num_decodes: int num_decodes: int
@ -170,6 +174,9 @@ class FlashInferMetadata:
decode_wrapper: Optional[BatchDecodeWithPagedKVCacheWrapper] = None decode_wrapper: Optional[BatchDecodeWithPagedKVCacheWrapper] = None
cascade_wrapper: Optional[MultiLevelCascadeAttentionWrapper] = None cascade_wrapper: Optional[MultiLevelCascadeAttentionWrapper] = None
qo_indptr_gpu: Optional[torch.Tensor] = None
paged_kv_indptr_gpu: Optional[torch.Tensor] = None
def __post_init__(self): def __post_init__(self):
if self.head_dim is not None: if self.head_dim is not None:
FlashInferBackend.validate_head_size(self.head_dim) FlashInferBackend.validate_head_size(self.head_dim)
@ -305,8 +312,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
2, self._get_workspace_buffer(), get_kv_cache_layout()) 2, self._get_workspace_buffer(), get_kv_cache_layout())
return self._cascade_wrapper return self._cascade_wrapper
def _plan(self, num_prefills: int, num_decodes: int, def _plan(self, attn_metadata: FlashInferMetadata):
attn_metadata: FlashInferMetadata):
if attn_metadata.use_cascade: if attn_metadata.use_cascade:
attn_metadata.cascade_wrapper = self._get_cascade_wrapper() attn_metadata.cascade_wrapper = self._get_cascade_wrapper()
attn_metadata.cascade_wrapper.plan( attn_metadata.cascade_wrapper.plan(
@ -341,6 +347,8 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
# Regular attention (common case). # Regular attention (common case).
# Decodes are at the front and prefills are at the back, # Decodes are at the front and prefills are at the back,
# according to reorder_batch() # according to reorder_batch()
num_prefills = attn_metadata.num_prefills
num_decodes = attn_metadata.num_decodes
if num_prefills > 0: if num_prefills > 0:
# Decodes are first so prefills start after the last decode # Decodes are first so prefills start after the last decode
prefill_start = num_decodes prefill_start = num_decodes
@ -356,23 +364,31 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
# to be relative to the start of the prefill queries. # to be relative to the start of the prefill queries.
qo_indptr_cpu = attn_metadata.qo_indptr_cpu[ qo_indptr_cpu = attn_metadata.qo_indptr_cpu[
prefill_start:] - attn_metadata.qo_indptr_cpu[prefill_start] prefill_start:] - attn_metadata.qo_indptr_cpu[prefill_start]
attn_metadata.prefill_wrapper.plan( paged_kv_indptr_cpu = attn_metadata.paged_kv_indptr_cpu[
qo_indptr_cpu, prefill_start:]
attn_metadata.paged_kv_indptr_cpu[prefill_start:], if not attn_metadata.prefill_use_trtllm:
attn_metadata.paged_kv_indices, attn_metadata.prefill_wrapper.plan(
attn_metadata.paged_kv_last_page_len_cpu[prefill_start:], qo_indptr_cpu,
attn_metadata.num_qo_heads, paged_kv_indptr_cpu,
attn_metadata.num_kv_heads, attn_metadata.paged_kv_indices,
attn_metadata.head_dim, attn_metadata.
attn_metadata.page_size, paged_kv_last_page_len_cpu[prefill_start:],
causal=True, attn_metadata.num_qo_heads,
sm_scale=self.global_hyperparameters.sm_scale, attn_metadata.num_kv_heads,
window_left=self.global_hyperparameters.window_left, attn_metadata.head_dim,
logits_soft_cap=self.global_hyperparameters. attn_metadata.page_size,
logits_soft_cap, causal=True,
q_data_type=attn_metadata.q_data_type, sm_scale=self.global_hyperparameters.sm_scale,
kv_data_type=attn_metadata.kv_data_type, window_left=self.global_hyperparameters.window_left,
) logits_soft_cap=self.global_hyperparameters.
logits_soft_cap,
q_data_type=attn_metadata.q_data_type,
kv_data_type=attn_metadata.kv_data_type,
)
else:
attn_metadata.qo_indptr_gpu = qo_indptr_cpu.to(self.device)
attn_metadata.paged_kv_indptr_gpu = paged_kv_indptr_cpu.to(
self.device)
if num_decodes > 0: if num_decodes > 0:
pure_decode = num_prefills == 0 pure_decode = num_prefills == 0
@ -400,11 +416,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
attn_metadata.decode_wrapper = self._get_decode_wrapper( attn_metadata.decode_wrapper = self._get_decode_wrapper(
num_input_tokens, use_cudagraph) num_input_tokens, use_cudagraph)
if not use_trtllm_decode_attention( if not attn_metadata.decode_use_trtllm:
num_decodes, attn_metadata.max_seq_len,
self.cache_config.cache_dtype,
attn_metadata.num_qo_heads, attn_metadata.num_kv_heads,
attn_metadata.head_dim):
# Use the persistent buffer with padding length, # Use the persistent buffer with padding length,
# instead of the same address but chunked version # instead of the same address but chunked version
# in atten_metadata when using cudagraph. # in atten_metadata when using cudagraph.
@ -437,6 +449,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
split_decodes_and_prefills(common_attn_metadata) split_decodes_and_prefills(common_attn_metadata)
page_size = self.kv_cache_spec.block_size page_size = self.kv_cache_spec.block_size
max_q_len = common_attn_metadata.max_query_len
max_seq_len = common_attn_metadata.seq_lens_cpu.max() max_seq_len = common_attn_metadata.seq_lens_cpu.max()
seq_lens = common_attn_metadata.seq_lens seq_lens = common_attn_metadata.seq_lens
seq_lens_cpu = common_attn_metadata.seq_lens_cpu seq_lens_cpu = common_attn_metadata.seq_lens_cpu
@ -503,6 +516,24 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
cache_dtype) cache_dtype)
else: else:
kv_cache_dtype = self.kv_cache_spec.dtype kv_cache_dtype = self.kv_cache_spec.dtype
num_qo_heads = self.vllm_config.model_config.get_num_attention_heads(
self.vllm_config.parallel_config)
num_kv_heads = self.kv_cache_spec.num_kv_heads
head_dim = self.kv_cache_spec.head_size
# currently prefill trtllm attention does not support fp8 kv cache
# trtllm may not support sliding window
prefill_use_trtllm = (self.global_hyperparameters.window_left == -1
and not cache_dtype.startswith("fp8")
and use_trtllm_attention(
num_prefill_tokens, max_seq_len, cache_dtype,
num_qo_heads, num_kv_heads, head_dim))
decode_use_trtllm = (self.global_hyperparameters.window_left == -1
and use_trtllm_attention(
num_decode_tokens, max_seq_len, cache_dtype,
num_qo_heads, num_kv_heads, head_dim))
attn_metadata = FlashInferMetadata( attn_metadata = FlashInferMetadata(
num_actual_tokens=num_actual_tokens, num_actual_tokens=num_actual_tokens,
qo_indptr_cpu=common_attn_metadata.query_start_loc_cpu, qo_indptr_cpu=common_attn_metadata.query_start_loc_cpu,
@ -510,14 +541,19 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
paged_kv_indices=paged_kv_indices, paged_kv_indices=paged_kv_indices,
paged_kv_last_page_len_cpu=self. paged_kv_last_page_len_cpu=self.
paged_kv_last_page_len_cpu[:num_reqs], paged_kv_last_page_len_cpu[:num_reqs],
num_qo_heads=self.vllm_config.model_config.get_num_attention_heads( num_qo_heads=num_qo_heads,
self.vllm_config.parallel_config), num_kv_heads=num_kv_heads,
num_kv_heads=self.kv_cache_spec.num_kv_heads, head_dim=head_dim,
head_dim=self.kv_cache_spec.head_size,
page_size=page_size, page_size=page_size,
kv_data_type=kv_cache_dtype, kv_data_type=kv_cache_dtype,
q_data_type=self.vllm_config.model_config.dtype, q_data_type=self.vllm_config.model_config.dtype,
slot_mapping=common_attn_metadata.slot_mapping, slot_mapping=common_attn_metadata.slot_mapping,
max_q_len=max_q_len,
max_seq_len=max_seq_len,
seq_lens=seq_lens,
block_table_tensor=block_table_tensor,
prefill_use_trtllm=prefill_use_trtllm,
decode_use_trtllm=decode_use_trtllm,
num_decodes=num_decodes, num_decodes=num_decodes,
num_decode_tokens=num_decode_tokens, num_decode_tokens=num_decode_tokens,
num_prefills=num_prefills, num_prefills=num_prefills,
@ -527,12 +563,9 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
shared_kv_page_indptr_cpu=shared_kv_page_indptr_cpu, shared_kv_page_indptr_cpu=shared_kv_page_indptr_cpu,
shared_kv_page_indices_cpu=shared_kv_page_indices_cpu, shared_kv_page_indices_cpu=shared_kv_page_indices_cpu,
shared_kv_last_page_len_cpu=shared_kv_last_page_len_cpu, shared_kv_last_page_len_cpu=shared_kv_last_page_len_cpu,
max_seq_len=max_seq_len,
seq_lens=seq_lens,
block_table_tensor=block_table_tensor,
) )
self._plan(num_prefills, num_decodes, attn_metadata) self._plan(attn_metadata)
return attn_metadata return attn_metadata
@ -698,30 +731,64 @@ class FlashInferImpl(AttentionImpl):
# Regular attention (common case). # Regular attention (common case).
# Decodes are at the front and prefills are at the back, # Decodes are at the front and prefills are at the back,
# according to reorder_batch() # according to reorder_batch()
if prefill_wrapper := attn_metadata.prefill_wrapper: if num_prefill_tokens > 0:
prefill_wrapper = attn_metadata.prefill_wrapper
prefill_query = query[num_decode_tokens:] prefill_query = query[num_decode_tokens:]
assert prefill_query.shape[0] == num_prefill_tokens assert prefill_query.shape[0] == num_prefill_tokens
assert prefill_wrapper is not None assert prefill_wrapper is not None
assert prefill_wrapper._causal
assert prefill_wrapper._window_left == window_left if not attn_metadata.prefill_use_trtllm:
assert prefill_wrapper._logits_soft_cap == (self.logits_soft_cap assert prefill_wrapper._causal
or 0.0) assert prefill_wrapper._window_left == window_left
assert prefill_wrapper._sm_scale == self.scale assert prefill_wrapper._logits_soft_cap == (
prefill_wrapper.run( self.logits_soft_cap or 0.0)
prefill_query, assert prefill_wrapper._sm_scale == self.scale
kv_cache_permute, prefill_wrapper.run(
k_scale=layer._k_scale_float, prefill_query,
v_scale=layer._v_scale_float, kv_cache_permute,
out=output[num_decode_tokens:], k_scale=layer._k_scale_float,
) v_scale=layer._v_scale_float,
if decode_wrapper := attn_metadata.decode_wrapper: out=output[num_decode_tokens:],
)
else:
# prefill_query may be non-contiguous
prefill_query = prefill_query.contiguous()
workspace_buffer = prefill_wrapper._float_workspace_buffer
block_tables_prefill = attn_metadata.block_table_tensor[
num_decode_tokens:]
seq_lens_prefill = attn_metadata.seq_lens[num_decode_tokens:]
# This path needs to be enabled with VLLM_KV_CACHE_LAYOUT = HND
assert get_kv_cache_layout() == "HND"
assert prefill_query.is_contiguous()
assert kv_cache_permute.is_contiguous()
assert workspace_buffer.is_contiguous()
assert block_tables_prefill.is_contiguous()
assert seq_lens_prefill.is_contiguous()
trtllm_batch_context_with_kv_cache(
query=prefill_query,
kv_cache=kv_cache_permute,
workspace_buffer=workspace_buffer,
block_tables=block_tables_prefill,
seq_lens=seq_lens_prefill,
max_q_len=attn_metadata.max_q_len,
max_kv_len=attn_metadata.max_seq_len,
bmm1_scale=layer._k_scale_float * self.scale,
bmm2_scale=layer._v_scale_float,
batch_size=attn_metadata.num_prefills,
cum_seq_lens_q=attn_metadata.qo_indptr_gpu,
cum_seq_lens_kv=attn_metadata.paged_kv_indptr_gpu,
out=output[num_decode_tokens:],
)
if num_decode_tokens > 0:
decode_wrapper = attn_metadata.decode_wrapper
decode_query = query[:num_decode_tokens] decode_query = query[:num_decode_tokens]
assert decode_query.shape[0] == num_decode_tokens assert decode_query.shape[0] == num_decode_tokens
assert decode_wrapper is not None assert decode_wrapper is not None
if not use_trtllm_decode_attention(
attn_metadata.num_decodes, attn_metadata.max_seq_len, if not attn_metadata.decode_use_trtllm:
self.kv_cache_dtype, attn_metadata.num_qo_heads,
attn_metadata.num_kv_heads, attn_metadata.head_dim):
assert decode_wrapper._window_left == window_left assert decode_wrapper._window_left == window_left
assert decode_wrapper._logits_soft_cap == (self.logits_soft_cap assert decode_wrapper._logits_soft_cap == (self.logits_soft_cap
or 0.0) or 0.0)
@ -734,34 +801,32 @@ class FlashInferImpl(AttentionImpl):
out=output[:num_decode_tokens], out=output[:num_decode_tokens],
) )
else: else:
# decode_query may be non-contiguous
decode_query = decode_query.contiguous()
workspace_buffer = decode_wrapper._float_workspace_buffer
block_tables_decode = attn_metadata.block_table_tensor[:
num_decode_tokens]
seq_lens_decode = attn_metadata.seq_lens[:num_decode_tokens]
# This path needs to be enabled with VLLM_KV_CACHE_LAYOUT = HND # This path needs to be enabled with VLLM_KV_CACHE_LAYOUT = HND
if num_decode_tokens > 0: assert get_kv_cache_layout() == "HND"
# decode_query may be non-contiguous assert decode_query.is_contiguous()
decode_query = decode_query.contiguous() assert kv_cache_permute.is_contiguous()
block_tables_decode = attn_metadata.block_table_tensor[: assert workspace_buffer.is_contiguous()
num_decode_tokens] assert block_tables_decode.is_contiguous()
seq_lens_decode = attn_metadata.seq_lens[: assert seq_lens_decode.is_contiguous()
num_decode_tokens]
workspace_buffer = decode_wrapper._float_workspace_buffer
assert get_kv_cache_layout() == "HND" trtllm_batch_decode_with_kv_cache(
assert decode_query.is_contiguous() query=decode_query,
assert kv_cache_permute.is_contiguous() kv_cache=kv_cache_permute,
assert block_tables_decode.is_contiguous() workspace_buffer=workspace_buffer,
assert seq_lens_decode.is_contiguous() block_tables=block_tables_decode,
assert workspace_buffer.is_contiguous() seq_lens=seq_lens_decode,
max_seq_len=attn_metadata.max_seq_len,
trtllm_batch_decode_with_kv_cache( bmm1_scale=layer._k_scale_float * self.scale,
query=decode_query, bmm2_scale=layer._v_scale_float,
kv_cache=kv_cache_permute, out=output[:num_decode_tokens],
workspace_buffer=workspace_buffer, )
block_tables=block_tables_decode,
seq_lens=seq_lens_decode,
max_seq_len=attn_metadata.max_seq_len,
bmm1_scale=layer._k_scale_float * self.scale,
bmm2_scale=layer._v_scale_float,
out=output[:num_decode_tokens],
)
return output_padded return output_padded
@ -786,8 +851,8 @@ def fast_plan_decode(
non_blocking: bool = True, non_blocking: bool = True,
) -> None: ) -> None:
""" """
A faster version of BatchDecodeWithPagedKVCacheWrapper::plan used for A faster version of BatchDecodeWithPagedKVCacheWrapper::plan used for
cudagraph capture/replay, while the no cudagraph version turns back cudagraph capture/replay, while the no cudagraph version turns back
to the original plan. to the original plan.
using original plan after passing host-side buffers: using original plan after passing host-side buffers:
- only host-to-device copy of indptr and last_page_len buffers - only host-to-device copy of indptr and last_page_len buffers