diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 0912bc1fd94f..d4fcb91b11b0 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -631,6 +631,7 @@ steps: - vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py - vllm/v1/attention/backends/flashinfer.py - vllm/compilation/fusion.py + - vllm/compilation/fusion_attn.py commands: - nvidia-smi - python3 examples/offline_inference/basic/chat.py @@ -647,6 +648,7 @@ steps: - pytest -v -s tests/kernels/moe/test_nvfp4_moe.py # Fusion - pytest -v -s tests/compile/test_fusion_all_reduce.py + - pytest -v -s tests/compile/test_fusion_attn.py::test_attention_quant_pattern ##### 1 GPU test ##### ##### multi gpus test ##### diff --git a/benchmarks/kernels/benchmark_trtllm_decode_attention.py b/benchmarks/kernels/benchmark_trtllm_decode_attention.py index 77136edca45b..b3f81715461b 100644 --- a/benchmarks/kernels/benchmark_trtllm_decode_attention.py +++ b/benchmarks/kernels/benchmark_trtllm_decode_attention.py @@ -3,16 +3,14 @@ import csv import os -import random from datetime import datetime +from typing import Optional import flashinfer import torch FLOAT32_BYTES = torch.finfo(torch.float).bits // 8 - -# KV Cache Layout for TRT-LLM -# kv_cache_shape = (num_blocks, 2, num_kv_heads, page_size, head_dim) +FP8_DTYPE = torch.float8_e4m3fn def to_float8(x, dtype=torch.float8_e4m3fn): @@ -26,65 +24,107 @@ def to_float8(x, dtype=torch.float8_e4m3fn): @torch.no_grad() def benchmark_decode( - num_seqs, - max_seq_len, - page_size=16, - dtype=torch.bfloat16, - kv_layout="HND", - num_kv_heads=8, - kv_cache_dtype="auto", - head_dim=128, - warmup=10, - trials=20, + dtype: torch.dtype, + quant_dtypes: tuple[ + Optional[torch.dtype], Optional[torch.dtype], Optional[torch.dtype] + ], + batch_size: int, + max_seq_len: int, + num_heads: tuple[int, int] = (64, 8), + head_size: int = 128, + kv_layout: str = "HND", + block_size: int = 16, + warmup: int = 10, + trials: int = 20, ): torch.set_default_device("cuda") - device = "cuda" torch.manual_seed(0) - HEAD_GRP_SIZE = 8 - MAX_SEQ_LEN = max_seq_len + q_quant_dtype, kv_quant_dtype, o_quant_dtype = quant_dtypes + q_quant_dtype = q_quant_dtype or dtype + kv_quant_dtype = kv_quant_dtype or dtype + o_quant_dtype = o_quant_dtype or dtype + + num_qo_heads, num_kv_heads = num_heads + assert num_qo_heads % num_kv_heads == 0 + + sm_scale = float(1.0 / (head_size**0.5)) # large number to reduce kv_cache reuse - NUM_BLOCKS = int(256000 / page_size) + NUM_BLOCKS = int(256000 / block_size) - workspace_buffer = torch.empty(1024 * 1024 * 1024, dtype=torch.int8, device=device) + kv_cache_shape = None + if kv_layout == "NHD": + kv_cache_shape = (NUM_BLOCKS, 2, block_size, num_kv_heads, head_size) + elif kv_layout == "HND": + kv_cache_shape = (NUM_BLOCKS, 2, num_kv_heads, block_size, head_size) + else: + raise ValueError(f"Invalid kv_layout: {kv_layout}") - # For decode, batch_size is num_decode_token - num_qo_heads = num_kv_heads * HEAD_GRP_SIZE - sm_scale = float(1.0 / (head_dim**0.5)) - q = torch.randn(num_seqs, num_qo_heads, head_dim, device=device, dtype=dtype) - kv_lens = [random.randint(1, MAX_SEQ_LEN) for _ in range(num_seqs)] + query = torch.randn(batch_size, num_qo_heads, head_size, dtype=dtype) + if q_quant_dtype == FP8_DTYPE: + query, q_scale = to_float8(query) + ref_query = query.to(dtype) * q_scale + else: + q_scale = 1.0 + ref_query = query - max_kv_len = max(kv_lens) - kv_lens_tensor = torch.tensor(kv_lens, dtype=torch.int, device=device) - max_num_blocks_per_seq = (max_kv_len + page_size - 1) // page_size + kv_lens = torch.randint(1, max_seq_len, (batch_size,), dtype=torch.int32) + kv_lens[-1] = max_seq_len + seq_lens = kv_lens + max_seq_len = torch.max(seq_lens).item() + + kv_cache = torch.randn(kv_cache_shape, dtype=dtype) + if kv_quant_dtype == FP8_DTYPE: + kv_cache, kv_scale = to_float8(kv_cache) + ref_kv_cache = kv_cache.to(dtype) * kv_scale + else: + kv_scale = 1.0 + ref_kv_cache = kv_cache + k_scale = v_scale = kv_scale + + max_num_blocks_per_seq = (max_seq_len + block_size - 1) // block_size block_tables = torch.randint( - 0, NUM_BLOCKS, (num_seqs, max_num_blocks_per_seq), dtype=torch.int32 + 0, NUM_BLOCKS, (batch_size, max_num_blocks_per_seq), dtype=torch.int32 ) + kv_indptr = [0] + kv_indices = [] + kv_last_page_lens = [] + for i in range(batch_size): + seq_len = seq_lens[i] + assert seq_len > 0 + num_blocks = (seq_len + block_size - 1) // block_size + kv_indices.extend(block_tables[i, :num_blocks]) + kv_indptr.append(kv_indptr[-1] + num_blocks) + kv_last_page_len = seq_len % block_size + if kv_last_page_len == 0: + kv_last_page_len = block_size + kv_last_page_lens.append(kv_last_page_len) - kv_cache_shape = (NUM_BLOCKS, 2, num_kv_heads, page_size, head_dim) - kv_cache = torch.randn(size=kv_cache_shape, device=device, dtype=dtype) - k_scale = v_scale = 1.0 + kv_indptr = torch.tensor(kv_indptr, dtype=torch.int32) + kv_indices = torch.tensor(kv_indices, dtype=torch.int32) + kv_last_page_lens = torch.tensor(kv_last_page_lens, dtype=torch.int32) + workspace_buffer = torch.zeros(1024 * 1024 * 1024, dtype=torch.int8) - if kv_cache_dtype.startswith("fp8"): - kv_cache, _ = to_float8(kv_cache) - - output_trtllm = torch.empty(q.shape, dtype=dtype) - - # Benchmark TRT decode - def trt_decode(): - return flashinfer.decode.trtllm_batch_decode_with_kv_cache( - q, - kv_cache, - workspace_buffer, - block_tables, - kv_lens_tensor, - max_kv_len, - bmm1_scale=k_scale * sm_scale, - bmm2_scale=v_scale, - out=output_trtllm, - ) + wrapper = flashinfer.BatchDecodeWithPagedKVCacheWrapper( + workspace_buffer, + kv_layout, + use_tensor_cores=((num_qo_heads // num_kv_heads) > 4), + ) + wrapper.plan( + kv_indptr, + kv_indices, + kv_last_page_lens, + num_qo_heads, + num_kv_heads, + head_size, + block_size, + "NONE", + sm_scale=sm_scale, + q_data_type=dtype, + kv_data_type=dtype, + ) def time_fn(fn, warmup=10, trials=20): torch.cuda.synchronize() @@ -101,74 +141,51 @@ def benchmark_decode( times.append(start.elapsed_time(end)) # ms return sum(times) / len(times), torch.std(torch.tensor(times)) - # TRT Decode - trt_mean, trt_std = time_fn(trt_decode) - - kv_indptr = [0] - kv_indices = [] - kv_last_page_lens = [] - for i in range(num_seqs): - seq_len = kv_lens[i] - assert seq_len > 0 - num_blocks = (seq_len + page_size - 1) // page_size - kv_indices.extend(block_tables[i, :num_blocks]) - kv_indptr.append(kv_indptr[-1] + num_blocks) - kv_last_page_len = seq_len % page_size - if kv_last_page_len == 0: - kv_last_page_len = page_size - kv_last_page_lens.append(kv_last_page_len) - - kv_indptr = torch.tensor(kv_indptr, dtype=torch.int32) - kv_indices = torch.tensor(kv_indices, dtype=torch.int32) - kv_last_page_lens = torch.tensor(kv_last_page_lens, dtype=torch.int32) - - output_baseline = torch.empty(q.shape, dtype=dtype) - - wrapper = flashinfer.BatchDecodeWithPagedKVCacheWrapper( - workspace_buffer, - kv_layout, - use_tensor_cores=((num_qo_heads // num_kv_heads) > 4), - ) - - wrapper.plan( - kv_indptr, - kv_indices, - kv_last_page_lens, - num_qo_heads, - num_kv_heads, - head_dim, - page_size, - "NONE", - q_data_type=dtype, - kv_data_type=torch.float8_e4m3fn if kv_cache_dtype.startswith("fp8") else dtype, - ) + o_scale = 1.0 + output_baseline = torch.empty(ref_query.shape, dtype=dtype) + output_trtllm = torch.empty(query.shape, dtype=o_quant_dtype) def baseline_decode(): - return wrapper.run(q, kv_cache, sm_scale, k_scale, v_scale, output_baseline) + return wrapper.run(ref_query, ref_kv_cache, out=output_baseline) + + def trtllm_decode(): + return flashinfer.decode.trtllm_batch_decode_with_kv_cache( + query=query, + kv_cache=kv_cache, + workspace_buffer=workspace_buffer, + block_tables=block_tables, + seq_lens=seq_lens, + max_seq_len=max_seq_len, + bmm1_scale=q_scale * k_scale * sm_scale, + bmm2_scale=v_scale / o_scale, + out=output_trtllm, + ) baseline_mean, baseline_std = time_fn(baseline_decode) + trtllm_mean, trtllm_std = time_fn(trtllm_decode) # Calculate percentage speedup (positive means TRT is faster) - speedup_percent = (baseline_mean - trt_mean) / baseline_mean + speedup_percent = (baseline_mean - trtllm_mean) / baseline_mean print( - f"\t{num_seqs}\t{max_seq_len}\t{trt_mean:.3f}\t{trt_std.item():.3f}" + f"\t{batch_size}\t{max_seq_len}\t{trtllm_mean:.3f}\t{trtllm_std.item():.3f}" f"\t{baseline_mean:.3f}\t{baseline_std.item():.3f}\t{speedup_percent:.3f}" ) # Return results for CSV writing return { - "num_seqs": num_seqs, - "trt_mean": trt_mean, - "trt_std": trt_std.item(), + "batch_size": batch_size, + "trtllm_mean": trtllm_mean, + "trtllm_std": trtllm_std.item(), "baseline_mean": baseline_mean, "baseline_std": baseline_std.item(), "speedup_percent": speedup_percent, - "q_dtype": str(dtype), - "kv_cache_dtype": kv_cache_dtype, - "page_size": page_size, + "q_dtype": str(q_quant_dtype), + "kv_cache_dtype": str(kv_quant_dtype), + "output_dtype": str(o_quant_dtype), + "block_size": block_size, "num_kv_heads": num_kv_heads, - "head_dim": head_dim, + "head_size": head_size, "max_seq_len": max_seq_len, } @@ -180,17 +197,18 @@ def write_results_to_csv(results, filename=None): filename = f"flashinfer_trtllm_benchmark_{timestamp}.csv" fieldnames = [ - "num_seqs", - "trt_mean", - "trt_std", + "batch_size", + "trtllm_mean", + "trtllm_std", "baseline_mean", "baseline_std", "speedup_percent", "q_dtype", "kv_cache_dtype", - "page_size", + "output_dtype", + "block_size", "num_kv_heads", - "head_dim", + "head_size", "max_seq_len", ] @@ -209,45 +227,42 @@ def write_results_to_csv(results, filename=None): if __name__ == "__main__": - num_seqs = [1, 4, 8, 16, 32, 64, 128, 256] + batch_sizes = [1, 4, 8, 16, 32, 64, 128, 256] max_seq_lens = [1024, 2048, 4096, 8192, 16384, 32768, 65536, 131072] all_results = [] - print( - "Running benchmark for q_dtype = bfloat16, kv_cache_dtype: bfloat16, " - "output_dtype: bfloat16" - ) - print( - "\tnum_seqs\tmax_seq_len\ttrt_mean\ttrt_std\tbaseline_mean\t" - "baseline_std\tspeedup_percent" - ) - for max_seq_len in max_seq_lens: - for bs in num_seqs: - result = benchmark_decode( - bs, - max_seq_len, - dtype=torch.bfloat16, - kv_cache_dtype="auto", - ) - all_results.append(result) + dtype = torch.bfloat16 + quant_dtypes = [ + # (q_quant_dtype, kv_quant_dtype, o_quant_dtype) + (None, None, None), + (None, FP8_DTYPE, None), + (FP8_DTYPE, FP8_DTYPE, FP8_DTYPE), + ] - print( - "Running benchmark for q_dtype = bfloat16, kv_cache_dtype: fp8, " - "output_dtype: bfloat16" - ) - print( - "\tnum_seqs\tmax_seq_len\ttrt_mean\ttrt_std\tbaseline_mean\t" - "baseline_std\tspeedup_percent" - ) - for max_seq_len in max_seq_lens: - for bs in num_seqs: - result = benchmark_decode( - bs, - max_seq_len, - dtype=torch.bfloat16, - kv_cache_dtype="fp8", - ) - all_results.append(result) + for quant_dtype in quant_dtypes: + q_quant_dtype, kv_quant_dtype, o_quant_dtype = quant_dtype + q_quant_dtype = q_quant_dtype or dtype + kv_quant_dtype = kv_quant_dtype or dtype + o_quant_dtype = o_quant_dtype or dtype + + print( + f"Running benchmark for q_dtype = {q_quant_dtype}, " + f"kv_cache_dtype: {kv_quant_dtype}, " + f"output_dtype: {o_quant_dtype}" + ) + print( + "\tbatch_size\tmax_seq_len\ttrtllm_mean\ttrtllm_std\tbaseline_mean\t" + "baseline_std\tspeedup_percent" + ) + for max_seq_len in max_seq_lens: + for bs in batch_sizes: + result = benchmark_decode( + dtype=dtype, + quant_dtypes=quant_dtype, + batch_size=bs, + max_seq_len=max_seq_len, + ) + all_results.append(result) # Write all results to CSV write_results_to_csv(all_results) diff --git a/benchmarks/kernels/benchmark_trtllm_prefill_attention.py b/benchmarks/kernels/benchmark_trtllm_prefill_attention.py index 67bd9aebbcca..49810e20c7d8 100644 --- a/benchmarks/kernels/benchmark_trtllm_prefill_attention.py +++ b/benchmarks/kernels/benchmark_trtllm_prefill_attention.py @@ -3,16 +3,14 @@ import csv import os -import random from datetime import datetime +from typing import Optional import flashinfer import torch FLOAT32_BYTES = torch.finfo(torch.float).bits // 8 - -# KV Cache Layout for TRT-LLM -# kv_cache_shape = (num_blocks, 2, num_kv_heads, page_size, head_dim) +FP8_DTYPE = torch.float8_e4m3fn def to_float8(x, dtype=torch.float8_e4m3fn): @@ -26,84 +24,99 @@ def to_float8(x, dtype=torch.float8_e4m3fn): @torch.no_grad() def benchmark_prefill( - num_seqs, - max_seq_len, - page_size=16, - dtype=torch.bfloat16, - kv_layout="HND", - num_kv_heads=8, - kv_cache_dtype="auto", - head_dim=128, - warmup=10, - trials=20, + dtype: torch.dtype, + quant_dtypes: tuple[ + Optional[torch.dtype], Optional[torch.dtype], Optional[torch.dtype] + ], + batch_size: int, + max_seq_len: int, + num_heads: tuple[int, int] = (64, 8), + head_size: int = 128, + kv_layout: str = "HND", + block_size: int = 16, + warmup: int = 10, + trials: int = 20, ): torch.set_default_device("cuda") torch.manual_seed(0) - HEAD_GRP_SIZE = 8 - MAX_SEQ_LEN = max_seq_len + q_quant_dtype, kv_quant_dtype, o_quant_dtype = quant_dtypes + q_quant_dtype = q_quant_dtype or dtype + kv_quant_dtype = kv_quant_dtype or dtype + o_quant_dtype = o_quant_dtype or dtype + + max_q_len = max_kv_len = max_seq_len + + num_qo_heads, num_kv_heads = num_heads + assert num_qo_heads % num_kv_heads == 0 + + sm_scale = float(1.0 / (head_size**0.5)) # large number to reduce kv_cache reuse - NUM_BLOCKS = int(256000 / page_size) + NUM_BLOCKS = int(256000 / block_size) - workspace_buffer = torch.empty(1024 * 1024 * 1024, dtype=torch.int8) + kv_cache_shape = None + if kv_layout == "NHD": + kv_cache_shape = (NUM_BLOCKS, 2, block_size, num_kv_heads, head_size) + elif kv_layout == "HND": + kv_cache_shape = (NUM_BLOCKS, 2, num_kv_heads, block_size, head_size) + else: + raise ValueError(f"Invalid kv_layout: {kv_layout}") - num_qo_heads = num_kv_heads * HEAD_GRP_SIZE - sm_scale = float(1.0 / (head_dim**0.5)) - - q_lens = [random.randint(1, MAX_SEQ_LEN) for _ in range(num_seqs)] - q_lens[-1] = MAX_SEQ_LEN - max_q_len = max(q_lens) + q_lens = torch.randint(1, max_q_len, (batch_size,), dtype=torch.int32) + q_lens[-1] = max_q_len q_indptr = torch.cat( [ torch.tensor([0], dtype=torch.int32), - torch.cumsum( - torch.tensor(q_lens, dtype=torch.int32), dim=0, dtype=torch.int32 - ), + torch.cumsum(q_lens, dim=0, dtype=torch.int32), ] ) - q = torch.randn(sum(q_lens), num_qo_heads, head_dim, dtype=dtype) - kv_lens = [random.randint(0, MAX_SEQ_LEN) for _ in range(num_seqs)] - kv_lens[-1] = MAX_SEQ_LEN + query = torch.randn(torch.sum(q_lens).item(), num_qo_heads, head_size, dtype=dtype) + if q_quant_dtype == FP8_DTYPE: + query, q_scale = to_float8(query) + ref_query = query.to(dtype) * q_scale + else: + q_scale = 1.0 + ref_query = query - seq_lens = [q_len + kv_len for q_len, kv_len in zip(q_lens, kv_lens)] - max_seq_len = max(seq_lens) - seq_lens_tensor = torch.tensor(seq_lens, dtype=torch.int32) + kv_lens = torch.randint(0, max_kv_len, (batch_size,), dtype=torch.int32) + kv_lens[-1] = max_kv_len - max_num_blocks_per_seq = (max_seq_len + page_size - 1) // page_size + seq_lens = kv_lens + q_lens + max_seq_len = torch.max(seq_lens).item() + + kv_cache = torch.randn(kv_cache_shape, dtype=dtype) + if kv_quant_dtype == FP8_DTYPE: + kv_cache, kv_scale = to_float8(kv_cache) + ref_kv_cache = kv_cache.to(dtype) * kv_scale + else: + kv_scale = 1.0 + ref_kv_cache = kv_cache + k_scale = v_scale = kv_scale + + max_num_blocks_per_seq = (max_seq_len + block_size - 1) // block_size block_tables = torch.randint( - 0, NUM_BLOCKS, (num_seqs, max_num_blocks_per_seq), dtype=torch.int32 + 0, NUM_BLOCKS, (batch_size, max_num_blocks_per_seq), dtype=torch.int32 ) - - kv_cache_shape = (NUM_BLOCKS, 2, num_kv_heads, page_size, head_dim) - kv_cache = torch.randn(size=kv_cache_shape, dtype=dtype) - k_scale = v_scale = 1.0 - - if kv_cache_dtype.startswith("fp8"): - kv_cache, _ = to_float8(kv_cache) - - output_trtllm = torch.empty(q.shape, dtype=dtype) - kv_indptr = [0] kv_indices = [] kv_last_page_lens = [] - for i in range(num_seqs): + for i in range(batch_size): seq_len = seq_lens[i] assert seq_len > 0 - num_blocks = (seq_len + page_size - 1) // page_size + num_blocks = (seq_len + block_size - 1) // block_size kv_indices.extend(block_tables[i, :num_blocks]) kv_indptr.append(kv_indptr[-1] + num_blocks) - kv_last_page_len = seq_len % page_size + kv_last_page_len = seq_len % block_size if kv_last_page_len == 0: - kv_last_page_len = page_size + kv_last_page_len = block_size kv_last_page_lens.append(kv_last_page_len) kv_indptr = torch.tensor(kv_indptr, dtype=torch.int32) kv_indices = torch.tensor(kv_indices, dtype=torch.int32) kv_last_page_lens = torch.tensor(kv_last_page_lens, dtype=torch.int32) - - output_baseline = torch.empty(q.shape, dtype=dtype) + workspace_buffer = torch.zeros(1024 * 1024 * 1024, dtype=torch.int8) wrapper = flashinfer.BatchPrefillWithPagedKVCacheWrapper( workspace_buffer, kv_layout @@ -115,12 +128,12 @@ def benchmark_prefill( kv_last_page_lens, num_qo_heads, num_kv_heads, - head_dim, - page_size, + head_size, + block_size, causal=True, sm_scale=sm_scale, q_data_type=dtype, - kv_data_type=kv_cache.dtype, + kv_data_type=dtype, ) def time_fn(fn, warmup=10, trials=20): @@ -138,52 +151,55 @@ def benchmark_prefill( times.append(start.elapsed_time(end)) # ms return sum(times) / len(times), torch.std(torch.tensor(times)) - def baseline_prefill(): - return wrapper.run( - q, kv_cache, k_scale=k_scale, v_scale=v_scale, out=output_baseline - ) + o_scale = 1.0 + output_baseline = torch.empty(ref_query.shape, dtype=dtype) + output_trtllm = torch.empty(query.shape, dtype=o_quant_dtype) - def trt_prefill(): + def baseline_prefill(): + return wrapper.run(ref_query, ref_kv_cache, out=output_baseline) + + def trtllm_prefill(): return flashinfer.prefill.trtllm_batch_context_with_kv_cache( - query=q, + query=query, kv_cache=kv_cache, workspace_buffer=workspace_buffer, block_tables=block_tables, - seq_lens=seq_lens_tensor, + seq_lens=seq_lens, max_q_len=max_q_len, max_kv_len=max_seq_len, - bmm1_scale=k_scale * sm_scale, - bmm2_scale=v_scale, - batch_size=num_seqs, + bmm1_scale=q_scale * k_scale * sm_scale, + bmm2_scale=v_scale / o_scale, + batch_size=batch_size, cum_seq_lens_q=q_indptr, cum_seq_lens_kv=kv_indptr, out=output_trtllm, ) - trt_mean, trt_std = time_fn(trt_prefill) baseline_mean, baseline_std = time_fn(baseline_prefill) + trtllm_mean, trtllm_std = time_fn(trtllm_prefill) # Calculate percentage speedup (positive means TRT is faster) - speedup_percent = (baseline_mean - trt_mean) / baseline_mean + speedup_percent = (baseline_mean - trtllm_mean) / baseline_mean print( - f"\t{num_seqs}\t{max_seq_len}\t{trt_mean:.5f}\t{trt_std.item():.5f}" - f"\t{baseline_mean:.5f}\t{baseline_std.item():.5f}\t{speedup_percent:.5f}" + f"\t{batch_size}\t{max_seq_len}\t{trtllm_mean:8.3f}\t{trtllm_std.item():8.3f}" + f"\t{baseline_mean:8.3f}\t{baseline_std.item():8.3f}\t{speedup_percent:8.3f}" ) # Return results for CSV writing return { - "num_seqs": num_seqs, - "trt_mean": trt_mean, - "trt_std": trt_std.item(), + "batch_size": batch_size, + "trtllm_mean": trtllm_mean, + "trtllm_std": trtllm_std.item(), "baseline_mean": baseline_mean, "baseline_std": baseline_std.item(), "speedup_percent": speedup_percent, - "q_dtype": str(dtype), - "kv_cache_dtype": kv_cache_dtype, - "page_size": page_size, + "q_dtype": str(q_quant_dtype), + "kv_cache_dtype": str(kv_quant_dtype), + "output_dtype": str(o_quant_dtype), + "block_size": block_size, "num_kv_heads": num_kv_heads, - "head_dim": head_dim, + "head_size": head_size, "max_seq_len": max_seq_len, } @@ -195,17 +211,18 @@ def write_results_to_csv(results, filename=None): filename = f"flashinfer_trtllm_benchmark_{timestamp}.csv" fieldnames = [ - "num_seqs", - "trt_mean", - "trt_std", + "batch_size", + "trtllm_mean", + "trtllm_std", "baseline_mean", "baseline_std", "speedup_percent", "q_dtype", "kv_cache_dtype", - "page_size", + "output_dtype", + "block_size", "num_kv_heads", - "head_dim", + "head_size", "max_seq_len", ] @@ -224,27 +241,41 @@ def write_results_to_csv(results, filename=None): if __name__ == "__main__": - num_seqs = [1, 4, 8, 16, 32, 64, 128, 256] + batch_sizes = [1, 4, 8, 16, 32, 64, 128, 256] max_seq_lens = [1024, 2048, 4096, 8192, 16384, 32768, 65536, 131072] all_results = [] - print( - "Running benchmark for q_dtype = bfloat16, kv_cache_dtype: bfloat16, " - "output_dtype: bfloat16" - ) - print( - "\tnum_seqs\tmax_seq_len\ttrt_mean\ttrt_std\tbaseline_mean\t" - "baseline_std\tspeedup_percent" - ) - for max_seq_len in max_seq_lens: - for bs in num_seqs: - result = benchmark_prefill( - bs, - max_seq_len, - dtype=torch.bfloat16, - kv_cache_dtype="auto", - ) - all_results.append(result) + dtype = torch.bfloat16 + quant_dtypes = [ + # (q_quant_dtype, kv_quant_dtype, o_quant_dtype) + (None, None, None), + (FP8_DTYPE, FP8_DTYPE, FP8_DTYPE), + ] + + for quant_dtype in quant_dtypes: + q_quant_dtype, kv_quant_dtype, o_quant_dtype = quant_dtype + q_quant_dtype = q_quant_dtype or dtype + kv_quant_dtype = kv_quant_dtype or dtype + o_quant_dtype = o_quant_dtype or dtype + + print( + f"Running benchmark for q_dtype = {q_quant_dtype}, " + f"kv_cache_dtype: {kv_quant_dtype}, " + f"output_dtype: {o_quant_dtype}" + ) + print( + "\tbatch_size\tmax_seq_len\ttrtllm_mean\ttrtllm_std\tbaseline_mean\t" + "baseline_std\tspeedup_percent" + ) + for max_seq_len in max_seq_lens: + for bs in batch_sizes: + result = benchmark_prefill( + dtype=dtype, + quant_dtypes=quant_dtype, + batch_size=bs, + max_seq_len=max_seq_len, + ) + all_results.append(result) # Write all results to CSV write_results_to_csv(all_results) diff --git a/tests/compile/test_fusion_attn.py b/tests/compile/test_fusion_attn.py index 70750eb9ac4e..bef0fdef985e 100644 --- a/tests/compile/test_fusion_attn.py +++ b/tests/compile/test_fusion_attn.py @@ -1,5 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import copy from typing import Optional import pytest @@ -7,13 +8,27 @@ import torch._dynamo from tests.compile.backend import TestBackend from tests.models.utils import check_outputs_equal +from tests.v1.attention.utils import (BatchSpec, _Backend, + create_common_attn_metadata) from vllm import LLM, SamplingParams +from vllm.attention import Attention +from vllm.attention.selector import global_force_attn_backend_context_manager from vllm.compilation.fusion import QUANT_OPS, QuantKey, kFp8StaticTensorSym from vllm.compilation.fusion_attn import ATTN_OP, AttnFusionPass from vllm.compilation.fx_utils import find_op_nodes from vllm.compilation.noop_elimination import NoOpEliminationPass -from vllm.config import CompilationConfig, CompilationLevel, VllmConfig +from vllm.config import (CacheConfig, CompilationConfig, CompilationLevel, + ModelConfig, PassConfig, SchedulerConfig, VllmConfig, + set_current_vllm_config) +from vllm.forward_context import get_forward_context, set_forward_context +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + GroupShape) +from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( + Fp8LinearOp) from vllm.platforms import current_platform +from vllm.v1.kv_cache_interface import AttentionSpec + +FP8_DTYPE = current_platform.fp8_dtype() # globals needed for string-import custom Dynamo backend field backend: Optional[TestBackend] = None @@ -132,3 +147,235 @@ def test_attention_fusion(example_prompts, monkeypatch, model: str, # Reset backend to make sure llm2 gets released backend = None + + +class TestAttentionStaticQuantPatternModel(torch.nn.Module): + """Test model for AttentionStaticQuantPattern fusion.""" + + def __init__(self, num_qo_heads: int, num_kv_heads: int, head_size: int, + kv_cache_dtype: torch.dtype, device: torch.device, + vllm_config: VllmConfig): + super().__init__() + self.num_qo_heads = num_qo_heads + self.num_kv_heads = num_kv_heads + self.head_size = head_size + self.kv_cache_dtype = kv_cache_dtype + self.device = device + self.vllm_config = vllm_config + + self.attn = Attention( + num_heads=self.num_qo_heads, + head_size=self.head_size, + scale=1.0 / (self.head_size**0.5), + num_kv_heads=self.num_kv_heads, + cache_config=vllm_config.cache_config, + prefix="model.layers.0.self_attn.attn", + ) + + self.fp8_linear = Fp8LinearOp( + act_quant_static=True, act_quant_group_shape=GroupShape.PER_TENSOR) + self.wscale = torch.tensor([1.0], dtype=torch.float32) + self.scale = torch.tensor([1.0], dtype=torch.float32) + + self.block_size = 16 + + # Initialize attn MetadataBuilder + self.builder = self.attn.attn_backend.get_builder_cls()( + kv_cache_spec=AttentionSpec( + block_size=self.block_size, + num_kv_heads=self.num_kv_heads, + head_size=self.head_size, + dtype=self.kv_cache_dtype, + use_mla=False, + ), + layer_names=[self.attn.layer_name], + vllm_config=self.vllm_config, + device=self.device, + ) + + def build_attn_metadata(self, batch_size: int): + """Initialize attention metadata.""" + + # Create common attn metadata + batch_spec = BatchSpec(seq_lens=[1] * batch_size, + query_lens=[1] * batch_size) + common_attn_metadata = create_common_attn_metadata( + batch_spec, + self.block_size, + self.device, + arange_block_indices=True) + + max_blocks = (max(batch_spec.seq_lens) + self.block_size - + 1) // self.block_size + num_blocks = batch_size * max_blocks + + # Create dummy KV cache for FlashInfer TRTLLM + # - NHD: [num_blocks, 2, block_size, num_kv_heads, head_size] + # - HND: [num_blocks, 2, num_kv_heads, block_size, head_size] + # Create kv_cache in HND layout and permute to NHD layout + # (later will be permuted back to HND layout in forward pass) + kv_cache = torch.zeros(num_blocks, + 2, + self.num_kv_heads, + self.block_size, + self.head_size, + dtype=self.kv_cache_dtype, + device=self.device) + kv_cache = kv_cache.permute(0, 1, 3, 2, 4) + self.attn.kv_cache = [kv_cache] + + # Build attn metadata + self.attn_metadata = self.builder.build( + common_prefix_len=0, common_attn_metadata=common_attn_metadata) + + return self.attn_metadata + + def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, + w: torch.Tensor): + """Forward pass that creates the pattern to be fused.""" + attn_output = self.attn(q, k, v) + return self.fp8_linear.apply(input=attn_output, + weight=w, + weight_scale=self.wscale, + input_scale=self.scale) + + +@pytest.mark.parametrize("num_qo_heads, num_kv_heads", [(64, 8), (40, 8)]) +@pytest.mark.parametrize("head_size", [128]) +@pytest.mark.parametrize("batch_size", [7, 256, 533]) +@pytest.mark.parametrize("dtype", [torch.bfloat16]) +@pytest.mark.parametrize( + "model_name, quant_key", + [("nvidia/Llama-4-Scout-17B-16E-Instruct-FP8", kFp8StaticTensorSym)]) +@pytest.mark.parametrize("backend", [_Backend.FLASHINFER]) +@pytest.mark.skipif(not current_platform.is_cuda(), reason="Only test CUDA") +@pytest.mark.skipif(not current_platform.supports_fp8(), reason="Need FP8") +@pytest.mark.skipif(not current_platform.is_device_capability((10, 0)), + reason="Only test on SM100(Blackwell)") +def test_attention_quant_pattern(num_qo_heads: int, num_kv_heads: int, + head_size: int, batch_size: int, + dtype: torch.dtype, model_name: str, + quant_key: QuantKey, backend: _Backend, + monkeypatch, dist_init): + """Test AttentionStaticQuantPattern fusion pass""" + + monkeypatch.setenv("VLLM_USE_V1", "1") + + device = torch.device("cuda:0") + torch.manual_seed(42) + + vllm_config = VllmConfig( + model_config=ModelConfig( + model=model_name, + max_model_len=2048, + ), + scheduler_config=SchedulerConfig(max_num_seqs=1024), + compilation_config=CompilationConfig( + level=CompilationLevel.PIECEWISE, + custom_ops=["+quant_fp8"], + ), + cache_config=CacheConfig(cache_dtype="fp8")) + + # Create test inputs + hidden_size = num_qo_heads * head_size + q = torch.randn(batch_size, hidden_size, dtype=dtype, device=device) + k = torch.randn(batch_size, + num_kv_heads * head_size, + dtype=dtype, + device=device) + v = torch.randn(batch_size, + num_kv_heads * head_size, + dtype=dtype, + device=device) + linear_w = torch.randn(hidden_size, hidden_size).to(FP8_DTYPE).t() + + # Mark first dimension as dynamic for realistic testing + torch._dynamo.mark_dynamic(q, 0) + torch._dynamo.mark_dynamic(k, 0) + torch._dynamo.mark_dynamic(v, 0) + + # Run model directly without compilation and fusion + vllm_config_unfused = copy.deepcopy(vllm_config) + with set_current_vllm_config(vllm_config_unfused), set_forward_context( + attn_metadata=None, vllm_config=vllm_config_unfused + ), global_force_attn_backend_context_manager(backend): + model_unfused = TestAttentionStaticQuantPatternModel( + num_qo_heads, num_kv_heads, head_size, FP8_DTYPE, device, + vllm_config_unfused) + model_unfused = model_unfused.to(device) + + forward_ctx = get_forward_context() + forward_ctx.attn_metadata = model_unfused.build_attn_metadata( + batch_size) + + # Run model directly without compilation and fusion + result_unfused = model_unfused(q, k, v, linear_w) + + # Run model with attn fusion enabled + vllm_config.compilation_config.pass_config = PassConfig( + enable_attn_fusion=True, enable_noop=True) + with set_current_vllm_config(vllm_config), set_forward_context( + attn_metadata=None, vllm_config=vllm_config + ), global_force_attn_backend_context_manager(backend): + model_fused = TestAttentionStaticQuantPatternModel( + num_qo_heads, num_kv_heads, head_size, FP8_DTYPE, device, + vllm_config) + model_fused = model_fused.to(device) + + forward_ctx = get_forward_context() + forward_ctx.attn_metadata = model_fused.build_attn_metadata(batch_size) + + # Create test backend with fusion passes enabled + noop_pass = NoOpEliminationPass(vllm_config) + attn_pass = lambda *args, **kw: AttnFusionPass(vllm_config)(*args, **kw + ) + test_backend = TestBackend(noop_pass, attn_pass) + + # Compile model with fusion enabled + model_compiled = torch.compile(model_fused, + backend=test_backend, + fullgraph=True) + assert model_compiled.attn._o_scale_float is None + result_fused_1 = model_compiled(q, k, v, linear_w) + + # After the 1st round of the forward pass, output quant scale should be + # loaded into the attn layer's _o_scale_float, the 2nd round should + # reuse the loaded _o_scale_float + assert model_compiled.attn._o_scale_float is not None + result_fused_2 = model_compiled(q, k, v, linear_w) + assert model_compiled.attn._o_scale_float is not None + + # Check attn fusion support + attn_fusion_supported = [ + layer.impl.fused_output_quant_supported(quant_key.dtype, + quant_key.static, + quant_key.group_shape) for key, + layer in vllm_config.compilation_config.static_forward_context.items() + ] + if any(attn_fusion_supported): + # Check quantization ops in the graph before and after fusion + test_backend.check_before_ops([QUANT_OPS[quant_key]], + fully_replaced=True) + + # Check attention ops in the graph before and after fusion + attn_nodes_pre = list(find_op_nodes(ATTN_OP, test_backend.graph_pre_pass)) + attn_nodes_post = list(find_op_nodes(ATTN_OP, + test_backend.graph_post_pass)) + + assert len(attn_nodes_pre) > 0, "Should have attention nodes before fusion" + assert len(attn_nodes_pre) == len(attn_nodes_post), \ + "Should have same number of attention nodes before and after fusion" + assert attn_nodes_pre[0].kwargs.get("output_scale") is None, \ + "Attention should not have output_scale before fusion" + assert attn_nodes_post[0].kwargs.get("output_scale") is not None, \ + "Attention should have output_scale after fusion" + + # Check that results are closed + torch.testing.assert_close(result_unfused, + result_fused_1, + atol=1e-2, + rtol=1e-2) + torch.testing.assert_close(result_unfused, + result_fused_2, + atol=1e-2, + rtol=1e-2) diff --git a/tests/kernels/attention/test_flashinfer_trtllm_attention.py b/tests/kernels/attention/test_flashinfer_trtllm_attention.py index 4b84e6a00ece..619822f3ee43 100644 --- a/tests/kernels/attention/test_flashinfer_trtllm_attention.py +++ b/tests/kernels/attention/test_flashinfer_trtllm_attention.py @@ -13,21 +13,7 @@ if not current_platform.is_device_capability(100): allow_module_level=True) FLOAT32_BYTES = torch.finfo(torch.float).bits // 8 - -# KV Cache Layout for TRT-LLM -# kv_cache_shape = (num_blocks, 2, num_kv_heads, page_size, head_dim) - -MAX_Q_LEN = 1024 -MAX_KV_LEN = 4096 -BATCH_SIZES = [4, 12] -NUM_HEADS = [(16, 16), (40, 8)] -HEAD_SIZES = [128] -BLOCK_SIZES = [16] -KV_LAYOUTS = ["HND"] -DTYPES = [torch.bfloat16] -KV_CACHE_DTYPES = [None, current_platform.fp8_dtype()] -NUM_BLOCKS = 32768 # Large enough to test overflow in index calculation. -SOFT_CAPS = [None, 50.0] +FP8_DTYPE = current_platform.fp8_dtype() def to_float8(x, dtype=torch.float8_e4m3fn): @@ -39,42 +25,59 @@ def to_float8(x, dtype=torch.float8_e4m3fn): return x_scl_sat.to(dtype), scale.float().reciprocal() -@pytest.mark.parametrize("batch_size", BATCH_SIZES) +DTYPE = [torch.bfloat16] +QUANT_DTYPES = [ + # (q_quant_dtype, kv_quant_dtype, o_quant_dtype) + (None, None, None), + (FP8_DTYPE, FP8_DTYPE, FP8_DTYPE), +] +BATCH_SIZE = [4, 12] +MAX_SEQ_LENS = [(1024, 4096)] +NUM_HEADS = [(64, 8), (40, 8)] +HEAD_SIZE = [128] +KV_LAYOUT = ["HND"] # currently only HND is supported +BLOCK_SIZE = [16] +SOFT_CAP = [None, 50.0] + +NUM_BLOCKS = 32768 # Large enough to test overflow in index calculation. + + +@pytest.mark.parametrize("dtype", DTYPE) +@pytest.mark.parametrize("quant_dtypes", QUANT_DTYPES) +@pytest.mark.parametrize("batch_size", BATCH_SIZE) +@pytest.mark.parametrize("max_seq_lens", MAX_SEQ_LENS) @pytest.mark.parametrize("num_heads", NUM_HEADS) -@pytest.mark.parametrize("head_size", HEAD_SIZES) -@pytest.mark.parametrize("block_size", BLOCK_SIZES) -@pytest.mark.parametrize("kv_layout", KV_LAYOUTS) -@pytest.mark.parametrize("dtype", DTYPES) -@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPES) -@pytest.mark.parametrize("soft_cap", SOFT_CAPS) +@pytest.mark.parametrize("head_size", HEAD_SIZE) +@pytest.mark.parametrize("kv_layout", KV_LAYOUT) +@pytest.mark.parametrize("block_size", BLOCK_SIZE) +@pytest.mark.parametrize("soft_cap", SOFT_CAP) @torch.inference_mode def test_flashinfer_trtllm_decode_with_baseline( + dtype: torch.dtype, + quant_dtypes: tuple[Optional[torch.dtype], Optional[torch.dtype], + Optional[torch.dtype]], batch_size: int, + max_seq_lens: tuple[int, int], num_heads: tuple[int, int], head_size: int, - block_size: int, kv_layout: str, - dtype: torch.dtype, - kv_cache_dtype: Optional[torch.dtype], + block_size: int, soft_cap: Optional[float], ) -> None: - kv_cache_dtype = dtype if kv_cache_dtype is None else kv_cache_dtype - torch.set_default_device("cuda") current_platform.seed_everything(0) - kv_lens = torch.randint(1, MAX_KV_LEN, (batch_size, ), dtype=torch.int32) - kv_lens[-1] = MAX_KV_LEN - max_kv_len = torch.max(kv_lens).item() - num_seqs = len(kv_lens) + q_quant_dtype, kv_quant_dtype, o_quant_dtype = quant_dtypes + q_quant_dtype = q_quant_dtype or dtype + kv_quant_dtype = kv_quant_dtype or dtype + o_quant_dtype = o_quant_dtype or dtype - num_query_heads = num_heads[0] - num_kv_heads = num_heads[1] - assert num_query_heads % num_kv_heads == 0 + _, max_kv_len = max_seq_lens - scale = head_size**-0.5 + num_qo_heads, num_kv_heads = num_heads + assert num_qo_heads % num_kv_heads == 0 - query = torch.randn(num_seqs, num_query_heads, head_size, dtype=dtype) + sm_scale = float(1.0 / (head_size**0.5)) kv_cache_shape = None if kv_layout == "NHD": @@ -83,156 +86,39 @@ def test_flashinfer_trtllm_decode_with_baseline( kv_cache_shape = (NUM_BLOCKS, 2, num_kv_heads, block_size, head_size) else: raise ValueError(f"Invalid kv_layout: {kv_layout}") - key_value_cache = torch.randn(kv_cache_shape, dtype=dtype) - kv_scale = 1.0 - if kv_cache_dtype is current_platform.fp8_dtype(): - key_value_cache, kv_scale = to_float8(key_value_cache, - current_platform.fp8_dtype()) - max_num_blocks_per_seq = (max_kv_len + block_size - 1) // block_size - block_tables = torch.randint(0, - NUM_BLOCKS, - (num_seqs, max_num_blocks_per_seq), - dtype=torch.int32) - k_scale = v_scale = kv_scale - kv_indptr = [0] - kv_indices = [] - kv_last_page_lens = [] - for i in range(num_seqs): - seq_len = kv_lens[i] - assert seq_len > 0 - num_blocks = (seq_len + block_size - 1) // block_size - kv_indices.extend(block_tables[i, :num_blocks]) - kv_indptr.append(kv_indptr[-1] + num_blocks) - kv_last_page_len = seq_len % block_size - if kv_last_page_len == 0: - kv_last_page_len = block_size - kv_last_page_lens.append(kv_last_page_len) + query = torch.randn(batch_size, num_qo_heads, head_size, dtype=dtype) + if q_quant_dtype == FP8_DTYPE: + query, q_scale = to_float8(query) + ref_query = query.to(dtype) * q_scale + else: + q_scale = 1.0 + ref_query = query - kv_indptr = torch.tensor(kv_indptr, dtype=torch.int32) - kv_indices = torch.tensor(kv_indices, dtype=torch.int32) - kv_last_page_lens = torch.tensor(kv_last_page_lens, dtype=torch.int32) + kv_lens = torch.randint(1, max_kv_len, (batch_size, ), dtype=torch.int32) + kv_lens[-1] = max_kv_len - workspace_buffer = torch.zeros(128 * 1024 * 1024, dtype=torch.int8) - wrapper = flashinfer.BatchDecodeWithPagedKVCacheWrapper( - workspace_buffer, - kv_layout, - use_tensor_cores=((num_query_heads // num_kv_heads) > 4)) - wrapper.plan(kv_indptr, - kv_indices, - kv_last_page_lens, - num_query_heads, - num_kv_heads, - head_size, - block_size, - "NONE", - sm_scale=scale, - q_data_type=dtype, - kv_data_type=kv_cache_dtype, - logits_soft_cap=soft_cap) - - output = torch.empty(query.shape, dtype=dtype) - wrapper.run(query, - key_value_cache, - k_scale=k_scale, - v_scale=v_scale, - out=output) - - # TRTLLM Decode - kv_lens_tensor = torch.tensor(kv_lens, dtype=torch.int32) - output_trtllm = torch.empty(query.shape, dtype=dtype) - flashinfer.decode.trtllm_batch_decode_with_kv_cache( - query=query.contiguous(), - kv_cache=key_value_cache, - workspace_buffer=workspace_buffer, - block_tables=block_tables, - seq_lens=kv_lens_tensor, - max_seq_len=max_kv_len, - bmm1_scale=k_scale * scale, - bmm2_scale=v_scale, - out=output_trtllm, - ) - - torch.testing.assert_close(output, output_trtllm, atol=1e-2, rtol=1e-2), \ - f"{torch.max(torch.abs(output - output_trtllm))}" - - -@pytest.mark.parametrize("batch_size", BATCH_SIZES) -@pytest.mark.parametrize("num_heads", NUM_HEADS) -@pytest.mark.parametrize("head_size", HEAD_SIZES) -@pytest.mark.parametrize("block_size", BLOCK_SIZES) -@pytest.mark.parametrize("kv_layout", KV_LAYOUTS) -@pytest.mark.parametrize("dtype", DTYPES) -@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPES) -@pytest.mark.parametrize("soft_cap", [None]) -@torch.inference_mode -def test_flashinfer_trtllm_prefill_with_baseline( - batch_size: int, - num_heads: tuple[int, int], - head_size: int, - block_size: int, - kv_layout: str, - dtype: torch.dtype, - kv_cache_dtype: Optional[torch.dtype], - soft_cap: Optional[float], -) -> None: - kv_cache_dtype = dtype if kv_cache_dtype is None else kv_cache_dtype - if dtype != kv_cache_dtype: - pytest.skip(f"Not supported dtype({dtype}) with " - "kv_cache_dtype({kv_cache_dtype})") - - torch.set_default_device("cuda") - current_platform.seed_everything(0) - - q_lens = torch.randint(1, MAX_Q_LEN, (batch_size, ), dtype=torch.int32) - q_lens[-1] = MAX_Q_LEN - max_q_len = torch.max(q_lens).item() - q_indptr = torch.cat([ - torch.tensor([0], dtype=torch.int32), - torch.cumsum(q_lens, dim=0, dtype=torch.int32), - ]) - - kv_lens = torch.randint(0, MAX_KV_LEN, (batch_size, ), dtype=torch.int32) - kv_lens[-1] = MAX_KV_LEN - - seq_lens = kv_lens + q_lens + seq_lens = kv_lens max_seq_len = torch.max(seq_lens).item() - num_seqs = len(seq_lens) - num_query_heads = num_heads[0] - num_kv_heads = num_heads[1] - assert num_query_heads % num_kv_heads == 0 - - scale = head_size**-0.5 - - query = torch.randn(torch.sum(q_lens).item(), - num_query_heads, - head_size, - dtype=dtype) - - kv_cache_shape = None - if kv_layout == "NHD": - kv_cache_shape = (NUM_BLOCKS, 2, block_size, num_kv_heads, head_size) - elif kv_layout == "HND": - kv_cache_shape = (NUM_BLOCKS, 2, num_kv_heads, block_size, head_size) + kv_cache = torch.randn(kv_cache_shape, dtype=dtype) + if kv_quant_dtype == FP8_DTYPE: + kv_cache, kv_scale = to_float8(kv_cache) + ref_kv_cache = kv_cache.to(dtype) * kv_scale else: - raise ValueError(f"Invalid kv_layout: {kv_layout}") - key_value_cache = torch.randn(kv_cache_shape, dtype=dtype) - kv_scale = 1.0 - if kv_cache_dtype is current_platform.fp8_dtype(): - key_value_cache, kv_scale = to_float8(key_value_cache, - current_platform.fp8_dtype()) + kv_scale = 1.0 + ref_kv_cache = kv_cache + k_scale = v_scale = kv_scale max_num_blocks_per_seq = (max_seq_len + block_size - 1) // block_size block_tables = torch.randint(0, NUM_BLOCKS, - (num_seqs, max_num_blocks_per_seq), + (batch_size, max_num_blocks_per_seq), dtype=torch.int32) - k_scale = v_scale = kv_scale kv_indptr = [0] kv_indices = [] kv_last_page_lens = [] - for i in range(num_seqs): + for i in range(batch_size): seq_len = seq_lens[i] assert seq_len > 0 num_blocks = (seq_len + block_size - 1) // block_size @@ -246,48 +132,206 @@ def test_flashinfer_trtllm_prefill_with_baseline( kv_indptr = torch.tensor(kv_indptr, dtype=torch.int32) kv_indices = torch.tensor(kv_indices, dtype=torch.int32) kv_last_page_lens = torch.tensor(kv_last_page_lens, dtype=torch.int32) - workspace_buffer = torch.zeros(128 * 1024 * 1024, dtype=torch.int8) + + # Baseline Decode + wrapper = flashinfer.BatchDecodeWithPagedKVCacheWrapper( + workspace_buffer, + kv_layout, + use_tensor_cores=((num_qo_heads // num_kv_heads) > 4)) + wrapper.plan(kv_indptr, + kv_indices, + kv_last_page_lens, + num_qo_heads, + num_kv_heads, + head_size, + block_size, + "NONE", + sm_scale=sm_scale, + q_data_type=dtype, + kv_data_type=dtype, + logits_soft_cap=soft_cap) + + output = torch.empty(ref_query.shape, dtype=dtype) + wrapper.run(ref_query, ref_kv_cache, out=output) + o_scale = 1.0 + if o_quant_dtype == FP8_DTYPE: + _, o_scale = to_float8(output) + + # TRTLLM Decode + output_trtllm = torch.empty(query.shape, dtype=o_quant_dtype) + flashinfer.decode.trtllm_batch_decode_with_kv_cache( + query=query, + kv_cache=kv_cache, + workspace_buffer=workspace_buffer, + block_tables=block_tables, + seq_lens=seq_lens, + max_seq_len=max_seq_len, + bmm1_scale=q_scale * k_scale * sm_scale, + bmm2_scale=v_scale / o_scale, + out=output_trtllm, + ) + if o_quant_dtype == FP8_DTYPE: + output_trtllm = output_trtllm.to(dtype) * o_scale + + if q_quant_dtype == FP8_DTYPE and o_quant_dtype == FP8_DTYPE: + rtol, atol = 5e-2, 7e-2 + else: + rtol, atol = 1e-2, 1e-2 + + torch.testing.assert_close(output, output_trtllm, atol=atol, rtol=rtol), \ + f"{torch.max(torch.abs(output - output_trtllm))}" + + +@pytest.mark.parametrize("dtype", DTYPE) +@pytest.mark.parametrize("quant_dtypes", QUANT_DTYPES) +@pytest.mark.parametrize("batch_size", BATCH_SIZE) +@pytest.mark.parametrize("max_seq_lens", MAX_SEQ_LENS) +@pytest.mark.parametrize("num_heads", NUM_HEADS) +@pytest.mark.parametrize("head_size", HEAD_SIZE) +@pytest.mark.parametrize("kv_layout", KV_LAYOUT) +@pytest.mark.parametrize("block_size", BLOCK_SIZE) +@pytest.mark.parametrize("soft_cap", [None]) +@torch.inference_mode +def test_flashinfer_trtllm_prefill_with_baseline( + dtype: torch.dtype, + quant_dtypes: tuple[Optional[torch.dtype], Optional[torch.dtype], + Optional[torch.dtype]], + batch_size: int, + max_seq_lens: tuple[int, int], + num_heads: tuple[int, int], + head_size: int, + kv_layout: str, + block_size: int, + soft_cap: Optional[float], +) -> None: + torch.set_default_device("cuda") + current_platform.seed_everything(0) + + q_quant_dtype, kv_quant_dtype, o_quant_dtype = quant_dtypes + q_quant_dtype = q_quant_dtype or dtype + kv_quant_dtype = kv_quant_dtype or dtype + o_quant_dtype = o_quant_dtype or dtype + + max_q_len, max_kv_len = max_seq_lens + + num_qo_heads, num_kv_heads = num_heads + assert num_qo_heads % num_kv_heads == 0 + + sm_scale = float(1.0 / (head_size**0.5)) + + kv_cache_shape = None + if kv_layout == "NHD": + kv_cache_shape = (NUM_BLOCKS, 2, block_size, num_kv_heads, head_size) + elif kv_layout == "HND": + kv_cache_shape = (NUM_BLOCKS, 2, num_kv_heads, block_size, head_size) + else: + raise ValueError(f"Invalid kv_layout: {kv_layout}") + + q_lens = torch.randint(1, max_q_len, (batch_size, ), dtype=torch.int32) + q_lens[-1] = max_q_len + q_indptr = torch.cat([ + torch.tensor([0], dtype=torch.int32), + torch.cumsum(q_lens, dim=0, dtype=torch.int32), + ]) + + query = torch.randn(torch.sum(q_lens).item(), + num_qo_heads, + head_size, + dtype=dtype) + if q_quant_dtype == FP8_DTYPE: + query, q_scale = to_float8(query) + ref_query = query.to(dtype) * q_scale + else: + q_scale = 1.0 + ref_query = query + + kv_lens = torch.randint(0, max_kv_len, (batch_size, ), dtype=torch.int32) + kv_lens[-1] = max_kv_len + + seq_lens = kv_lens + q_lens + max_seq_len = torch.max(seq_lens).item() + + kv_cache = torch.randn(kv_cache_shape, dtype=dtype) + if kv_quant_dtype == FP8_DTYPE: + kv_cache, kv_scale = to_float8(kv_cache) + ref_kv_cache = kv_cache.to(dtype) * kv_scale + else: + kv_scale = 1.0 + ref_kv_cache = kv_cache + k_scale = v_scale = kv_scale + + max_num_blocks_per_seq = (max_seq_len + block_size - 1) // block_size + block_tables = torch.randint(0, + NUM_BLOCKS, + (batch_size, max_num_blocks_per_seq), + dtype=torch.int32) + kv_indptr = [0] + kv_indices = [] + kv_last_page_lens = [] + for i in range(batch_size): + seq_len = seq_lens[i] + assert seq_len > 0 + num_blocks = (seq_len + block_size - 1) // block_size + kv_indices.extend(block_tables[i, :num_blocks]) + kv_indptr.append(kv_indptr[-1] + num_blocks) + kv_last_page_len = seq_len % block_size + if kv_last_page_len == 0: + kv_last_page_len = block_size + kv_last_page_lens.append(kv_last_page_len) + + kv_indptr = torch.tensor(kv_indptr, dtype=torch.int32) + kv_indices = torch.tensor(kv_indices, dtype=torch.int32) + kv_last_page_lens = torch.tensor(kv_last_page_lens, dtype=torch.int32) + workspace_buffer = torch.zeros(128 * 1024 * 1024, dtype=torch.int8) + + # Baseline Prefill wrapper = flashinfer.BatchPrefillWithPagedKVCacheWrapper( workspace_buffer, kv_layout) wrapper.plan(q_indptr, kv_indptr, kv_indices, kv_last_page_lens, - num_query_heads, + num_qo_heads, num_kv_heads, head_size, block_size, causal=True, - sm_scale=scale, + sm_scale=sm_scale, q_data_type=dtype, - kv_data_type=kv_cache_dtype, + kv_data_type=dtype, logits_soft_cap=soft_cap) - output = torch.empty(query.shape, dtype=dtype) - wrapper.run(query, - key_value_cache, - k_scale=k_scale, - v_scale=v_scale, - out=output) + output = torch.empty(ref_query.shape, dtype=dtype) + wrapper.run(ref_query, ref_kv_cache, out=output) + o_scale = 1.0 + if o_quant_dtype == FP8_DTYPE: + _, o_scale = to_float8(output) - # TRTLLM Decode - output_trtllm = torch.empty(query.shape, dtype=dtype) + # TRTLLM Prefill + output_trtllm = torch.empty(query.shape, dtype=o_quant_dtype) flashinfer.prefill.trtllm_batch_context_with_kv_cache( - query=query.contiguous(), - kv_cache=key_value_cache, + query=query, + kv_cache=kv_cache, workspace_buffer=workspace_buffer, block_tables=block_tables, seq_lens=seq_lens, max_q_len=max_q_len, max_kv_len=max_seq_len, - bmm1_scale=k_scale * scale, - bmm2_scale=v_scale, - batch_size=num_seqs, + bmm1_scale=q_scale * k_scale * sm_scale, + bmm2_scale=v_scale / o_scale, + batch_size=batch_size, cum_seq_lens_q=q_indptr, cum_seq_lens_kv=kv_indptr, out=output_trtllm, ) + if o_quant_dtype == FP8_DTYPE: + output_trtllm = output_trtllm.to(dtype) * o_scale - torch.testing.assert_close(output, output_trtllm, atol=1e-2, rtol=1e-2), \ + if q_quant_dtype == FP8_DTYPE and o_quant_dtype == FP8_DTYPE: + rtol, atol = 5e-2, 7e-2 + else: + rtol, atol = 1e-2, 1e-2 + + torch.testing.assert_close(output, output_trtllm, atol=atol, rtol=rtol), \ f"{torch.max(torch.abs(output - output_trtllm))}" diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index 0e87fa3f23e3..04ab100c8775 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -128,11 +128,17 @@ class Attention(nn.Module): self._q_scale = torch.tensor(1.0, dtype=torch.float32) self._prob_scale = torch.tensor(1.0, dtype=torch.float32) - # We also keep the float32 versions of k/v_scale for attention - # backends that don't support tensors (Flashinfer) + # We also keep q/k/v_scale on host (cpu) memory for attention + # backends that require the scales to be on host instead of on device. + # e.g. Flashinfer + self._q_scale_float = 1.0 self._k_scale_float = 1.0 self._v_scale_float = 1.0 + # The output scale on host memory. This should be the input scale of + # the quant op after this attention layer. + self._o_scale_float: Optional[float] = None + self.use_mla = use_mla self.num_heads = num_heads self.head_size = head_size @@ -291,6 +297,7 @@ class Attention(nn.Module): self._q_scale.copy_(torch.abs(query).max() / self.q_range) self._k_scale.copy_(torch.abs(key).max() / self.k_range) self._v_scale.copy_(torch.abs(value).max() / self.v_range) + self._q_scale_float = self._q_scale.item() self._k_scale_float = self._k_scale.item() self._v_scale_float = self._v_scale.item() # We only calculate the scales once diff --git a/vllm/compilation/fusion_attn.py b/vllm/compilation/fusion_attn.py index a40a8caf34a8..1f77a2667613 100644 --- a/vllm/compilation/fusion_attn.py +++ b/vllm/compilation/fusion_attn.py @@ -9,7 +9,7 @@ from torch._subclasses.fake_tensor import (FakeTensorMode, unset_fake_temporarily) from vllm.attention import Attention -from vllm.config import VllmConfig +from vllm.config import VllmConfig, get_layers_from_vllm_config from vllm.logger import init_logger from vllm.platforms import current_platform @@ -18,23 +18,32 @@ from .vllm_inductor_pass import VllmInductorPass logger = init_logger(__name__) +FP8_DTYPE = current_platform.fp8_dtype() + ATTN_OP = torch.ops.vllm.unified_attention_with_output.default RESHAPE_OP = torch.ops.aten.reshape.default class AttentionStaticQuantPattern: + """ + Fusion for Attention+StaticQuant. + + Only triggers when the attention implementation returns True in + `fused_output_quant_supported()`. If the pattern is found, the StaticQuant + op will be removed from the graph, and its scale will be passed into + Attention op as the `output_scale` argument. + """ def __init__( self, - layer_name: str, - num_heads: int, - head_size: int, + layer: Attention, quant_dtype: torch.dtype, symmetric=True, ): - self.layer_name = layer_name - self.num_heads = num_heads - self.head_size = head_size + self.layer = layer + self.layer_name = layer.layer_name + self.num_heads = layer.num_heads + self.head_size = layer.head_size self.quant_dtype = quant_dtype self.quant_key = QuantKey(dtype=quant_dtype, static=True, @@ -48,11 +57,10 @@ class AttentionStaticQuantPattern: kwargs = {'dtype': self.quant_dtype, 'device': "cuda", **kwargs} return torch.empty(*args, **kwargs) - def register_if_supported(self, pm_pass: PatternMatcherPass, - layer: Attention): - if layer.impl.fused_output_quant_supported(self.quant_dtype, - self.quant_key.static, - self.quant_key.group_shape): + def register_if_supported(self, pm_pass: PatternMatcherPass): + if self.layer.impl.fused_output_quant_supported( + self.quant_dtype, self.quant_key.static, + self.quant_key.group_shape): self._register(pm_pass) def _register(self, pm_pass: PatternMatcherPass): @@ -60,19 +68,15 @@ class AttentionStaticQuantPattern: def pattern(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, output_attn: torch.Tensor, output_quant: torch.Tensor, scale: torch.Tensor): - view_7 = RESHAPE_OP(output_attn, - [-1, self.num_heads, self.head_size]) - at1 = auto_functionalized(ATTN_OP, query=q, key=k, value=v, - output=view_7, + output=output_attn, layer_name=self.layer_name, output_scale=None) attn_out_view = RESHAPE_OP(at1[1], [-1, self.num_heads * self.head_size]) - at2 = auto_functionalized(self.QUANT_OP, result=output_quant, input=attn_out_view, @@ -82,17 +86,19 @@ class AttentionStaticQuantPattern: def replacement(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, output_attn: torch.Tensor, output_quant: torch.Tensor, scale: torch.Tensor): - view_7 = RESHAPE_OP(output_quant, - [-1, self.num_heads, self.head_size]) - + # attn output in quant_dtype + output_attn = torch.ops.aten.full.default( + [q.shape[0], self.num_heads, self.head_size], + 0.0, + dtype=self.quant_dtype, + device=q.device) at1 = auto_functionalized(ATTN_OP, query=q, key=k, value=v, - output=view_7, + output=output_attn, layer_name=self.layer_name, output_scale=scale) - return RESHAPE_OP(at1[1], [-1, self.num_heads * self.head_size]) # Need custom fake mode, otherwise tracing happens with real tensors. @@ -102,7 +108,7 @@ class AttentionStaticQuantPattern: empty_bf16(5, self.num_heads, self.head_size), # q empty_bf16(5, self.num_heads, self.head_size), # k empty_bf16(5, self.num_heads, self.head_size), # v - empty_bf16(5, self.num_heads * self.head_size), # attn_output + empty_bf16(5, self.num_heads, self.head_size), # attn_output self.empty_quant(5, self.num_heads * self.head_size), # quant_output empty_fp32(1, 1) # scale @@ -140,27 +146,30 @@ class AttnFusionPass(VllmInductorPass): def __init__(self, config: VllmConfig): super().__init__(config) - self.static_fwd_ctx = config.compilation_config.static_forward_context self.patterns = PatternMatcherPass(pass_name="attn_fusion_pass") - for key, layer in self.static_fwd_ctx.items(): - pattern = AttentionStaticQuantPattern(key, layer.num_heads, - layer.head_size, - current_platform.fp8_dtype()) - pattern.register_if_supported(self.patterns, layer) - if len(self.static_fwd_ctx) == 0: + attn_layers = get_layers_from_vllm_config(config, Attention) + for layer_name, layer in attn_layers.items(): + pattern = AttentionStaticQuantPattern(layer, FP8_DTYPE) + pattern.register_if_supported(self.patterns) + if len(attn_layers) == 0: logger.warning( - "Attention + quant fusion is enabled, but " - "CompilationConfig.static_forward_context is empty. " - "Cannot access attention layers so no fusion " - "patterns were registered.") + "Attention + quant fusion is enabled, but no attention layers " + "were found in CompilationConfig.static_forward_context " + "so no fusion patterns were registered.") def __call__(self, graph: torch.fx.graph.Graph) -> None: self.begin() self.dump_graph(graph, "before_attn_fusion") count = self.patterns.apply(graph) + + # TODO: Move this to pass_manager.py after the fx graph broken issue + # has been resolved. + # see https://github.com/vllm-project/vllm/issues/23091 + graph.eliminate_dead_code() + logger.debug("Fused quantization onto %s attention nodes", count) self.dump_graph(graph, "after_attn_fusion") self.end_and_log() diff --git a/vllm/utils/flashinfer.py b/vllm/utils/flashinfer.py index 2e31b7bad747..996be1265667 100644 --- a/vllm/utils/flashinfer.py +++ b/vllm/utils/flashinfer.py @@ -174,21 +174,30 @@ def supports_trtllm_attention() -> tuple[bool, Optional[str]]: def use_trtllm_attention( + num_qo_heads: int, + num_kv_heads: int, num_tokens: int, max_seq_len: int, kv_cache_dtype: str, - num_qo_heads: Optional[int], - num_kv_heads: Optional[int], - attn_head_size: Optional[int], + q_dtype: torch.dtype, + is_prefill: bool, has_sinks: bool = False, ) -> bool: use_trtllm, env_value = supports_trtllm_attention() if not use_trtllm: return False - # Check if the dimensions are supported by TRTLLM decode attention - if (attn_head_size is None or num_qo_heads is None or num_kv_heads is None - or num_qo_heads % num_kv_heads != 0): + if num_qo_heads % num_kv_heads != 0: + return False + + # Must use TRTLLM attention if query is FP8 quantized + if q_dtype == current_platform.fp8_dtype(): + logger.info_once("Using TRTLLM attention (query is quantized).") + return True + + # TRTLLM prefill attention does not support FP8 kv cache with + # non-quantized query + if is_prefill and kv_cache_dtype.startswith("fp8"): return False # If sinks are being used, we must use TRTLLM attention as it's @@ -290,6 +299,7 @@ __all__ = [ "has_flashinfer_moe", "has_flashinfer_cutlass_fused_moe", "has_nvidia_artifactory", + "supports_trtllm_attention", "use_trtllm_attention", "flashinfer_scaled_fp4_mm", ] diff --git a/vllm/v1/attention/backends/flashinfer.py b/vllm/v1/attention/backends/flashinfer.py index 991904229fd7..c56e721dff8c 100755 --- a/vllm/v1/attention/backends/flashinfer.py +++ b/vllm/v1/attention/backends/flashinfer.py @@ -15,12 +15,17 @@ from flashinfer.decode import (_get_range_buf, get_seq_lens, from flashinfer.prefill import trtllm_batch_context_with_kv_cache import vllm.envs as envs +from vllm import _custom_ops as ops from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, AttentionType) from vllm.config import CUDAGraphMode, VllmConfig from vllm.logger import init_logger +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + GroupShape) +from vllm.platforms import current_platform from vllm.utils import cdiv, is_pin_memory_available -from vllm.utils.flashinfer import use_trtllm_attention +from vllm.utils.flashinfer import (supports_trtllm_attention, + use_trtllm_attention) from vllm.v1.attention.backends.flash_attn import use_cascade_attention # yapf conflicts with isort for this block # yapf: disable @@ -35,6 +40,8 @@ from vllm.v1.kv_cache_interface import AttentionSpec FLASHINFER_WORKSPACE_BUFFER_SIZE = 256 * 1024 * 1024 +FP8_DTYPE = current_platform.fp8_dtype() + logger = init_logger(__name__) @@ -519,22 +526,27 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): else: kv_cache_dtype = self.kv_cache_spec.dtype - num_qo_heads = self.vllm_config.model_config.get_num_attention_heads( - self.vllm_config.parallel_config) + config = self.vllm_config + num_qo_heads = config.model_config.get_num_attention_heads( + config.parallel_config) num_kv_heads = self.kv_cache_spec.num_kv_heads head_dim = self.kv_cache_spec.head_size # Check if any layer uses sinks (requires TRTLLM attention) has_sinks = self.global_hyperparameters.has_sinks - # currently prefill trtllm attention does not support fp8 kv cache - prefill_use_trtllm = not cache_dtype.startswith("fp8") \ - and use_trtllm_attention( - num_prefill_tokens, max_seq_len, cache_dtype, - num_qo_heads, num_kv_heads, head_dim, has_sinks) + # Insert FP8 quant for query if FP8 kv cache and attn fusion enabled + q_dtype = config.model_config.dtype + enable_fusion = config.compilation_config.pass_config.enable_attn_fusion + if cache_dtype.startswith("fp8") and enable_fusion: + q_dtype = kv_cache_dtype + + prefill_use_trtllm = use_trtllm_attention( + num_qo_heads, num_kv_heads, num_prefill_tokens, max_seq_len, + cache_dtype, q_dtype, is_prefill=True, has_sinks=has_sinks) decode_use_trtllm = use_trtllm_attention( - num_decode_tokens, max_seq_len, cache_dtype, - num_qo_heads, num_kv_heads, head_dim, has_sinks) + num_qo_heads, num_kv_heads, num_decode_tokens, max_seq_len, + cache_dtype, q_dtype, is_prefill=False, has_sinks=has_sinks) attn_metadata = FlashInferMetadata( num_actual_tokens=num_actual_tokens, @@ -548,7 +560,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): head_dim=head_dim, page_size=page_size, kv_data_type=kv_cache_dtype, - q_data_type=self.vllm_config.model_config.dtype, + q_data_type=q_dtype, slot_mapping=common_attn_metadata.slot_mapping, max_q_len=max_q_len, max_seq_len=max_seq_len, @@ -622,6 +634,8 @@ class FlashInferImpl(AttentionImpl): self.sliding_window = (-1, -1) else: self.sliding_window = (sliding_window - 1, 0) + self.window_left = (self.sliding_window[0] + if self.sliding_window is not None else -1) self.kv_cache_dtype = kv_cache_dtype self.logits_soft_cap = logits_soft_cap self.kv_sharing_target_layer_name = kv_sharing_target_layer_name @@ -644,6 +658,19 @@ class FlashInferImpl(AttentionImpl): ) self.sinks = sinks + self.support_trtllm_attn = (supports_trtllm_attention() and + num_heads % num_kv_heads == 0) + self.bmm1_scale: Optional[float] = None + self.bmm2_scale: Optional[float] = None + + def fused_output_quant_supported(self, dtype: torch.dtype, static: bool, + group_shape: GroupShape): + supported_quant_type = (dtype == FP8_DTYPE and static and + group_shape == GroupShape.PER_TENSOR) + return (self.support_trtllm_attn + and self.kv_cache_dtype.startswith("fp8") + and supported_quant_type) + def forward( self, layer: torch.nn.Module, @@ -672,15 +699,42 @@ class FlashInferImpl(AttentionImpl): """ assert output is not None, "Output tensor must be provided." - if output_scale is not None: - raise NotImplementedError( - "fused output quantization is not yet supported" - " for FlashInferImpl") - if attn_metadata is None: # Profiling run. return output + if self.bmm1_scale is None: + self.bmm1_scale = (layer._q_scale_float * layer._k_scale_float * + self.scale) + + if self.bmm2_scale is None: + self.bmm2_scale = layer._v_scale_float + + # The attn+quant fusion happens when output_scale is provided. + if output_scale is None: + assert attn_metadata.q_data_type != FP8_DTYPE, \ + "Query can only be FP8 if output fusion happened." + else: + assert attn_metadata.q_data_type == FP8_DTYPE, \ + "Query must be FP8 when attn+quant fusion happened." + assert (attn_metadata.prefill_use_trtllm and + attn_metadata.decode_use_trtllm), "Must use TRT-LLM attn" + assert output.dtype == FP8_DTYPE, \ + "Output must be FP8 when attn+quant fusion happened." + + # TRTLLM attn kernel requires o scale as a host scalar, store the + # o scale to host scalar in warmup run with cuda graph not enabled + if layer._o_scale_float is None: + layer._o_scale_float = output_scale.cpu().item() + self.bmm2_scale = self.bmm2_scale / layer._o_scale_float + + # Insert FP8 quant for query + num_tokens, num_heads, head_size = query.shape + query, _ = ops.scaled_fp8_quant( + query.reshape((num_tokens, num_heads * head_size)).contiguous(), + layer._q_scale) + query = query.reshape((num_tokens, num_heads, head_size)) + # IMPORTANT! # NOTE(woosuk): With piece-wise CUDA graphs, this method is executed in # eager-mode PyTorch. Thus, we need to be careful about any CPU overhead @@ -718,9 +772,6 @@ class FlashInferImpl(AttentionImpl): self.kv_cache_dtype) kv_cache = kv_cache.view(torch_dtype) - window_left = (self.sliding_window[0] - if self.sliding_window is not None else -1) - # Inputs and outputs may be padded for CUDA graphs query = query[:num_actual_tokens] output_padded = output @@ -748,7 +799,7 @@ class FlashInferImpl(AttentionImpl): if not attn_metadata.prefill_use_trtllm: assert prefill_wrapper._causal - assert prefill_wrapper._window_left == window_left + assert prefill_wrapper._window_left == self.window_left assert prefill_wrapper._logits_soft_cap == ( self.logits_soft_cap or 0.0) assert prefill_wrapper._sm_scale == self.scale @@ -783,12 +834,12 @@ class FlashInferImpl(AttentionImpl): seq_lens=seq_lens_prefill, max_q_len=attn_metadata.max_q_len, max_kv_len=attn_metadata.max_seq_len, - bmm1_scale=layer._k_scale_float * self.scale, - bmm2_scale=layer._v_scale_float, + bmm1_scale=self.bmm1_scale, + bmm2_scale=self.bmm2_scale, batch_size=attn_metadata.num_prefills, cum_seq_lens_q=attn_metadata.qo_indptr_gpu, cum_seq_lens_kv=attn_metadata.paged_kv_indptr_gpu, - window_left=window_left, + window_left=self.window_left, sinks=self.sinks, out=output[num_decode_tokens:], ) @@ -800,7 +851,7 @@ class FlashInferImpl(AttentionImpl): assert decode_wrapper is not None if not attn_metadata.decode_use_trtllm: - assert decode_wrapper._window_left == window_left + assert decode_wrapper._window_left == self.window_left assert decode_wrapper._logits_soft_cap == (self.logits_soft_cap or 0.0) assert decode_wrapper._sm_scale == self.scale @@ -815,8 +866,8 @@ class FlashInferImpl(AttentionImpl): # decode_query may be non-contiguous decode_query = decode_query.contiguous() workspace_buffer = decode_wrapper._float_workspace_buffer - block_tables_decode = attn_metadata.block_table_tensor[: - num_decode_tokens] + block_tables_decode = attn_metadata.\ + block_table_tensor[:num_decode_tokens] seq_lens_decode = attn_metadata.seq_lens[:num_decode_tokens] # This path needs to be enabled with VLLM_KV_CACHE_LAYOUT = HND @@ -834,9 +885,9 @@ class FlashInferImpl(AttentionImpl): block_tables=block_tables_decode, seq_lens=seq_lens_decode, max_seq_len=attn_metadata.max_seq_len, - bmm1_scale=layer._k_scale_float * self.scale, - bmm2_scale=layer._v_scale_float, - window_left=window_left, + bmm1_scale=self.bmm1_scale, + bmm2_scale=self.bmm2_scale, + window_left=self.window_left, sinks=self.sinks, out=output[:num_decode_tokens], )