[NVIDIA] Support Flashinfer TRTLLM FP8-q/kv/out Attention Kernel (#21716)

Signed-off-by: elvischenv <219235043+elvischenv@users.noreply.github.com>
Co-authored-by: Michael Goin <mgoin64@gmail.com>
Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
This commit is contained in:
elvischenv 2025-08-19 20:22:15 +08:00 committed by GitHub
parent 40f26734b9
commit 03752dba8f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 916 additions and 500 deletions

View File

@ -631,6 +631,7 @@ steps:
- vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py - vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py
- vllm/v1/attention/backends/flashinfer.py - vllm/v1/attention/backends/flashinfer.py
- vllm/compilation/fusion.py - vllm/compilation/fusion.py
- vllm/compilation/fusion_attn.py
commands: commands:
- nvidia-smi - nvidia-smi
- python3 examples/offline_inference/basic/chat.py - python3 examples/offline_inference/basic/chat.py
@ -647,6 +648,7 @@ steps:
- pytest -v -s tests/kernels/moe/test_nvfp4_moe.py - pytest -v -s tests/kernels/moe/test_nvfp4_moe.py
# Fusion # Fusion
- pytest -v -s tests/compile/test_fusion_all_reduce.py - pytest -v -s tests/compile/test_fusion_all_reduce.py
- pytest -v -s tests/compile/test_fusion_attn.py::test_attention_quant_pattern
##### 1 GPU test ##### ##### 1 GPU test #####
##### multi gpus test ##### ##### multi gpus test #####

View File

@ -3,16 +3,14 @@
import csv import csv
import os import os
import random
from datetime import datetime from datetime import datetime
from typing import Optional
import flashinfer import flashinfer
import torch import torch
FLOAT32_BYTES = torch.finfo(torch.float).bits // 8 FLOAT32_BYTES = torch.finfo(torch.float).bits // 8
FP8_DTYPE = torch.float8_e4m3fn
# KV Cache Layout for TRT-LLM
# kv_cache_shape = (num_blocks, 2, num_kv_heads, page_size, head_dim)
def to_float8(x, dtype=torch.float8_e4m3fn): def to_float8(x, dtype=torch.float8_e4m3fn):
@ -26,65 +24,107 @@ def to_float8(x, dtype=torch.float8_e4m3fn):
@torch.no_grad() @torch.no_grad()
def benchmark_decode( def benchmark_decode(
num_seqs, dtype: torch.dtype,
max_seq_len, quant_dtypes: tuple[
page_size=16, Optional[torch.dtype], Optional[torch.dtype], Optional[torch.dtype]
dtype=torch.bfloat16, ],
kv_layout="HND", batch_size: int,
num_kv_heads=8, max_seq_len: int,
kv_cache_dtype="auto", num_heads: tuple[int, int] = (64, 8),
head_dim=128, head_size: int = 128,
warmup=10, kv_layout: str = "HND",
trials=20, block_size: int = 16,
warmup: int = 10,
trials: int = 20,
): ):
torch.set_default_device("cuda") torch.set_default_device("cuda")
device = "cuda"
torch.manual_seed(0) torch.manual_seed(0)
HEAD_GRP_SIZE = 8 q_quant_dtype, kv_quant_dtype, o_quant_dtype = quant_dtypes
MAX_SEQ_LEN = max_seq_len q_quant_dtype = q_quant_dtype or dtype
kv_quant_dtype = kv_quant_dtype or dtype
o_quant_dtype = o_quant_dtype or dtype
num_qo_heads, num_kv_heads = num_heads
assert num_qo_heads % num_kv_heads == 0
sm_scale = float(1.0 / (head_size**0.5))
# large number to reduce kv_cache reuse # large number to reduce kv_cache reuse
NUM_BLOCKS = int(256000 / page_size) NUM_BLOCKS = int(256000 / block_size)
workspace_buffer = torch.empty(1024 * 1024 * 1024, dtype=torch.int8, device=device) kv_cache_shape = None
if kv_layout == "NHD":
kv_cache_shape = (NUM_BLOCKS, 2, block_size, num_kv_heads, head_size)
elif kv_layout == "HND":
kv_cache_shape = (NUM_BLOCKS, 2, num_kv_heads, block_size, head_size)
else:
raise ValueError(f"Invalid kv_layout: {kv_layout}")
# For decode, batch_size is num_decode_token query = torch.randn(batch_size, num_qo_heads, head_size, dtype=dtype)
num_qo_heads = num_kv_heads * HEAD_GRP_SIZE if q_quant_dtype == FP8_DTYPE:
sm_scale = float(1.0 / (head_dim**0.5)) query, q_scale = to_float8(query)
q = torch.randn(num_seqs, num_qo_heads, head_dim, device=device, dtype=dtype) ref_query = query.to(dtype) * q_scale
kv_lens = [random.randint(1, MAX_SEQ_LEN) for _ in range(num_seqs)] else:
q_scale = 1.0
ref_query = query
max_kv_len = max(kv_lens) kv_lens = torch.randint(1, max_seq_len, (batch_size,), dtype=torch.int32)
kv_lens_tensor = torch.tensor(kv_lens, dtype=torch.int, device=device) kv_lens[-1] = max_seq_len
max_num_blocks_per_seq = (max_kv_len + page_size - 1) // page_size
seq_lens = kv_lens
max_seq_len = torch.max(seq_lens).item()
kv_cache = torch.randn(kv_cache_shape, dtype=dtype)
if kv_quant_dtype == FP8_DTYPE:
kv_cache, kv_scale = to_float8(kv_cache)
ref_kv_cache = kv_cache.to(dtype) * kv_scale
else:
kv_scale = 1.0
ref_kv_cache = kv_cache
k_scale = v_scale = kv_scale
max_num_blocks_per_seq = (max_seq_len + block_size - 1) // block_size
block_tables = torch.randint( block_tables = torch.randint(
0, NUM_BLOCKS, (num_seqs, max_num_blocks_per_seq), dtype=torch.int32 0, NUM_BLOCKS, (batch_size, max_num_blocks_per_seq), dtype=torch.int32
) )
kv_indptr = [0]
kv_indices = []
kv_last_page_lens = []
for i in range(batch_size):
seq_len = seq_lens[i]
assert seq_len > 0
num_blocks = (seq_len + block_size - 1) // block_size
kv_indices.extend(block_tables[i, :num_blocks])
kv_indptr.append(kv_indptr[-1] + num_blocks)
kv_last_page_len = seq_len % block_size
if kv_last_page_len == 0:
kv_last_page_len = block_size
kv_last_page_lens.append(kv_last_page_len)
kv_cache_shape = (NUM_BLOCKS, 2, num_kv_heads, page_size, head_dim) kv_indptr = torch.tensor(kv_indptr, dtype=torch.int32)
kv_cache = torch.randn(size=kv_cache_shape, device=device, dtype=dtype) kv_indices = torch.tensor(kv_indices, dtype=torch.int32)
k_scale = v_scale = 1.0 kv_last_page_lens = torch.tensor(kv_last_page_lens, dtype=torch.int32)
workspace_buffer = torch.zeros(1024 * 1024 * 1024, dtype=torch.int8)
if kv_cache_dtype.startswith("fp8"): wrapper = flashinfer.BatchDecodeWithPagedKVCacheWrapper(
kv_cache, _ = to_float8(kv_cache) workspace_buffer,
kv_layout,
output_trtllm = torch.empty(q.shape, dtype=dtype) use_tensor_cores=((num_qo_heads // num_kv_heads) > 4),
)
# Benchmark TRT decode wrapper.plan(
def trt_decode(): kv_indptr,
return flashinfer.decode.trtllm_batch_decode_with_kv_cache( kv_indices,
q, kv_last_page_lens,
kv_cache, num_qo_heads,
workspace_buffer, num_kv_heads,
block_tables, head_size,
kv_lens_tensor, block_size,
max_kv_len, "NONE",
bmm1_scale=k_scale * sm_scale, sm_scale=sm_scale,
bmm2_scale=v_scale, q_data_type=dtype,
out=output_trtllm, kv_data_type=dtype,
) )
def time_fn(fn, warmup=10, trials=20): def time_fn(fn, warmup=10, trials=20):
torch.cuda.synchronize() torch.cuda.synchronize()
@ -101,74 +141,51 @@ def benchmark_decode(
times.append(start.elapsed_time(end)) # ms times.append(start.elapsed_time(end)) # ms
return sum(times) / len(times), torch.std(torch.tensor(times)) return sum(times) / len(times), torch.std(torch.tensor(times))
# TRT Decode o_scale = 1.0
trt_mean, trt_std = time_fn(trt_decode) output_baseline = torch.empty(ref_query.shape, dtype=dtype)
output_trtllm = torch.empty(query.shape, dtype=o_quant_dtype)
kv_indptr = [0]
kv_indices = []
kv_last_page_lens = []
for i in range(num_seqs):
seq_len = kv_lens[i]
assert seq_len > 0
num_blocks = (seq_len + page_size - 1) // page_size
kv_indices.extend(block_tables[i, :num_blocks])
kv_indptr.append(kv_indptr[-1] + num_blocks)
kv_last_page_len = seq_len % page_size
if kv_last_page_len == 0:
kv_last_page_len = page_size
kv_last_page_lens.append(kv_last_page_len)
kv_indptr = torch.tensor(kv_indptr, dtype=torch.int32)
kv_indices = torch.tensor(kv_indices, dtype=torch.int32)
kv_last_page_lens = torch.tensor(kv_last_page_lens, dtype=torch.int32)
output_baseline = torch.empty(q.shape, dtype=dtype)
wrapper = flashinfer.BatchDecodeWithPagedKVCacheWrapper(
workspace_buffer,
kv_layout,
use_tensor_cores=((num_qo_heads // num_kv_heads) > 4),
)
wrapper.plan(
kv_indptr,
kv_indices,
kv_last_page_lens,
num_qo_heads,
num_kv_heads,
head_dim,
page_size,
"NONE",
q_data_type=dtype,
kv_data_type=torch.float8_e4m3fn if kv_cache_dtype.startswith("fp8") else dtype,
)
def baseline_decode(): def baseline_decode():
return wrapper.run(q, kv_cache, sm_scale, k_scale, v_scale, output_baseline) return wrapper.run(ref_query, ref_kv_cache, out=output_baseline)
def trtllm_decode():
return flashinfer.decode.trtllm_batch_decode_with_kv_cache(
query=query,
kv_cache=kv_cache,
workspace_buffer=workspace_buffer,
block_tables=block_tables,
seq_lens=seq_lens,
max_seq_len=max_seq_len,
bmm1_scale=q_scale * k_scale * sm_scale,
bmm2_scale=v_scale / o_scale,
out=output_trtllm,
)
baseline_mean, baseline_std = time_fn(baseline_decode) baseline_mean, baseline_std = time_fn(baseline_decode)
trtllm_mean, trtllm_std = time_fn(trtllm_decode)
# Calculate percentage speedup (positive means TRT is faster) # Calculate percentage speedup (positive means TRT is faster)
speedup_percent = (baseline_mean - trt_mean) / baseline_mean speedup_percent = (baseline_mean - trtllm_mean) / baseline_mean
print( print(
f"\t{num_seqs}\t{max_seq_len}\t{trt_mean:.3f}\t{trt_std.item():.3f}" f"\t{batch_size}\t{max_seq_len}\t{trtllm_mean:.3f}\t{trtllm_std.item():.3f}"
f"\t{baseline_mean:.3f}\t{baseline_std.item():.3f}\t{speedup_percent:.3f}" f"\t{baseline_mean:.3f}\t{baseline_std.item():.3f}\t{speedup_percent:.3f}"
) )
# Return results for CSV writing # Return results for CSV writing
return { return {
"num_seqs": num_seqs, "batch_size": batch_size,
"trt_mean": trt_mean, "trtllm_mean": trtllm_mean,
"trt_std": trt_std.item(), "trtllm_std": trtllm_std.item(),
"baseline_mean": baseline_mean, "baseline_mean": baseline_mean,
"baseline_std": baseline_std.item(), "baseline_std": baseline_std.item(),
"speedup_percent": speedup_percent, "speedup_percent": speedup_percent,
"q_dtype": str(dtype), "q_dtype": str(q_quant_dtype),
"kv_cache_dtype": kv_cache_dtype, "kv_cache_dtype": str(kv_quant_dtype),
"page_size": page_size, "output_dtype": str(o_quant_dtype),
"block_size": block_size,
"num_kv_heads": num_kv_heads, "num_kv_heads": num_kv_heads,
"head_dim": head_dim, "head_size": head_size,
"max_seq_len": max_seq_len, "max_seq_len": max_seq_len,
} }
@ -180,17 +197,18 @@ def write_results_to_csv(results, filename=None):
filename = f"flashinfer_trtllm_benchmark_{timestamp}.csv" filename = f"flashinfer_trtllm_benchmark_{timestamp}.csv"
fieldnames = [ fieldnames = [
"num_seqs", "batch_size",
"trt_mean", "trtllm_mean",
"trt_std", "trtllm_std",
"baseline_mean", "baseline_mean",
"baseline_std", "baseline_std",
"speedup_percent", "speedup_percent",
"q_dtype", "q_dtype",
"kv_cache_dtype", "kv_cache_dtype",
"page_size", "output_dtype",
"block_size",
"num_kv_heads", "num_kv_heads",
"head_dim", "head_size",
"max_seq_len", "max_seq_len",
] ]
@ -209,45 +227,42 @@ def write_results_to_csv(results, filename=None):
if __name__ == "__main__": if __name__ == "__main__":
num_seqs = [1, 4, 8, 16, 32, 64, 128, 256] batch_sizes = [1, 4, 8, 16, 32, 64, 128, 256]
max_seq_lens = [1024, 2048, 4096, 8192, 16384, 32768, 65536, 131072] max_seq_lens = [1024, 2048, 4096, 8192, 16384, 32768, 65536, 131072]
all_results = [] all_results = []
print( dtype = torch.bfloat16
"Running benchmark for q_dtype = bfloat16, kv_cache_dtype: bfloat16, " quant_dtypes = [
"output_dtype: bfloat16" # (q_quant_dtype, kv_quant_dtype, o_quant_dtype)
) (None, None, None),
print( (None, FP8_DTYPE, None),
"\tnum_seqs\tmax_seq_len\ttrt_mean\ttrt_std\tbaseline_mean\t" (FP8_DTYPE, FP8_DTYPE, FP8_DTYPE),
"baseline_std\tspeedup_percent" ]
)
for max_seq_len in max_seq_lens:
for bs in num_seqs:
result = benchmark_decode(
bs,
max_seq_len,
dtype=torch.bfloat16,
kv_cache_dtype="auto",
)
all_results.append(result)
print( for quant_dtype in quant_dtypes:
"Running benchmark for q_dtype = bfloat16, kv_cache_dtype: fp8, " q_quant_dtype, kv_quant_dtype, o_quant_dtype = quant_dtype
"output_dtype: bfloat16" q_quant_dtype = q_quant_dtype or dtype
) kv_quant_dtype = kv_quant_dtype or dtype
print( o_quant_dtype = o_quant_dtype or dtype
"\tnum_seqs\tmax_seq_len\ttrt_mean\ttrt_std\tbaseline_mean\t"
"baseline_std\tspeedup_percent" print(
) f"Running benchmark for q_dtype = {q_quant_dtype}, "
for max_seq_len in max_seq_lens: f"kv_cache_dtype: {kv_quant_dtype}, "
for bs in num_seqs: f"output_dtype: {o_quant_dtype}"
result = benchmark_decode( )
bs, print(
max_seq_len, "\tbatch_size\tmax_seq_len\ttrtllm_mean\ttrtllm_std\tbaseline_mean\t"
dtype=torch.bfloat16, "baseline_std\tspeedup_percent"
kv_cache_dtype="fp8", )
) for max_seq_len in max_seq_lens:
all_results.append(result) for bs in batch_sizes:
result = benchmark_decode(
dtype=dtype,
quant_dtypes=quant_dtype,
batch_size=bs,
max_seq_len=max_seq_len,
)
all_results.append(result)
# Write all results to CSV # Write all results to CSV
write_results_to_csv(all_results) write_results_to_csv(all_results)

View File

@ -3,16 +3,14 @@
import csv import csv
import os import os
import random
from datetime import datetime from datetime import datetime
from typing import Optional
import flashinfer import flashinfer
import torch import torch
FLOAT32_BYTES = torch.finfo(torch.float).bits // 8 FLOAT32_BYTES = torch.finfo(torch.float).bits // 8
FP8_DTYPE = torch.float8_e4m3fn
# KV Cache Layout for TRT-LLM
# kv_cache_shape = (num_blocks, 2, num_kv_heads, page_size, head_dim)
def to_float8(x, dtype=torch.float8_e4m3fn): def to_float8(x, dtype=torch.float8_e4m3fn):
@ -26,84 +24,99 @@ def to_float8(x, dtype=torch.float8_e4m3fn):
@torch.no_grad() @torch.no_grad()
def benchmark_prefill( def benchmark_prefill(
num_seqs, dtype: torch.dtype,
max_seq_len, quant_dtypes: tuple[
page_size=16, Optional[torch.dtype], Optional[torch.dtype], Optional[torch.dtype]
dtype=torch.bfloat16, ],
kv_layout="HND", batch_size: int,
num_kv_heads=8, max_seq_len: int,
kv_cache_dtype="auto", num_heads: tuple[int, int] = (64, 8),
head_dim=128, head_size: int = 128,
warmup=10, kv_layout: str = "HND",
trials=20, block_size: int = 16,
warmup: int = 10,
trials: int = 20,
): ):
torch.set_default_device("cuda") torch.set_default_device("cuda")
torch.manual_seed(0) torch.manual_seed(0)
HEAD_GRP_SIZE = 8 q_quant_dtype, kv_quant_dtype, o_quant_dtype = quant_dtypes
MAX_SEQ_LEN = max_seq_len q_quant_dtype = q_quant_dtype or dtype
kv_quant_dtype = kv_quant_dtype or dtype
o_quant_dtype = o_quant_dtype or dtype
max_q_len = max_kv_len = max_seq_len
num_qo_heads, num_kv_heads = num_heads
assert num_qo_heads % num_kv_heads == 0
sm_scale = float(1.0 / (head_size**0.5))
# large number to reduce kv_cache reuse # large number to reduce kv_cache reuse
NUM_BLOCKS = int(256000 / page_size) NUM_BLOCKS = int(256000 / block_size)
workspace_buffer = torch.empty(1024 * 1024 * 1024, dtype=torch.int8) kv_cache_shape = None
if kv_layout == "NHD":
kv_cache_shape = (NUM_BLOCKS, 2, block_size, num_kv_heads, head_size)
elif kv_layout == "HND":
kv_cache_shape = (NUM_BLOCKS, 2, num_kv_heads, block_size, head_size)
else:
raise ValueError(f"Invalid kv_layout: {kv_layout}")
num_qo_heads = num_kv_heads * HEAD_GRP_SIZE q_lens = torch.randint(1, max_q_len, (batch_size,), dtype=torch.int32)
sm_scale = float(1.0 / (head_dim**0.5)) q_lens[-1] = max_q_len
q_lens = [random.randint(1, MAX_SEQ_LEN) for _ in range(num_seqs)]
q_lens[-1] = MAX_SEQ_LEN
max_q_len = max(q_lens)
q_indptr = torch.cat( q_indptr = torch.cat(
[ [
torch.tensor([0], dtype=torch.int32), torch.tensor([0], dtype=torch.int32),
torch.cumsum( torch.cumsum(q_lens, dim=0, dtype=torch.int32),
torch.tensor(q_lens, dtype=torch.int32), dim=0, dtype=torch.int32
),
] ]
) )
q = torch.randn(sum(q_lens), num_qo_heads, head_dim, dtype=dtype)
kv_lens = [random.randint(0, MAX_SEQ_LEN) for _ in range(num_seqs)] query = torch.randn(torch.sum(q_lens).item(), num_qo_heads, head_size, dtype=dtype)
kv_lens[-1] = MAX_SEQ_LEN if q_quant_dtype == FP8_DTYPE:
query, q_scale = to_float8(query)
ref_query = query.to(dtype) * q_scale
else:
q_scale = 1.0
ref_query = query
seq_lens = [q_len + kv_len for q_len, kv_len in zip(q_lens, kv_lens)] kv_lens = torch.randint(0, max_kv_len, (batch_size,), dtype=torch.int32)
max_seq_len = max(seq_lens) kv_lens[-1] = max_kv_len
seq_lens_tensor = torch.tensor(seq_lens, dtype=torch.int32)
max_num_blocks_per_seq = (max_seq_len + page_size - 1) // page_size seq_lens = kv_lens + q_lens
max_seq_len = torch.max(seq_lens).item()
kv_cache = torch.randn(kv_cache_shape, dtype=dtype)
if kv_quant_dtype == FP8_DTYPE:
kv_cache, kv_scale = to_float8(kv_cache)
ref_kv_cache = kv_cache.to(dtype) * kv_scale
else:
kv_scale = 1.0
ref_kv_cache = kv_cache
k_scale = v_scale = kv_scale
max_num_blocks_per_seq = (max_seq_len + block_size - 1) // block_size
block_tables = torch.randint( block_tables = torch.randint(
0, NUM_BLOCKS, (num_seqs, max_num_blocks_per_seq), dtype=torch.int32 0, NUM_BLOCKS, (batch_size, max_num_blocks_per_seq), dtype=torch.int32
) )
kv_cache_shape = (NUM_BLOCKS, 2, num_kv_heads, page_size, head_dim)
kv_cache = torch.randn(size=kv_cache_shape, dtype=dtype)
k_scale = v_scale = 1.0
if kv_cache_dtype.startswith("fp8"):
kv_cache, _ = to_float8(kv_cache)
output_trtllm = torch.empty(q.shape, dtype=dtype)
kv_indptr = [0] kv_indptr = [0]
kv_indices = [] kv_indices = []
kv_last_page_lens = [] kv_last_page_lens = []
for i in range(num_seqs): for i in range(batch_size):
seq_len = seq_lens[i] seq_len = seq_lens[i]
assert seq_len > 0 assert seq_len > 0
num_blocks = (seq_len + page_size - 1) // page_size num_blocks = (seq_len + block_size - 1) // block_size
kv_indices.extend(block_tables[i, :num_blocks]) kv_indices.extend(block_tables[i, :num_blocks])
kv_indptr.append(kv_indptr[-1] + num_blocks) kv_indptr.append(kv_indptr[-1] + num_blocks)
kv_last_page_len = seq_len % page_size kv_last_page_len = seq_len % block_size
if kv_last_page_len == 0: if kv_last_page_len == 0:
kv_last_page_len = page_size kv_last_page_len = block_size
kv_last_page_lens.append(kv_last_page_len) kv_last_page_lens.append(kv_last_page_len)
kv_indptr = torch.tensor(kv_indptr, dtype=torch.int32) kv_indptr = torch.tensor(kv_indptr, dtype=torch.int32)
kv_indices = torch.tensor(kv_indices, dtype=torch.int32) kv_indices = torch.tensor(kv_indices, dtype=torch.int32)
kv_last_page_lens = torch.tensor(kv_last_page_lens, dtype=torch.int32) kv_last_page_lens = torch.tensor(kv_last_page_lens, dtype=torch.int32)
workspace_buffer = torch.zeros(1024 * 1024 * 1024, dtype=torch.int8)
output_baseline = torch.empty(q.shape, dtype=dtype)
wrapper = flashinfer.BatchPrefillWithPagedKVCacheWrapper( wrapper = flashinfer.BatchPrefillWithPagedKVCacheWrapper(
workspace_buffer, kv_layout workspace_buffer, kv_layout
@ -115,12 +128,12 @@ def benchmark_prefill(
kv_last_page_lens, kv_last_page_lens,
num_qo_heads, num_qo_heads,
num_kv_heads, num_kv_heads,
head_dim, head_size,
page_size, block_size,
causal=True, causal=True,
sm_scale=sm_scale, sm_scale=sm_scale,
q_data_type=dtype, q_data_type=dtype,
kv_data_type=kv_cache.dtype, kv_data_type=dtype,
) )
def time_fn(fn, warmup=10, trials=20): def time_fn(fn, warmup=10, trials=20):
@ -138,52 +151,55 @@ def benchmark_prefill(
times.append(start.elapsed_time(end)) # ms times.append(start.elapsed_time(end)) # ms
return sum(times) / len(times), torch.std(torch.tensor(times)) return sum(times) / len(times), torch.std(torch.tensor(times))
def baseline_prefill(): o_scale = 1.0
return wrapper.run( output_baseline = torch.empty(ref_query.shape, dtype=dtype)
q, kv_cache, k_scale=k_scale, v_scale=v_scale, out=output_baseline output_trtllm = torch.empty(query.shape, dtype=o_quant_dtype)
)
def trt_prefill(): def baseline_prefill():
return wrapper.run(ref_query, ref_kv_cache, out=output_baseline)
def trtllm_prefill():
return flashinfer.prefill.trtllm_batch_context_with_kv_cache( return flashinfer.prefill.trtllm_batch_context_with_kv_cache(
query=q, query=query,
kv_cache=kv_cache, kv_cache=kv_cache,
workspace_buffer=workspace_buffer, workspace_buffer=workspace_buffer,
block_tables=block_tables, block_tables=block_tables,
seq_lens=seq_lens_tensor, seq_lens=seq_lens,
max_q_len=max_q_len, max_q_len=max_q_len,
max_kv_len=max_seq_len, max_kv_len=max_seq_len,
bmm1_scale=k_scale * sm_scale, bmm1_scale=q_scale * k_scale * sm_scale,
bmm2_scale=v_scale, bmm2_scale=v_scale / o_scale,
batch_size=num_seqs, batch_size=batch_size,
cum_seq_lens_q=q_indptr, cum_seq_lens_q=q_indptr,
cum_seq_lens_kv=kv_indptr, cum_seq_lens_kv=kv_indptr,
out=output_trtllm, out=output_trtllm,
) )
trt_mean, trt_std = time_fn(trt_prefill)
baseline_mean, baseline_std = time_fn(baseline_prefill) baseline_mean, baseline_std = time_fn(baseline_prefill)
trtllm_mean, trtllm_std = time_fn(trtllm_prefill)
# Calculate percentage speedup (positive means TRT is faster) # Calculate percentage speedup (positive means TRT is faster)
speedup_percent = (baseline_mean - trt_mean) / baseline_mean speedup_percent = (baseline_mean - trtllm_mean) / baseline_mean
print( print(
f"\t{num_seqs}\t{max_seq_len}\t{trt_mean:.5f}\t{trt_std.item():.5f}" f"\t{batch_size}\t{max_seq_len}\t{trtllm_mean:8.3f}\t{trtllm_std.item():8.3f}"
f"\t{baseline_mean:.5f}\t{baseline_std.item():.5f}\t{speedup_percent:.5f}" f"\t{baseline_mean:8.3f}\t{baseline_std.item():8.3f}\t{speedup_percent:8.3f}"
) )
# Return results for CSV writing # Return results for CSV writing
return { return {
"num_seqs": num_seqs, "batch_size": batch_size,
"trt_mean": trt_mean, "trtllm_mean": trtllm_mean,
"trt_std": trt_std.item(), "trtllm_std": trtllm_std.item(),
"baseline_mean": baseline_mean, "baseline_mean": baseline_mean,
"baseline_std": baseline_std.item(), "baseline_std": baseline_std.item(),
"speedup_percent": speedup_percent, "speedup_percent": speedup_percent,
"q_dtype": str(dtype), "q_dtype": str(q_quant_dtype),
"kv_cache_dtype": kv_cache_dtype, "kv_cache_dtype": str(kv_quant_dtype),
"page_size": page_size, "output_dtype": str(o_quant_dtype),
"block_size": block_size,
"num_kv_heads": num_kv_heads, "num_kv_heads": num_kv_heads,
"head_dim": head_dim, "head_size": head_size,
"max_seq_len": max_seq_len, "max_seq_len": max_seq_len,
} }
@ -195,17 +211,18 @@ def write_results_to_csv(results, filename=None):
filename = f"flashinfer_trtllm_benchmark_{timestamp}.csv" filename = f"flashinfer_trtllm_benchmark_{timestamp}.csv"
fieldnames = [ fieldnames = [
"num_seqs", "batch_size",
"trt_mean", "trtllm_mean",
"trt_std", "trtllm_std",
"baseline_mean", "baseline_mean",
"baseline_std", "baseline_std",
"speedup_percent", "speedup_percent",
"q_dtype", "q_dtype",
"kv_cache_dtype", "kv_cache_dtype",
"page_size", "output_dtype",
"block_size",
"num_kv_heads", "num_kv_heads",
"head_dim", "head_size",
"max_seq_len", "max_seq_len",
] ]
@ -224,27 +241,41 @@ def write_results_to_csv(results, filename=None):
if __name__ == "__main__": if __name__ == "__main__":
num_seqs = [1, 4, 8, 16, 32, 64, 128, 256] batch_sizes = [1, 4, 8, 16, 32, 64, 128, 256]
max_seq_lens = [1024, 2048, 4096, 8192, 16384, 32768, 65536, 131072] max_seq_lens = [1024, 2048, 4096, 8192, 16384, 32768, 65536, 131072]
all_results = [] all_results = []
print( dtype = torch.bfloat16
"Running benchmark for q_dtype = bfloat16, kv_cache_dtype: bfloat16, " quant_dtypes = [
"output_dtype: bfloat16" # (q_quant_dtype, kv_quant_dtype, o_quant_dtype)
) (None, None, None),
print( (FP8_DTYPE, FP8_DTYPE, FP8_DTYPE),
"\tnum_seqs\tmax_seq_len\ttrt_mean\ttrt_std\tbaseline_mean\t" ]
"baseline_std\tspeedup_percent"
) for quant_dtype in quant_dtypes:
for max_seq_len in max_seq_lens: q_quant_dtype, kv_quant_dtype, o_quant_dtype = quant_dtype
for bs in num_seqs: q_quant_dtype = q_quant_dtype or dtype
result = benchmark_prefill( kv_quant_dtype = kv_quant_dtype or dtype
bs, o_quant_dtype = o_quant_dtype or dtype
max_seq_len,
dtype=torch.bfloat16, print(
kv_cache_dtype="auto", f"Running benchmark for q_dtype = {q_quant_dtype}, "
) f"kv_cache_dtype: {kv_quant_dtype}, "
all_results.append(result) f"output_dtype: {o_quant_dtype}"
)
print(
"\tbatch_size\tmax_seq_len\ttrtllm_mean\ttrtllm_std\tbaseline_mean\t"
"baseline_std\tspeedup_percent"
)
for max_seq_len in max_seq_lens:
for bs in batch_sizes:
result = benchmark_prefill(
dtype=dtype,
quant_dtypes=quant_dtype,
batch_size=bs,
max_seq_len=max_seq_len,
)
all_results.append(result)
# Write all results to CSV # Write all results to CSV
write_results_to_csv(all_results) write_results_to_csv(all_results)

View File

@ -1,5 +1,6 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import copy
from typing import Optional from typing import Optional
import pytest import pytest
@ -7,13 +8,27 @@ import torch._dynamo
from tests.compile.backend import TestBackend from tests.compile.backend import TestBackend
from tests.models.utils import check_outputs_equal from tests.models.utils import check_outputs_equal
from tests.v1.attention.utils import (BatchSpec, _Backend,
create_common_attn_metadata)
from vllm import LLM, SamplingParams from vllm import LLM, SamplingParams
from vllm.attention import Attention
from vllm.attention.selector import global_force_attn_backend_context_manager
from vllm.compilation.fusion import QUANT_OPS, QuantKey, kFp8StaticTensorSym from vllm.compilation.fusion import QUANT_OPS, QuantKey, kFp8StaticTensorSym
from vllm.compilation.fusion_attn import ATTN_OP, AttnFusionPass from vllm.compilation.fusion_attn import ATTN_OP, AttnFusionPass
from vllm.compilation.fx_utils import find_op_nodes from vllm.compilation.fx_utils import find_op_nodes
from vllm.compilation.noop_elimination import NoOpEliminationPass from vllm.compilation.noop_elimination import NoOpEliminationPass
from vllm.config import CompilationConfig, CompilationLevel, VllmConfig from vllm.config import (CacheConfig, CompilationConfig, CompilationLevel,
ModelConfig, PassConfig, SchedulerConfig, VllmConfig,
set_current_vllm_config)
from vllm.forward_context import get_forward_context, set_forward_context
from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
Fp8LinearOp)
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.v1.kv_cache_interface import AttentionSpec
FP8_DTYPE = current_platform.fp8_dtype()
# globals needed for string-import custom Dynamo backend field # globals needed for string-import custom Dynamo backend field
backend: Optional[TestBackend] = None backend: Optional[TestBackend] = None
@ -132,3 +147,235 @@ def test_attention_fusion(example_prompts, monkeypatch, model: str,
# Reset backend to make sure llm2 gets released # Reset backend to make sure llm2 gets released
backend = None backend = None
class TestAttentionStaticQuantPatternModel(torch.nn.Module):
"""Test model for AttentionStaticQuantPattern fusion."""
def __init__(self, num_qo_heads: int, num_kv_heads: int, head_size: int,
kv_cache_dtype: torch.dtype, device: torch.device,
vllm_config: VllmConfig):
super().__init__()
self.num_qo_heads = num_qo_heads
self.num_kv_heads = num_kv_heads
self.head_size = head_size
self.kv_cache_dtype = kv_cache_dtype
self.device = device
self.vllm_config = vllm_config
self.attn = Attention(
num_heads=self.num_qo_heads,
head_size=self.head_size,
scale=1.0 / (self.head_size**0.5),
num_kv_heads=self.num_kv_heads,
cache_config=vllm_config.cache_config,
prefix="model.layers.0.self_attn.attn",
)
self.fp8_linear = Fp8LinearOp(
act_quant_static=True, act_quant_group_shape=GroupShape.PER_TENSOR)
self.wscale = torch.tensor([1.0], dtype=torch.float32)
self.scale = torch.tensor([1.0], dtype=torch.float32)
self.block_size = 16
# Initialize attn MetadataBuilder
self.builder = self.attn.attn_backend.get_builder_cls()(
kv_cache_spec=AttentionSpec(
block_size=self.block_size,
num_kv_heads=self.num_kv_heads,
head_size=self.head_size,
dtype=self.kv_cache_dtype,
use_mla=False,
),
layer_names=[self.attn.layer_name],
vllm_config=self.vllm_config,
device=self.device,
)
def build_attn_metadata(self, batch_size: int):
"""Initialize attention metadata."""
# Create common attn metadata
batch_spec = BatchSpec(seq_lens=[1] * batch_size,
query_lens=[1] * batch_size)
common_attn_metadata = create_common_attn_metadata(
batch_spec,
self.block_size,
self.device,
arange_block_indices=True)
max_blocks = (max(batch_spec.seq_lens) + self.block_size -
1) // self.block_size
num_blocks = batch_size * max_blocks
# Create dummy KV cache for FlashInfer TRTLLM
# - NHD: [num_blocks, 2, block_size, num_kv_heads, head_size]
# - HND: [num_blocks, 2, num_kv_heads, block_size, head_size]
# Create kv_cache in HND layout and permute to NHD layout
# (later will be permuted back to HND layout in forward pass)
kv_cache = torch.zeros(num_blocks,
2,
self.num_kv_heads,
self.block_size,
self.head_size,
dtype=self.kv_cache_dtype,
device=self.device)
kv_cache = kv_cache.permute(0, 1, 3, 2, 4)
self.attn.kv_cache = [kv_cache]
# Build attn metadata
self.attn_metadata = self.builder.build(
common_prefix_len=0, common_attn_metadata=common_attn_metadata)
return self.attn_metadata
def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
w: torch.Tensor):
"""Forward pass that creates the pattern to be fused."""
attn_output = self.attn(q, k, v)
return self.fp8_linear.apply(input=attn_output,
weight=w,
weight_scale=self.wscale,
input_scale=self.scale)
@pytest.mark.parametrize("num_qo_heads, num_kv_heads", [(64, 8), (40, 8)])
@pytest.mark.parametrize("head_size", [128])
@pytest.mark.parametrize("batch_size", [7, 256, 533])
@pytest.mark.parametrize("dtype", [torch.bfloat16])
@pytest.mark.parametrize(
"model_name, quant_key",
[("nvidia/Llama-4-Scout-17B-16E-Instruct-FP8", kFp8StaticTensorSym)])
@pytest.mark.parametrize("backend", [_Backend.FLASHINFER])
@pytest.mark.skipif(not current_platform.is_cuda(), reason="Only test CUDA")
@pytest.mark.skipif(not current_platform.supports_fp8(), reason="Need FP8")
@pytest.mark.skipif(not current_platform.is_device_capability((10, 0)),
reason="Only test on SM100(Blackwell)")
def test_attention_quant_pattern(num_qo_heads: int, num_kv_heads: int,
head_size: int, batch_size: int,
dtype: torch.dtype, model_name: str,
quant_key: QuantKey, backend: _Backend,
monkeypatch, dist_init):
"""Test AttentionStaticQuantPattern fusion pass"""
monkeypatch.setenv("VLLM_USE_V1", "1")
device = torch.device("cuda:0")
torch.manual_seed(42)
vllm_config = VllmConfig(
model_config=ModelConfig(
model=model_name,
max_model_len=2048,
),
scheduler_config=SchedulerConfig(max_num_seqs=1024),
compilation_config=CompilationConfig(
level=CompilationLevel.PIECEWISE,
custom_ops=["+quant_fp8"],
),
cache_config=CacheConfig(cache_dtype="fp8"))
# Create test inputs
hidden_size = num_qo_heads * head_size
q = torch.randn(batch_size, hidden_size, dtype=dtype, device=device)
k = torch.randn(batch_size,
num_kv_heads * head_size,
dtype=dtype,
device=device)
v = torch.randn(batch_size,
num_kv_heads * head_size,
dtype=dtype,
device=device)
linear_w = torch.randn(hidden_size, hidden_size).to(FP8_DTYPE).t()
# Mark first dimension as dynamic for realistic testing
torch._dynamo.mark_dynamic(q, 0)
torch._dynamo.mark_dynamic(k, 0)
torch._dynamo.mark_dynamic(v, 0)
# Run model directly without compilation and fusion
vllm_config_unfused = copy.deepcopy(vllm_config)
with set_current_vllm_config(vllm_config_unfused), set_forward_context(
attn_metadata=None, vllm_config=vllm_config_unfused
), global_force_attn_backend_context_manager(backend):
model_unfused = TestAttentionStaticQuantPatternModel(
num_qo_heads, num_kv_heads, head_size, FP8_DTYPE, device,
vllm_config_unfused)
model_unfused = model_unfused.to(device)
forward_ctx = get_forward_context()
forward_ctx.attn_metadata = model_unfused.build_attn_metadata(
batch_size)
# Run model directly without compilation and fusion
result_unfused = model_unfused(q, k, v, linear_w)
# Run model with attn fusion enabled
vllm_config.compilation_config.pass_config = PassConfig(
enable_attn_fusion=True, enable_noop=True)
with set_current_vllm_config(vllm_config), set_forward_context(
attn_metadata=None, vllm_config=vllm_config
), global_force_attn_backend_context_manager(backend):
model_fused = TestAttentionStaticQuantPatternModel(
num_qo_heads, num_kv_heads, head_size, FP8_DTYPE, device,
vllm_config)
model_fused = model_fused.to(device)
forward_ctx = get_forward_context()
forward_ctx.attn_metadata = model_fused.build_attn_metadata(batch_size)
# Create test backend with fusion passes enabled
noop_pass = NoOpEliminationPass(vllm_config)
attn_pass = lambda *args, **kw: AttnFusionPass(vllm_config)(*args, **kw
)
test_backend = TestBackend(noop_pass, attn_pass)
# Compile model with fusion enabled
model_compiled = torch.compile(model_fused,
backend=test_backend,
fullgraph=True)
assert model_compiled.attn._o_scale_float is None
result_fused_1 = model_compiled(q, k, v, linear_w)
# After the 1st round of the forward pass, output quant scale should be
# loaded into the attn layer's _o_scale_float, the 2nd round should
# reuse the loaded _o_scale_float
assert model_compiled.attn._o_scale_float is not None
result_fused_2 = model_compiled(q, k, v, linear_w)
assert model_compiled.attn._o_scale_float is not None
# Check attn fusion support
attn_fusion_supported = [
layer.impl.fused_output_quant_supported(quant_key.dtype,
quant_key.static,
quant_key.group_shape) for key,
layer in vllm_config.compilation_config.static_forward_context.items()
]
if any(attn_fusion_supported):
# Check quantization ops in the graph before and after fusion
test_backend.check_before_ops([QUANT_OPS[quant_key]],
fully_replaced=True)
# Check attention ops in the graph before and after fusion
attn_nodes_pre = list(find_op_nodes(ATTN_OP, test_backend.graph_pre_pass))
attn_nodes_post = list(find_op_nodes(ATTN_OP,
test_backend.graph_post_pass))
assert len(attn_nodes_pre) > 0, "Should have attention nodes before fusion"
assert len(attn_nodes_pre) == len(attn_nodes_post), \
"Should have same number of attention nodes before and after fusion"
assert attn_nodes_pre[0].kwargs.get("output_scale") is None, \
"Attention should not have output_scale before fusion"
assert attn_nodes_post[0].kwargs.get("output_scale") is not None, \
"Attention should have output_scale after fusion"
# Check that results are closed
torch.testing.assert_close(result_unfused,
result_fused_1,
atol=1e-2,
rtol=1e-2)
torch.testing.assert_close(result_unfused,
result_fused_2,
atol=1e-2,
rtol=1e-2)

View File

@ -13,21 +13,7 @@ if not current_platform.is_device_capability(100):
allow_module_level=True) allow_module_level=True)
FLOAT32_BYTES = torch.finfo(torch.float).bits // 8 FLOAT32_BYTES = torch.finfo(torch.float).bits // 8
FP8_DTYPE = current_platform.fp8_dtype()
# KV Cache Layout for TRT-LLM
# kv_cache_shape = (num_blocks, 2, num_kv_heads, page_size, head_dim)
MAX_Q_LEN = 1024
MAX_KV_LEN = 4096
BATCH_SIZES = [4, 12]
NUM_HEADS = [(16, 16), (40, 8)]
HEAD_SIZES = [128]
BLOCK_SIZES = [16]
KV_LAYOUTS = ["HND"]
DTYPES = [torch.bfloat16]
KV_CACHE_DTYPES = [None, current_platform.fp8_dtype()]
NUM_BLOCKS = 32768 # Large enough to test overflow in index calculation.
SOFT_CAPS = [None, 50.0]
def to_float8(x, dtype=torch.float8_e4m3fn): def to_float8(x, dtype=torch.float8_e4m3fn):
@ -39,42 +25,59 @@ def to_float8(x, dtype=torch.float8_e4m3fn):
return x_scl_sat.to(dtype), scale.float().reciprocal() return x_scl_sat.to(dtype), scale.float().reciprocal()
@pytest.mark.parametrize("batch_size", BATCH_SIZES) DTYPE = [torch.bfloat16]
QUANT_DTYPES = [
# (q_quant_dtype, kv_quant_dtype, o_quant_dtype)
(None, None, None),
(FP8_DTYPE, FP8_DTYPE, FP8_DTYPE),
]
BATCH_SIZE = [4, 12]
MAX_SEQ_LENS = [(1024, 4096)]
NUM_HEADS = [(64, 8), (40, 8)]
HEAD_SIZE = [128]
KV_LAYOUT = ["HND"] # currently only HND is supported
BLOCK_SIZE = [16]
SOFT_CAP = [None, 50.0]
NUM_BLOCKS = 32768 # Large enough to test overflow in index calculation.
@pytest.mark.parametrize("dtype", DTYPE)
@pytest.mark.parametrize("quant_dtypes", QUANT_DTYPES)
@pytest.mark.parametrize("batch_size", BATCH_SIZE)
@pytest.mark.parametrize("max_seq_lens", MAX_SEQ_LENS)
@pytest.mark.parametrize("num_heads", NUM_HEADS) @pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES) @pytest.mark.parametrize("head_size", HEAD_SIZE)
@pytest.mark.parametrize("block_size", BLOCK_SIZES) @pytest.mark.parametrize("kv_layout", KV_LAYOUT)
@pytest.mark.parametrize("kv_layout", KV_LAYOUTS) @pytest.mark.parametrize("block_size", BLOCK_SIZE)
@pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("soft_cap", SOFT_CAP)
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPES)
@pytest.mark.parametrize("soft_cap", SOFT_CAPS)
@torch.inference_mode @torch.inference_mode
def test_flashinfer_trtllm_decode_with_baseline( def test_flashinfer_trtllm_decode_with_baseline(
dtype: torch.dtype,
quant_dtypes: tuple[Optional[torch.dtype], Optional[torch.dtype],
Optional[torch.dtype]],
batch_size: int, batch_size: int,
max_seq_lens: tuple[int, int],
num_heads: tuple[int, int], num_heads: tuple[int, int],
head_size: int, head_size: int,
block_size: int,
kv_layout: str, kv_layout: str,
dtype: torch.dtype, block_size: int,
kv_cache_dtype: Optional[torch.dtype],
soft_cap: Optional[float], soft_cap: Optional[float],
) -> None: ) -> None:
kv_cache_dtype = dtype if kv_cache_dtype is None else kv_cache_dtype
torch.set_default_device("cuda") torch.set_default_device("cuda")
current_platform.seed_everything(0) current_platform.seed_everything(0)
kv_lens = torch.randint(1, MAX_KV_LEN, (batch_size, ), dtype=torch.int32) q_quant_dtype, kv_quant_dtype, o_quant_dtype = quant_dtypes
kv_lens[-1] = MAX_KV_LEN q_quant_dtype = q_quant_dtype or dtype
max_kv_len = torch.max(kv_lens).item() kv_quant_dtype = kv_quant_dtype or dtype
num_seqs = len(kv_lens) o_quant_dtype = o_quant_dtype or dtype
num_query_heads = num_heads[0] _, max_kv_len = max_seq_lens
num_kv_heads = num_heads[1]
assert num_query_heads % num_kv_heads == 0
scale = head_size**-0.5 num_qo_heads, num_kv_heads = num_heads
assert num_qo_heads % num_kv_heads == 0
query = torch.randn(num_seqs, num_query_heads, head_size, dtype=dtype) sm_scale = float(1.0 / (head_size**0.5))
kv_cache_shape = None kv_cache_shape = None
if kv_layout == "NHD": if kv_layout == "NHD":
@ -83,156 +86,39 @@ def test_flashinfer_trtllm_decode_with_baseline(
kv_cache_shape = (NUM_BLOCKS, 2, num_kv_heads, block_size, head_size) kv_cache_shape = (NUM_BLOCKS, 2, num_kv_heads, block_size, head_size)
else: else:
raise ValueError(f"Invalid kv_layout: {kv_layout}") raise ValueError(f"Invalid kv_layout: {kv_layout}")
key_value_cache = torch.randn(kv_cache_shape, dtype=dtype)
kv_scale = 1.0
if kv_cache_dtype is current_platform.fp8_dtype():
key_value_cache, kv_scale = to_float8(key_value_cache,
current_platform.fp8_dtype())
max_num_blocks_per_seq = (max_kv_len + block_size - 1) // block_size query = torch.randn(batch_size, num_qo_heads, head_size, dtype=dtype)
block_tables = torch.randint(0, if q_quant_dtype == FP8_DTYPE:
NUM_BLOCKS, query, q_scale = to_float8(query)
(num_seqs, max_num_blocks_per_seq), ref_query = query.to(dtype) * q_scale
dtype=torch.int32) else:
k_scale = v_scale = kv_scale q_scale = 1.0
kv_indptr = [0] ref_query = query
kv_indices = []
kv_last_page_lens = []
for i in range(num_seqs):
seq_len = kv_lens[i]
assert seq_len > 0
num_blocks = (seq_len + block_size - 1) // block_size
kv_indices.extend(block_tables[i, :num_blocks])
kv_indptr.append(kv_indptr[-1] + num_blocks)
kv_last_page_len = seq_len % block_size
if kv_last_page_len == 0:
kv_last_page_len = block_size
kv_last_page_lens.append(kv_last_page_len)
kv_indptr = torch.tensor(kv_indptr, dtype=torch.int32) kv_lens = torch.randint(1, max_kv_len, (batch_size, ), dtype=torch.int32)
kv_indices = torch.tensor(kv_indices, dtype=torch.int32) kv_lens[-1] = max_kv_len
kv_last_page_lens = torch.tensor(kv_last_page_lens, dtype=torch.int32)
workspace_buffer = torch.zeros(128 * 1024 * 1024, dtype=torch.int8) seq_lens = kv_lens
wrapper = flashinfer.BatchDecodeWithPagedKVCacheWrapper(
workspace_buffer,
kv_layout,
use_tensor_cores=((num_query_heads // num_kv_heads) > 4))
wrapper.plan(kv_indptr,
kv_indices,
kv_last_page_lens,
num_query_heads,
num_kv_heads,
head_size,
block_size,
"NONE",
sm_scale=scale,
q_data_type=dtype,
kv_data_type=kv_cache_dtype,
logits_soft_cap=soft_cap)
output = torch.empty(query.shape, dtype=dtype)
wrapper.run(query,
key_value_cache,
k_scale=k_scale,
v_scale=v_scale,
out=output)
# TRTLLM Decode
kv_lens_tensor = torch.tensor(kv_lens, dtype=torch.int32)
output_trtllm = torch.empty(query.shape, dtype=dtype)
flashinfer.decode.trtllm_batch_decode_with_kv_cache(
query=query.contiguous(),
kv_cache=key_value_cache,
workspace_buffer=workspace_buffer,
block_tables=block_tables,
seq_lens=kv_lens_tensor,
max_seq_len=max_kv_len,
bmm1_scale=k_scale * scale,
bmm2_scale=v_scale,
out=output_trtllm,
)
torch.testing.assert_close(output, output_trtllm, atol=1e-2, rtol=1e-2), \
f"{torch.max(torch.abs(output - output_trtllm))}"
@pytest.mark.parametrize("batch_size", BATCH_SIZES)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
@pytest.mark.parametrize("kv_layout", KV_LAYOUTS)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPES)
@pytest.mark.parametrize("soft_cap", [None])
@torch.inference_mode
def test_flashinfer_trtllm_prefill_with_baseline(
batch_size: int,
num_heads: tuple[int, int],
head_size: int,
block_size: int,
kv_layout: str,
dtype: torch.dtype,
kv_cache_dtype: Optional[torch.dtype],
soft_cap: Optional[float],
) -> None:
kv_cache_dtype = dtype if kv_cache_dtype is None else kv_cache_dtype
if dtype != kv_cache_dtype:
pytest.skip(f"Not supported dtype({dtype}) with "
"kv_cache_dtype({kv_cache_dtype})")
torch.set_default_device("cuda")
current_platform.seed_everything(0)
q_lens = torch.randint(1, MAX_Q_LEN, (batch_size, ), dtype=torch.int32)
q_lens[-1] = MAX_Q_LEN
max_q_len = torch.max(q_lens).item()
q_indptr = torch.cat([
torch.tensor([0], dtype=torch.int32),
torch.cumsum(q_lens, dim=0, dtype=torch.int32),
])
kv_lens = torch.randint(0, MAX_KV_LEN, (batch_size, ), dtype=torch.int32)
kv_lens[-1] = MAX_KV_LEN
seq_lens = kv_lens + q_lens
max_seq_len = torch.max(seq_lens).item() max_seq_len = torch.max(seq_lens).item()
num_seqs = len(seq_lens)
num_query_heads = num_heads[0] kv_cache = torch.randn(kv_cache_shape, dtype=dtype)
num_kv_heads = num_heads[1] if kv_quant_dtype == FP8_DTYPE:
assert num_query_heads % num_kv_heads == 0 kv_cache, kv_scale = to_float8(kv_cache)
ref_kv_cache = kv_cache.to(dtype) * kv_scale
scale = head_size**-0.5
query = torch.randn(torch.sum(q_lens).item(),
num_query_heads,
head_size,
dtype=dtype)
kv_cache_shape = None
if kv_layout == "NHD":
kv_cache_shape = (NUM_BLOCKS, 2, block_size, num_kv_heads, head_size)
elif kv_layout == "HND":
kv_cache_shape = (NUM_BLOCKS, 2, num_kv_heads, block_size, head_size)
else: else:
raise ValueError(f"Invalid kv_layout: {kv_layout}") kv_scale = 1.0
key_value_cache = torch.randn(kv_cache_shape, dtype=dtype) ref_kv_cache = kv_cache
kv_scale = 1.0 k_scale = v_scale = kv_scale
if kv_cache_dtype is current_platform.fp8_dtype():
key_value_cache, kv_scale = to_float8(key_value_cache,
current_platform.fp8_dtype())
max_num_blocks_per_seq = (max_seq_len + block_size - 1) // block_size max_num_blocks_per_seq = (max_seq_len + block_size - 1) // block_size
block_tables = torch.randint(0, block_tables = torch.randint(0,
NUM_BLOCKS, NUM_BLOCKS,
(num_seqs, max_num_blocks_per_seq), (batch_size, max_num_blocks_per_seq),
dtype=torch.int32) dtype=torch.int32)
k_scale = v_scale = kv_scale
kv_indptr = [0] kv_indptr = [0]
kv_indices = [] kv_indices = []
kv_last_page_lens = [] kv_last_page_lens = []
for i in range(num_seqs): for i in range(batch_size):
seq_len = seq_lens[i] seq_len = seq_lens[i]
assert seq_len > 0 assert seq_len > 0
num_blocks = (seq_len + block_size - 1) // block_size num_blocks = (seq_len + block_size - 1) // block_size
@ -246,48 +132,206 @@ def test_flashinfer_trtllm_prefill_with_baseline(
kv_indptr = torch.tensor(kv_indptr, dtype=torch.int32) kv_indptr = torch.tensor(kv_indptr, dtype=torch.int32)
kv_indices = torch.tensor(kv_indices, dtype=torch.int32) kv_indices = torch.tensor(kv_indices, dtype=torch.int32)
kv_last_page_lens = torch.tensor(kv_last_page_lens, dtype=torch.int32) kv_last_page_lens = torch.tensor(kv_last_page_lens, dtype=torch.int32)
workspace_buffer = torch.zeros(128 * 1024 * 1024, dtype=torch.int8) workspace_buffer = torch.zeros(128 * 1024 * 1024, dtype=torch.int8)
# Baseline Decode
wrapper = flashinfer.BatchDecodeWithPagedKVCacheWrapper(
workspace_buffer,
kv_layout,
use_tensor_cores=((num_qo_heads // num_kv_heads) > 4))
wrapper.plan(kv_indptr,
kv_indices,
kv_last_page_lens,
num_qo_heads,
num_kv_heads,
head_size,
block_size,
"NONE",
sm_scale=sm_scale,
q_data_type=dtype,
kv_data_type=dtype,
logits_soft_cap=soft_cap)
output = torch.empty(ref_query.shape, dtype=dtype)
wrapper.run(ref_query, ref_kv_cache, out=output)
o_scale = 1.0
if o_quant_dtype == FP8_DTYPE:
_, o_scale = to_float8(output)
# TRTLLM Decode
output_trtllm = torch.empty(query.shape, dtype=o_quant_dtype)
flashinfer.decode.trtllm_batch_decode_with_kv_cache(
query=query,
kv_cache=kv_cache,
workspace_buffer=workspace_buffer,
block_tables=block_tables,
seq_lens=seq_lens,
max_seq_len=max_seq_len,
bmm1_scale=q_scale * k_scale * sm_scale,
bmm2_scale=v_scale / o_scale,
out=output_trtllm,
)
if o_quant_dtype == FP8_DTYPE:
output_trtllm = output_trtllm.to(dtype) * o_scale
if q_quant_dtype == FP8_DTYPE and o_quant_dtype == FP8_DTYPE:
rtol, atol = 5e-2, 7e-2
else:
rtol, atol = 1e-2, 1e-2
torch.testing.assert_close(output, output_trtllm, atol=atol, rtol=rtol), \
f"{torch.max(torch.abs(output - output_trtllm))}"
@pytest.mark.parametrize("dtype", DTYPE)
@pytest.mark.parametrize("quant_dtypes", QUANT_DTYPES)
@pytest.mark.parametrize("batch_size", BATCH_SIZE)
@pytest.mark.parametrize("max_seq_lens", MAX_SEQ_LENS)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZE)
@pytest.mark.parametrize("kv_layout", KV_LAYOUT)
@pytest.mark.parametrize("block_size", BLOCK_SIZE)
@pytest.mark.parametrize("soft_cap", [None])
@torch.inference_mode
def test_flashinfer_trtllm_prefill_with_baseline(
dtype: torch.dtype,
quant_dtypes: tuple[Optional[torch.dtype], Optional[torch.dtype],
Optional[torch.dtype]],
batch_size: int,
max_seq_lens: tuple[int, int],
num_heads: tuple[int, int],
head_size: int,
kv_layout: str,
block_size: int,
soft_cap: Optional[float],
) -> None:
torch.set_default_device("cuda")
current_platform.seed_everything(0)
q_quant_dtype, kv_quant_dtype, o_quant_dtype = quant_dtypes
q_quant_dtype = q_quant_dtype or dtype
kv_quant_dtype = kv_quant_dtype or dtype
o_quant_dtype = o_quant_dtype or dtype
max_q_len, max_kv_len = max_seq_lens
num_qo_heads, num_kv_heads = num_heads
assert num_qo_heads % num_kv_heads == 0
sm_scale = float(1.0 / (head_size**0.5))
kv_cache_shape = None
if kv_layout == "NHD":
kv_cache_shape = (NUM_BLOCKS, 2, block_size, num_kv_heads, head_size)
elif kv_layout == "HND":
kv_cache_shape = (NUM_BLOCKS, 2, num_kv_heads, block_size, head_size)
else:
raise ValueError(f"Invalid kv_layout: {kv_layout}")
q_lens = torch.randint(1, max_q_len, (batch_size, ), dtype=torch.int32)
q_lens[-1] = max_q_len
q_indptr = torch.cat([
torch.tensor([0], dtype=torch.int32),
torch.cumsum(q_lens, dim=0, dtype=torch.int32),
])
query = torch.randn(torch.sum(q_lens).item(),
num_qo_heads,
head_size,
dtype=dtype)
if q_quant_dtype == FP8_DTYPE:
query, q_scale = to_float8(query)
ref_query = query.to(dtype) * q_scale
else:
q_scale = 1.0
ref_query = query
kv_lens = torch.randint(0, max_kv_len, (batch_size, ), dtype=torch.int32)
kv_lens[-1] = max_kv_len
seq_lens = kv_lens + q_lens
max_seq_len = torch.max(seq_lens).item()
kv_cache = torch.randn(kv_cache_shape, dtype=dtype)
if kv_quant_dtype == FP8_DTYPE:
kv_cache, kv_scale = to_float8(kv_cache)
ref_kv_cache = kv_cache.to(dtype) * kv_scale
else:
kv_scale = 1.0
ref_kv_cache = kv_cache
k_scale = v_scale = kv_scale
max_num_blocks_per_seq = (max_seq_len + block_size - 1) // block_size
block_tables = torch.randint(0,
NUM_BLOCKS,
(batch_size, max_num_blocks_per_seq),
dtype=torch.int32)
kv_indptr = [0]
kv_indices = []
kv_last_page_lens = []
for i in range(batch_size):
seq_len = seq_lens[i]
assert seq_len > 0
num_blocks = (seq_len + block_size - 1) // block_size
kv_indices.extend(block_tables[i, :num_blocks])
kv_indptr.append(kv_indptr[-1] + num_blocks)
kv_last_page_len = seq_len % block_size
if kv_last_page_len == 0:
kv_last_page_len = block_size
kv_last_page_lens.append(kv_last_page_len)
kv_indptr = torch.tensor(kv_indptr, dtype=torch.int32)
kv_indices = torch.tensor(kv_indices, dtype=torch.int32)
kv_last_page_lens = torch.tensor(kv_last_page_lens, dtype=torch.int32)
workspace_buffer = torch.zeros(128 * 1024 * 1024, dtype=torch.int8)
# Baseline Prefill
wrapper = flashinfer.BatchPrefillWithPagedKVCacheWrapper( wrapper = flashinfer.BatchPrefillWithPagedKVCacheWrapper(
workspace_buffer, kv_layout) workspace_buffer, kv_layout)
wrapper.plan(q_indptr, wrapper.plan(q_indptr,
kv_indptr, kv_indptr,
kv_indices, kv_indices,
kv_last_page_lens, kv_last_page_lens,
num_query_heads, num_qo_heads,
num_kv_heads, num_kv_heads,
head_size, head_size,
block_size, block_size,
causal=True, causal=True,
sm_scale=scale, sm_scale=sm_scale,
q_data_type=dtype, q_data_type=dtype,
kv_data_type=kv_cache_dtype, kv_data_type=dtype,
logits_soft_cap=soft_cap) logits_soft_cap=soft_cap)
output = torch.empty(query.shape, dtype=dtype) output = torch.empty(ref_query.shape, dtype=dtype)
wrapper.run(query, wrapper.run(ref_query, ref_kv_cache, out=output)
key_value_cache, o_scale = 1.0
k_scale=k_scale, if o_quant_dtype == FP8_DTYPE:
v_scale=v_scale, _, o_scale = to_float8(output)
out=output)
# TRTLLM Decode # TRTLLM Prefill
output_trtllm = torch.empty(query.shape, dtype=dtype) output_trtllm = torch.empty(query.shape, dtype=o_quant_dtype)
flashinfer.prefill.trtllm_batch_context_with_kv_cache( flashinfer.prefill.trtllm_batch_context_with_kv_cache(
query=query.contiguous(), query=query,
kv_cache=key_value_cache, kv_cache=kv_cache,
workspace_buffer=workspace_buffer, workspace_buffer=workspace_buffer,
block_tables=block_tables, block_tables=block_tables,
seq_lens=seq_lens, seq_lens=seq_lens,
max_q_len=max_q_len, max_q_len=max_q_len,
max_kv_len=max_seq_len, max_kv_len=max_seq_len,
bmm1_scale=k_scale * scale, bmm1_scale=q_scale * k_scale * sm_scale,
bmm2_scale=v_scale, bmm2_scale=v_scale / o_scale,
batch_size=num_seqs, batch_size=batch_size,
cum_seq_lens_q=q_indptr, cum_seq_lens_q=q_indptr,
cum_seq_lens_kv=kv_indptr, cum_seq_lens_kv=kv_indptr,
out=output_trtllm, out=output_trtllm,
) )
if o_quant_dtype == FP8_DTYPE:
output_trtllm = output_trtllm.to(dtype) * o_scale
torch.testing.assert_close(output, output_trtllm, atol=1e-2, rtol=1e-2), \ if q_quant_dtype == FP8_DTYPE and o_quant_dtype == FP8_DTYPE:
rtol, atol = 5e-2, 7e-2
else:
rtol, atol = 1e-2, 1e-2
torch.testing.assert_close(output, output_trtllm, atol=atol, rtol=rtol), \
f"{torch.max(torch.abs(output - output_trtllm))}" f"{torch.max(torch.abs(output - output_trtllm))}"

View File

@ -128,11 +128,17 @@ class Attention(nn.Module):
self._q_scale = torch.tensor(1.0, dtype=torch.float32) self._q_scale = torch.tensor(1.0, dtype=torch.float32)
self._prob_scale = torch.tensor(1.0, dtype=torch.float32) self._prob_scale = torch.tensor(1.0, dtype=torch.float32)
# We also keep the float32 versions of k/v_scale for attention # We also keep q/k/v_scale on host (cpu) memory for attention
# backends that don't support tensors (Flashinfer) # backends that require the scales to be on host instead of on device.
# e.g. Flashinfer
self._q_scale_float = 1.0
self._k_scale_float = 1.0 self._k_scale_float = 1.0
self._v_scale_float = 1.0 self._v_scale_float = 1.0
# The output scale on host memory. This should be the input scale of
# the quant op after this attention layer.
self._o_scale_float: Optional[float] = None
self.use_mla = use_mla self.use_mla = use_mla
self.num_heads = num_heads self.num_heads = num_heads
self.head_size = head_size self.head_size = head_size
@ -291,6 +297,7 @@ class Attention(nn.Module):
self._q_scale.copy_(torch.abs(query).max() / self.q_range) self._q_scale.copy_(torch.abs(query).max() / self.q_range)
self._k_scale.copy_(torch.abs(key).max() / self.k_range) self._k_scale.copy_(torch.abs(key).max() / self.k_range)
self._v_scale.copy_(torch.abs(value).max() / self.v_range) self._v_scale.copy_(torch.abs(value).max() / self.v_range)
self._q_scale_float = self._q_scale.item()
self._k_scale_float = self._k_scale.item() self._k_scale_float = self._k_scale.item()
self._v_scale_float = self._v_scale.item() self._v_scale_float = self._v_scale.item()
# We only calculate the scales once # We only calculate the scales once

View File

@ -9,7 +9,7 @@ from torch._subclasses.fake_tensor import (FakeTensorMode,
unset_fake_temporarily) unset_fake_temporarily)
from vllm.attention import Attention from vllm.attention import Attention
from vllm.config import VllmConfig from vllm.config import VllmConfig, get_layers_from_vllm_config
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.platforms import current_platform from vllm.platforms import current_platform
@ -18,23 +18,32 @@ from .vllm_inductor_pass import VllmInductorPass
logger = init_logger(__name__) logger = init_logger(__name__)
FP8_DTYPE = current_platform.fp8_dtype()
ATTN_OP = torch.ops.vllm.unified_attention_with_output.default ATTN_OP = torch.ops.vllm.unified_attention_with_output.default
RESHAPE_OP = torch.ops.aten.reshape.default RESHAPE_OP = torch.ops.aten.reshape.default
class AttentionStaticQuantPattern: class AttentionStaticQuantPattern:
"""
Fusion for Attention+StaticQuant.
Only triggers when the attention implementation returns True in
`fused_output_quant_supported()`. If the pattern is found, the StaticQuant
op will be removed from the graph, and its scale will be passed into
Attention op as the `output_scale` argument.
"""
def __init__( def __init__(
self, self,
layer_name: str, layer: Attention,
num_heads: int,
head_size: int,
quant_dtype: torch.dtype, quant_dtype: torch.dtype,
symmetric=True, symmetric=True,
): ):
self.layer_name = layer_name self.layer = layer
self.num_heads = num_heads self.layer_name = layer.layer_name
self.head_size = head_size self.num_heads = layer.num_heads
self.head_size = layer.head_size
self.quant_dtype = quant_dtype self.quant_dtype = quant_dtype
self.quant_key = QuantKey(dtype=quant_dtype, self.quant_key = QuantKey(dtype=quant_dtype,
static=True, static=True,
@ -48,11 +57,10 @@ class AttentionStaticQuantPattern:
kwargs = {'dtype': self.quant_dtype, 'device': "cuda", **kwargs} kwargs = {'dtype': self.quant_dtype, 'device': "cuda", **kwargs}
return torch.empty(*args, **kwargs) return torch.empty(*args, **kwargs)
def register_if_supported(self, pm_pass: PatternMatcherPass, def register_if_supported(self, pm_pass: PatternMatcherPass):
layer: Attention): if self.layer.impl.fused_output_quant_supported(
if layer.impl.fused_output_quant_supported(self.quant_dtype, self.quant_dtype, self.quant_key.static,
self.quant_key.static, self.quant_key.group_shape):
self.quant_key.group_shape):
self._register(pm_pass) self._register(pm_pass)
def _register(self, pm_pass: PatternMatcherPass): def _register(self, pm_pass: PatternMatcherPass):
@ -60,19 +68,15 @@ class AttentionStaticQuantPattern:
def pattern(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, def pattern(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
output_attn: torch.Tensor, output_quant: torch.Tensor, output_attn: torch.Tensor, output_quant: torch.Tensor,
scale: torch.Tensor): scale: torch.Tensor):
view_7 = RESHAPE_OP(output_attn,
[-1, self.num_heads, self.head_size])
at1 = auto_functionalized(ATTN_OP, at1 = auto_functionalized(ATTN_OP,
query=q, query=q,
key=k, key=k,
value=v, value=v,
output=view_7, output=output_attn,
layer_name=self.layer_name, layer_name=self.layer_name,
output_scale=None) output_scale=None)
attn_out_view = RESHAPE_OP(at1[1], attn_out_view = RESHAPE_OP(at1[1],
[-1, self.num_heads * self.head_size]) [-1, self.num_heads * self.head_size])
at2 = auto_functionalized(self.QUANT_OP, at2 = auto_functionalized(self.QUANT_OP,
result=output_quant, result=output_quant,
input=attn_out_view, input=attn_out_view,
@ -82,17 +86,19 @@ class AttentionStaticQuantPattern:
def replacement(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, def replacement(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
output_attn: torch.Tensor, output_quant: torch.Tensor, output_attn: torch.Tensor, output_quant: torch.Tensor,
scale: torch.Tensor): scale: torch.Tensor):
view_7 = RESHAPE_OP(output_quant, # attn output in quant_dtype
[-1, self.num_heads, self.head_size]) output_attn = torch.ops.aten.full.default(
[q.shape[0], self.num_heads, self.head_size],
0.0,
dtype=self.quant_dtype,
device=q.device)
at1 = auto_functionalized(ATTN_OP, at1 = auto_functionalized(ATTN_OP,
query=q, query=q,
key=k, key=k,
value=v, value=v,
output=view_7, output=output_attn,
layer_name=self.layer_name, layer_name=self.layer_name,
output_scale=scale) output_scale=scale)
return RESHAPE_OP(at1[1], [-1, self.num_heads * self.head_size]) return RESHAPE_OP(at1[1], [-1, self.num_heads * self.head_size])
# Need custom fake mode, otherwise tracing happens with real tensors. # Need custom fake mode, otherwise tracing happens with real tensors.
@ -102,7 +108,7 @@ class AttentionStaticQuantPattern:
empty_bf16(5, self.num_heads, self.head_size), # q empty_bf16(5, self.num_heads, self.head_size), # q
empty_bf16(5, self.num_heads, self.head_size), # k empty_bf16(5, self.num_heads, self.head_size), # k
empty_bf16(5, self.num_heads, self.head_size), # v empty_bf16(5, self.num_heads, self.head_size), # v
empty_bf16(5, self.num_heads * self.head_size), # attn_output empty_bf16(5, self.num_heads, self.head_size), # attn_output
self.empty_quant(5, self.num_heads * self.empty_quant(5, self.num_heads *
self.head_size), # quant_output self.head_size), # quant_output
empty_fp32(1, 1) # scale empty_fp32(1, 1) # scale
@ -140,27 +146,30 @@ class AttnFusionPass(VllmInductorPass):
def __init__(self, config: VllmConfig): def __init__(self, config: VllmConfig):
super().__init__(config) super().__init__(config)
self.static_fwd_ctx = config.compilation_config.static_forward_context
self.patterns = PatternMatcherPass(pass_name="attn_fusion_pass") self.patterns = PatternMatcherPass(pass_name="attn_fusion_pass")
for key, layer in self.static_fwd_ctx.items(): attn_layers = get_layers_from_vllm_config(config, Attention)
pattern = AttentionStaticQuantPattern(key, layer.num_heads, for layer_name, layer in attn_layers.items():
layer.head_size, pattern = AttentionStaticQuantPattern(layer, FP8_DTYPE)
current_platform.fp8_dtype()) pattern.register_if_supported(self.patterns)
pattern.register_if_supported(self.patterns, layer) if len(attn_layers) == 0:
if len(self.static_fwd_ctx) == 0:
logger.warning( logger.warning(
"Attention + quant fusion is enabled, but " "Attention + quant fusion is enabled, but no attention layers "
"CompilationConfig.static_forward_context is empty. " "were found in CompilationConfig.static_forward_context "
"Cannot access attention layers so no fusion " "so no fusion patterns were registered.")
"patterns were registered.")
def __call__(self, graph: torch.fx.graph.Graph) -> None: def __call__(self, graph: torch.fx.graph.Graph) -> None:
self.begin() self.begin()
self.dump_graph(graph, "before_attn_fusion") self.dump_graph(graph, "before_attn_fusion")
count = self.patterns.apply(graph) count = self.patterns.apply(graph)
# TODO: Move this to pass_manager.py after the fx graph broken issue
# has been resolved.
# see https://github.com/vllm-project/vllm/issues/23091
graph.eliminate_dead_code()
logger.debug("Fused quantization onto %s attention nodes", count) logger.debug("Fused quantization onto %s attention nodes", count)
self.dump_graph(graph, "after_attn_fusion") self.dump_graph(graph, "after_attn_fusion")
self.end_and_log() self.end_and_log()

View File

@ -174,21 +174,30 @@ def supports_trtllm_attention() -> tuple[bool, Optional[str]]:
def use_trtllm_attention( def use_trtllm_attention(
num_qo_heads: int,
num_kv_heads: int,
num_tokens: int, num_tokens: int,
max_seq_len: int, max_seq_len: int,
kv_cache_dtype: str, kv_cache_dtype: str,
num_qo_heads: Optional[int], q_dtype: torch.dtype,
num_kv_heads: Optional[int], is_prefill: bool,
attn_head_size: Optional[int],
has_sinks: bool = False, has_sinks: bool = False,
) -> bool: ) -> bool:
use_trtllm, env_value = supports_trtllm_attention() use_trtllm, env_value = supports_trtllm_attention()
if not use_trtllm: if not use_trtllm:
return False return False
# Check if the dimensions are supported by TRTLLM decode attention if num_qo_heads % num_kv_heads != 0:
if (attn_head_size is None or num_qo_heads is None or num_kv_heads is None return False
or num_qo_heads % num_kv_heads != 0):
# Must use TRTLLM attention if query is FP8 quantized
if q_dtype == current_platform.fp8_dtype():
logger.info_once("Using TRTLLM attention (query is quantized).")
return True
# TRTLLM prefill attention does not support FP8 kv cache with
# non-quantized query
if is_prefill and kv_cache_dtype.startswith("fp8"):
return False return False
# If sinks are being used, we must use TRTLLM attention as it's # If sinks are being used, we must use TRTLLM attention as it's
@ -290,6 +299,7 @@ __all__ = [
"has_flashinfer_moe", "has_flashinfer_moe",
"has_flashinfer_cutlass_fused_moe", "has_flashinfer_cutlass_fused_moe",
"has_nvidia_artifactory", "has_nvidia_artifactory",
"supports_trtllm_attention",
"use_trtllm_attention", "use_trtllm_attention",
"flashinfer_scaled_fp4_mm", "flashinfer_scaled_fp4_mm",
] ]

View File

@ -15,12 +15,17 @@ from flashinfer.decode import (_get_range_buf, get_seq_lens,
from flashinfer.prefill import trtllm_batch_context_with_kv_cache from flashinfer.prefill import trtllm_batch_context_with_kv_cache
import vllm.envs as envs import vllm.envs as envs
from vllm import _custom_ops as ops
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionType) AttentionType)
from vllm.config import CUDAGraphMode, VllmConfig from vllm.config import CUDAGraphMode, VllmConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape)
from vllm.platforms import current_platform
from vllm.utils import cdiv, is_pin_memory_available from vllm.utils import cdiv, is_pin_memory_available
from vllm.utils.flashinfer import use_trtllm_attention from vllm.utils.flashinfer import (supports_trtllm_attention,
use_trtllm_attention)
from vllm.v1.attention.backends.flash_attn import use_cascade_attention from vllm.v1.attention.backends.flash_attn import use_cascade_attention
# yapf conflicts with isort for this block # yapf conflicts with isort for this block
# yapf: disable # yapf: disable
@ -35,6 +40,8 @@ from vllm.v1.kv_cache_interface import AttentionSpec
FLASHINFER_WORKSPACE_BUFFER_SIZE = 256 * 1024 * 1024 FLASHINFER_WORKSPACE_BUFFER_SIZE = 256 * 1024 * 1024
FP8_DTYPE = current_platform.fp8_dtype()
logger = init_logger(__name__) logger = init_logger(__name__)
@ -519,22 +526,27 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
else: else:
kv_cache_dtype = self.kv_cache_spec.dtype kv_cache_dtype = self.kv_cache_spec.dtype
num_qo_heads = self.vllm_config.model_config.get_num_attention_heads( config = self.vllm_config
self.vllm_config.parallel_config) num_qo_heads = config.model_config.get_num_attention_heads(
config.parallel_config)
num_kv_heads = self.kv_cache_spec.num_kv_heads num_kv_heads = self.kv_cache_spec.num_kv_heads
head_dim = self.kv_cache_spec.head_size head_dim = self.kv_cache_spec.head_size
# Check if any layer uses sinks (requires TRTLLM attention) # Check if any layer uses sinks (requires TRTLLM attention)
has_sinks = self.global_hyperparameters.has_sinks has_sinks = self.global_hyperparameters.has_sinks
# currently prefill trtllm attention does not support fp8 kv cache # Insert FP8 quant for query if FP8 kv cache and attn fusion enabled
prefill_use_trtllm = not cache_dtype.startswith("fp8") \ q_dtype = config.model_config.dtype
and use_trtllm_attention( enable_fusion = config.compilation_config.pass_config.enable_attn_fusion
num_prefill_tokens, max_seq_len, cache_dtype, if cache_dtype.startswith("fp8") and enable_fusion:
num_qo_heads, num_kv_heads, head_dim, has_sinks) q_dtype = kv_cache_dtype
prefill_use_trtllm = use_trtllm_attention(
num_qo_heads, num_kv_heads, num_prefill_tokens, max_seq_len,
cache_dtype, q_dtype, is_prefill=True, has_sinks=has_sinks)
decode_use_trtllm = use_trtllm_attention( decode_use_trtllm = use_trtllm_attention(
num_decode_tokens, max_seq_len, cache_dtype, num_qo_heads, num_kv_heads, num_decode_tokens, max_seq_len,
num_qo_heads, num_kv_heads, head_dim, has_sinks) cache_dtype, q_dtype, is_prefill=False, has_sinks=has_sinks)
attn_metadata = FlashInferMetadata( attn_metadata = FlashInferMetadata(
num_actual_tokens=num_actual_tokens, num_actual_tokens=num_actual_tokens,
@ -548,7 +560,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
head_dim=head_dim, head_dim=head_dim,
page_size=page_size, page_size=page_size,
kv_data_type=kv_cache_dtype, kv_data_type=kv_cache_dtype,
q_data_type=self.vllm_config.model_config.dtype, q_data_type=q_dtype,
slot_mapping=common_attn_metadata.slot_mapping, slot_mapping=common_attn_metadata.slot_mapping,
max_q_len=max_q_len, max_q_len=max_q_len,
max_seq_len=max_seq_len, max_seq_len=max_seq_len,
@ -622,6 +634,8 @@ class FlashInferImpl(AttentionImpl):
self.sliding_window = (-1, -1) self.sliding_window = (-1, -1)
else: else:
self.sliding_window = (sliding_window - 1, 0) self.sliding_window = (sliding_window - 1, 0)
self.window_left = (self.sliding_window[0]
if self.sliding_window is not None else -1)
self.kv_cache_dtype = kv_cache_dtype self.kv_cache_dtype = kv_cache_dtype
self.logits_soft_cap = logits_soft_cap self.logits_soft_cap = logits_soft_cap
self.kv_sharing_target_layer_name = kv_sharing_target_layer_name self.kv_sharing_target_layer_name = kv_sharing_target_layer_name
@ -644,6 +658,19 @@ class FlashInferImpl(AttentionImpl):
) )
self.sinks = sinks self.sinks = sinks
self.support_trtllm_attn = (supports_trtllm_attention() and
num_heads % num_kv_heads == 0)
self.bmm1_scale: Optional[float] = None
self.bmm2_scale: Optional[float] = None
def fused_output_quant_supported(self, dtype: torch.dtype, static: bool,
group_shape: GroupShape):
supported_quant_type = (dtype == FP8_DTYPE and static and
group_shape == GroupShape.PER_TENSOR)
return (self.support_trtllm_attn
and self.kv_cache_dtype.startswith("fp8")
and supported_quant_type)
def forward( def forward(
self, self,
layer: torch.nn.Module, layer: torch.nn.Module,
@ -672,15 +699,42 @@ class FlashInferImpl(AttentionImpl):
""" """
assert output is not None, "Output tensor must be provided." assert output is not None, "Output tensor must be provided."
if output_scale is not None:
raise NotImplementedError(
"fused output quantization is not yet supported"
" for FlashInferImpl")
if attn_metadata is None: if attn_metadata is None:
# Profiling run. # Profiling run.
return output return output
if self.bmm1_scale is None:
self.bmm1_scale = (layer._q_scale_float * layer._k_scale_float *
self.scale)
if self.bmm2_scale is None:
self.bmm2_scale = layer._v_scale_float
# The attn+quant fusion happens when output_scale is provided.
if output_scale is None:
assert attn_metadata.q_data_type != FP8_DTYPE, \
"Query can only be FP8 if output fusion happened."
else:
assert attn_metadata.q_data_type == FP8_DTYPE, \
"Query must be FP8 when attn+quant fusion happened."
assert (attn_metadata.prefill_use_trtllm and
attn_metadata.decode_use_trtllm), "Must use TRT-LLM attn"
assert output.dtype == FP8_DTYPE, \
"Output must be FP8 when attn+quant fusion happened."
# TRTLLM attn kernel requires o scale as a host scalar, store the
# o scale to host scalar in warmup run with cuda graph not enabled
if layer._o_scale_float is None:
layer._o_scale_float = output_scale.cpu().item()
self.bmm2_scale = self.bmm2_scale / layer._o_scale_float
# Insert FP8 quant for query
num_tokens, num_heads, head_size = query.shape
query, _ = ops.scaled_fp8_quant(
query.reshape((num_tokens, num_heads * head_size)).contiguous(),
layer._q_scale)
query = query.reshape((num_tokens, num_heads, head_size))
# IMPORTANT! # IMPORTANT!
# NOTE(woosuk): With piece-wise CUDA graphs, this method is executed in # NOTE(woosuk): With piece-wise CUDA graphs, this method is executed in
# eager-mode PyTorch. Thus, we need to be careful about any CPU overhead # eager-mode PyTorch. Thus, we need to be careful about any CPU overhead
@ -718,9 +772,6 @@ class FlashInferImpl(AttentionImpl):
self.kv_cache_dtype) self.kv_cache_dtype)
kv_cache = kv_cache.view(torch_dtype) kv_cache = kv_cache.view(torch_dtype)
window_left = (self.sliding_window[0]
if self.sliding_window is not None else -1)
# Inputs and outputs may be padded for CUDA graphs # Inputs and outputs may be padded for CUDA graphs
query = query[:num_actual_tokens] query = query[:num_actual_tokens]
output_padded = output output_padded = output
@ -748,7 +799,7 @@ class FlashInferImpl(AttentionImpl):
if not attn_metadata.prefill_use_trtllm: if not attn_metadata.prefill_use_trtllm:
assert prefill_wrapper._causal assert prefill_wrapper._causal
assert prefill_wrapper._window_left == window_left assert prefill_wrapper._window_left == self.window_left
assert prefill_wrapper._logits_soft_cap == ( assert prefill_wrapper._logits_soft_cap == (
self.logits_soft_cap or 0.0) self.logits_soft_cap or 0.0)
assert prefill_wrapper._sm_scale == self.scale assert prefill_wrapper._sm_scale == self.scale
@ -783,12 +834,12 @@ class FlashInferImpl(AttentionImpl):
seq_lens=seq_lens_prefill, seq_lens=seq_lens_prefill,
max_q_len=attn_metadata.max_q_len, max_q_len=attn_metadata.max_q_len,
max_kv_len=attn_metadata.max_seq_len, max_kv_len=attn_metadata.max_seq_len,
bmm1_scale=layer._k_scale_float * self.scale, bmm1_scale=self.bmm1_scale,
bmm2_scale=layer._v_scale_float, bmm2_scale=self.bmm2_scale,
batch_size=attn_metadata.num_prefills, batch_size=attn_metadata.num_prefills,
cum_seq_lens_q=attn_metadata.qo_indptr_gpu, cum_seq_lens_q=attn_metadata.qo_indptr_gpu,
cum_seq_lens_kv=attn_metadata.paged_kv_indptr_gpu, cum_seq_lens_kv=attn_metadata.paged_kv_indptr_gpu,
window_left=window_left, window_left=self.window_left,
sinks=self.sinks, sinks=self.sinks,
out=output[num_decode_tokens:], out=output[num_decode_tokens:],
) )
@ -800,7 +851,7 @@ class FlashInferImpl(AttentionImpl):
assert decode_wrapper is not None assert decode_wrapper is not None
if not attn_metadata.decode_use_trtllm: if not attn_metadata.decode_use_trtllm:
assert decode_wrapper._window_left == window_left assert decode_wrapper._window_left == self.window_left
assert decode_wrapper._logits_soft_cap == (self.logits_soft_cap assert decode_wrapper._logits_soft_cap == (self.logits_soft_cap
or 0.0) or 0.0)
assert decode_wrapper._sm_scale == self.scale assert decode_wrapper._sm_scale == self.scale
@ -815,8 +866,8 @@ class FlashInferImpl(AttentionImpl):
# decode_query may be non-contiguous # decode_query may be non-contiguous
decode_query = decode_query.contiguous() decode_query = decode_query.contiguous()
workspace_buffer = decode_wrapper._float_workspace_buffer workspace_buffer = decode_wrapper._float_workspace_buffer
block_tables_decode = attn_metadata.block_table_tensor[: block_tables_decode = attn_metadata.\
num_decode_tokens] block_table_tensor[:num_decode_tokens]
seq_lens_decode = attn_metadata.seq_lens[:num_decode_tokens] seq_lens_decode = attn_metadata.seq_lens[:num_decode_tokens]
# This path needs to be enabled with VLLM_KV_CACHE_LAYOUT = HND # This path needs to be enabled with VLLM_KV_CACHE_LAYOUT = HND
@ -834,9 +885,9 @@ class FlashInferImpl(AttentionImpl):
block_tables=block_tables_decode, block_tables=block_tables_decode,
seq_lens=seq_lens_decode, seq_lens=seq_lens_decode,
max_seq_len=attn_metadata.max_seq_len, max_seq_len=attn_metadata.max_seq_len,
bmm1_scale=layer._k_scale_float * self.scale, bmm1_scale=self.bmm1_scale,
bmm2_scale=layer._v_scale_float, bmm2_scale=self.bmm2_scale,
window_left=window_left, window_left=self.window_left,
sinks=self.sinks, sinks=self.sinks,
out=output[:num_decode_tokens], out=output[:num_decode_tokens],
) )