diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index b7a2ca6ca9b2..e139c6b30586 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -664,7 +664,7 @@ steps: # Attention # num_heads2 broken by https://github.com/flashinfer-ai/flashinfer/issues/1353 - pytest -v -s tests/kernels/attention/test_flashinfer.py -k 'not num_heads2' - - pytest -v -s tests/kernels/attention/test_flashinfer_trtllm_decode_attention.py + - pytest -v -s tests/kernels/attention/test_flashinfer_trtllm_attention.py - pytest -v -s tests/kernels/test_cutlass_mla_decode.py # Quantization - pytest -v -s tests/kernels/quantization/test_cutlass_scaled_mm.py -k 'fp8' diff --git a/benchmarks/kernels/benchmark_trtllm_attention.py b/benchmarks/kernels/benchmark_trtllm_decode_attention.py similarity index 99% rename from benchmarks/kernels/benchmark_trtllm_attention.py rename to benchmarks/kernels/benchmark_trtllm_decode_attention.py index 68c48858e61c..77136edca45b 100644 --- a/benchmarks/kernels/benchmark_trtllm_attention.py +++ b/benchmarks/kernels/benchmark_trtllm_decode_attention.py @@ -41,7 +41,6 @@ def benchmark_decode( device = "cuda" torch.manual_seed(0) - # Currently only HEAD_GRP_SIZE == 8 is supported HEAD_GRP_SIZE = 8 MAX_SEQ_LEN = max_seq_len diff --git a/benchmarks/kernels/benchmark_trtllm_prefill_attention.py b/benchmarks/kernels/benchmark_trtllm_prefill_attention.py new file mode 100644 index 000000000000..67bd9aebbcca --- /dev/null +++ b/benchmarks/kernels/benchmark_trtllm_prefill_attention.py @@ -0,0 +1,250 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import csv +import os +import random +from datetime import datetime + +import flashinfer +import torch + +FLOAT32_BYTES = torch.finfo(torch.float).bits // 8 + +# KV Cache Layout for TRT-LLM +# kv_cache_shape = (num_blocks, 2, num_kv_heads, page_size, head_dim) + + +def to_float8(x, dtype=torch.float8_e4m3fn): + finfo = torch.finfo(dtype) + min_val, max_val = x.aminmax() + amax = torch.maximum(min_val.abs(), max_val.abs()).clamp(min=1e-12) + scale = finfo.max / amax * 0.1 + x_scl_sat = (x * scale).clamp(min=finfo.min, max=finfo.max) + return x_scl_sat.to(dtype), scale.float().reciprocal() + + +@torch.no_grad() +def benchmark_prefill( + num_seqs, + max_seq_len, + page_size=16, + dtype=torch.bfloat16, + kv_layout="HND", + num_kv_heads=8, + kv_cache_dtype="auto", + head_dim=128, + warmup=10, + trials=20, +): + torch.set_default_device("cuda") + torch.manual_seed(0) + + HEAD_GRP_SIZE = 8 + MAX_SEQ_LEN = max_seq_len + + # large number to reduce kv_cache reuse + NUM_BLOCKS = int(256000 / page_size) + + workspace_buffer = torch.empty(1024 * 1024 * 1024, dtype=torch.int8) + + num_qo_heads = num_kv_heads * HEAD_GRP_SIZE + sm_scale = float(1.0 / (head_dim**0.5)) + + q_lens = [random.randint(1, MAX_SEQ_LEN) for _ in range(num_seqs)] + q_lens[-1] = MAX_SEQ_LEN + max_q_len = max(q_lens) + q_indptr = torch.cat( + [ + torch.tensor([0], dtype=torch.int32), + torch.cumsum( + torch.tensor(q_lens, dtype=torch.int32), dim=0, dtype=torch.int32 + ), + ] + ) + q = torch.randn(sum(q_lens), num_qo_heads, head_dim, dtype=dtype) + + kv_lens = [random.randint(0, MAX_SEQ_LEN) for _ in range(num_seqs)] + kv_lens[-1] = MAX_SEQ_LEN + + seq_lens = [q_len + kv_len for q_len, kv_len in zip(q_lens, kv_lens)] + max_seq_len = max(seq_lens) + seq_lens_tensor = torch.tensor(seq_lens, dtype=torch.int32) + + max_num_blocks_per_seq = (max_seq_len + page_size - 1) // page_size + block_tables = torch.randint( + 0, NUM_BLOCKS, (num_seqs, max_num_blocks_per_seq), dtype=torch.int32 + ) + + kv_cache_shape = (NUM_BLOCKS, 2, num_kv_heads, page_size, head_dim) + kv_cache = torch.randn(size=kv_cache_shape, dtype=dtype) + k_scale = v_scale = 1.0 + + if kv_cache_dtype.startswith("fp8"): + kv_cache, _ = to_float8(kv_cache) + + output_trtllm = torch.empty(q.shape, dtype=dtype) + + kv_indptr = [0] + kv_indices = [] + kv_last_page_lens = [] + for i in range(num_seqs): + seq_len = seq_lens[i] + assert seq_len > 0 + num_blocks = (seq_len + page_size - 1) // page_size + kv_indices.extend(block_tables[i, :num_blocks]) + kv_indptr.append(kv_indptr[-1] + num_blocks) + kv_last_page_len = seq_len % page_size + if kv_last_page_len == 0: + kv_last_page_len = page_size + kv_last_page_lens.append(kv_last_page_len) + + kv_indptr = torch.tensor(kv_indptr, dtype=torch.int32) + kv_indices = torch.tensor(kv_indices, dtype=torch.int32) + kv_last_page_lens = torch.tensor(kv_last_page_lens, dtype=torch.int32) + + output_baseline = torch.empty(q.shape, dtype=dtype) + + wrapper = flashinfer.BatchPrefillWithPagedKVCacheWrapper( + workspace_buffer, kv_layout + ) + wrapper.plan( + q_indptr, + kv_indptr, + kv_indices, + kv_last_page_lens, + num_qo_heads, + num_kv_heads, + head_dim, + page_size, + causal=True, + sm_scale=sm_scale, + q_data_type=dtype, + kv_data_type=kv_cache.dtype, + ) + + def time_fn(fn, warmup=10, trials=20): + torch.cuda.synchronize() + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + times = [] + for i in range(warmup): + fn() + for i in range(trials): + start.record() + fn() + end.record() + torch.cuda.synchronize() + times.append(start.elapsed_time(end)) # ms + return sum(times) / len(times), torch.std(torch.tensor(times)) + + def baseline_prefill(): + return wrapper.run( + q, kv_cache, k_scale=k_scale, v_scale=v_scale, out=output_baseline + ) + + def trt_prefill(): + return flashinfer.prefill.trtllm_batch_context_with_kv_cache( + query=q, + kv_cache=kv_cache, + workspace_buffer=workspace_buffer, + block_tables=block_tables, + seq_lens=seq_lens_tensor, + max_q_len=max_q_len, + max_kv_len=max_seq_len, + bmm1_scale=k_scale * sm_scale, + bmm2_scale=v_scale, + batch_size=num_seqs, + cum_seq_lens_q=q_indptr, + cum_seq_lens_kv=kv_indptr, + out=output_trtllm, + ) + + trt_mean, trt_std = time_fn(trt_prefill) + baseline_mean, baseline_std = time_fn(baseline_prefill) + + # Calculate percentage speedup (positive means TRT is faster) + speedup_percent = (baseline_mean - trt_mean) / baseline_mean + + print( + f"\t{num_seqs}\t{max_seq_len}\t{trt_mean:.5f}\t{trt_std.item():.5f}" + f"\t{baseline_mean:.5f}\t{baseline_std.item():.5f}\t{speedup_percent:.5f}" + ) + + # Return results for CSV writing + return { + "num_seqs": num_seqs, + "trt_mean": trt_mean, + "trt_std": trt_std.item(), + "baseline_mean": baseline_mean, + "baseline_std": baseline_std.item(), + "speedup_percent": speedup_percent, + "q_dtype": str(dtype), + "kv_cache_dtype": kv_cache_dtype, + "page_size": page_size, + "num_kv_heads": num_kv_heads, + "head_dim": head_dim, + "max_seq_len": max_seq_len, + } + + +def write_results_to_csv(results, filename=None): + """Write benchmark results to CSV file.""" + if filename is None: + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + filename = f"flashinfer_trtllm_benchmark_{timestamp}.csv" + + fieldnames = [ + "num_seqs", + "trt_mean", + "trt_std", + "baseline_mean", + "baseline_std", + "speedup_percent", + "q_dtype", + "kv_cache_dtype", + "page_size", + "num_kv_heads", + "head_dim", + "max_seq_len", + ] + + file_exists = os.path.exists(filename) + + with open(filename, "a", newline="") as csvfile: + writer = csv.DictWriter(csvfile, fieldnames=fieldnames) + + if not file_exists: + writer.writeheader() + + for result in results: + writer.writerow(result) + + print(f"Results written to {filename}") + + +if __name__ == "__main__": + num_seqs = [1, 4, 8, 16, 32, 64, 128, 256] + max_seq_lens = [1024, 2048, 4096, 8192, 16384, 32768, 65536, 131072] + all_results = [] + + print( + "Running benchmark for q_dtype = bfloat16, kv_cache_dtype: bfloat16, " + "output_dtype: bfloat16" + ) + print( + "\tnum_seqs\tmax_seq_len\ttrt_mean\ttrt_std\tbaseline_mean\t" + "baseline_std\tspeedup_percent" + ) + for max_seq_len in max_seq_lens: + for bs in num_seqs: + result = benchmark_prefill( + bs, + max_seq_len, + dtype=torch.bfloat16, + kv_cache_dtype="auto", + ) + all_results.append(result) + + # Write all results to CSV + write_results_to_csv(all_results) diff --git a/tests/kernels/attention/test_flashinfer_trtllm_attention.py b/tests/kernels/attention/test_flashinfer_trtllm_attention.py new file mode 100644 index 000000000000..e87ce520bc66 --- /dev/null +++ b/tests/kernels/attention/test_flashinfer_trtllm_attention.py @@ -0,0 +1,293 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from typing import Optional + +import flashinfer +import pytest +import torch + +from vllm.platforms import current_platform + +if not current_platform.is_device_capability(100): + pytest.skip("This TRTLLM kernel requires NVIDIA Blackwell.", + allow_module_level=True) + +FLOAT32_BYTES = torch.finfo(torch.float).bits // 8 + +# KV Cache Layout for TRT-LLM +# kv_cache_shape = (num_blocks, 2, num_kv_heads, page_size, head_dim) + +MAX_Q_LEN = 1024 +MAX_KV_LEN = 4096 +BATCH_SIZES = [4, 12] +NUM_HEADS = [(64, 8), (16, 16), (40, 8), (32, 8)] +HEAD_SIZES = [128] +BLOCK_SIZES = [16, 32] +KV_LAYOUTS = ["HND"] +DTYPES = [torch.float16, torch.bfloat16] +KV_CACHE_DTYPES = [None, current_platform.fp8_dtype()] +NUM_BLOCKS = 32768 # Large enough to test overflow in index calculation. +SOFT_CAPS = [None, 50.0] + + +def to_float8(x, dtype=torch.float8_e4m3fn): + finfo = torch.finfo(dtype) + min_val, max_val = x.aminmax() + amax = torch.maximum(min_val.abs(), max_val.abs()).clamp(min=1e-12) + scale = finfo.max / amax * 0.1 + x_scl_sat = (x * scale).clamp(min=finfo.min, max=finfo.max) + return x_scl_sat.to(dtype), scale.float().reciprocal() + + +@pytest.mark.parametrize("batch_size", BATCH_SIZES) +@pytest.mark.parametrize("num_heads", NUM_HEADS) +@pytest.mark.parametrize("head_size", HEAD_SIZES) +@pytest.mark.parametrize("block_size", BLOCK_SIZES) +@pytest.mark.parametrize("kv_layout", KV_LAYOUTS) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPES) +@pytest.mark.parametrize("soft_cap", SOFT_CAPS) +@torch.inference_mode +def test_flashinfer_trtllm_decode_with_baseline( + batch_size: int, + num_heads: tuple[int, int], + head_size: int, + block_size: int, + kv_layout: str, + dtype: torch.dtype, + kv_cache_dtype: Optional[torch.dtype], + soft_cap: Optional[float], +) -> None: + kv_cache_dtype = dtype if kv_cache_dtype is None else kv_cache_dtype + + torch.set_default_device("cuda") + current_platform.seed_everything(0) + + kv_lens = torch.randint(1, MAX_KV_LEN, (batch_size, ), dtype=torch.int32) + kv_lens[-1] = MAX_KV_LEN + max_kv_len = torch.max(kv_lens).item() + num_seqs = len(kv_lens) + + num_query_heads = num_heads[0] + num_kv_heads = num_heads[1] + assert num_query_heads % num_kv_heads == 0 + + scale = head_size**-0.5 + + query = torch.randn(num_seqs, num_query_heads, head_size, dtype=dtype) + + kv_cache_shape = None + if kv_layout == "NHD": + kv_cache_shape = (NUM_BLOCKS, 2, block_size, num_kv_heads, head_size) + elif kv_layout == "HND": + kv_cache_shape = (NUM_BLOCKS, 2, num_kv_heads, block_size, head_size) + else: + raise ValueError(f"Invalid kv_layout: {kv_layout}") + key_value_cache = torch.randn(kv_cache_shape, dtype=dtype) + kv_scale = 1.0 + if kv_cache_dtype is current_platform.fp8_dtype(): + key_value_cache, kv_scale = to_float8(key_value_cache, + current_platform.fp8_dtype()) + + max_num_blocks_per_seq = (max_kv_len + block_size - 1) // block_size + block_tables = torch.randint(0, + NUM_BLOCKS, + (num_seqs, max_num_blocks_per_seq), + dtype=torch.int32) + k_scale = v_scale = kv_scale + kv_indptr = [0] + kv_indices = [] + kv_last_page_lens = [] + for i in range(num_seqs): + seq_len = kv_lens[i] + assert seq_len > 0 + num_blocks = (seq_len + block_size - 1) // block_size + kv_indices.extend(block_tables[i, :num_blocks]) + kv_indptr.append(kv_indptr[-1] + num_blocks) + kv_last_page_len = seq_len % block_size + if kv_last_page_len == 0: + kv_last_page_len = block_size + kv_last_page_lens.append(kv_last_page_len) + + kv_indptr = torch.tensor(kv_indptr, dtype=torch.int32) + kv_indices = torch.tensor(kv_indices, dtype=torch.int32) + kv_last_page_lens = torch.tensor(kv_last_page_lens, dtype=torch.int32) + + workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8) + wrapper = flashinfer.BatchDecodeWithPagedKVCacheWrapper( + workspace_buffer, + kv_layout, + use_tensor_cores=((num_query_heads // num_kv_heads) > 4)) + wrapper.plan(kv_indptr, + kv_indices, + kv_last_page_lens, + num_query_heads, + num_kv_heads, + head_size, + block_size, + "NONE", + sm_scale=scale, + q_data_type=dtype, + kv_data_type=kv_cache_dtype, + logits_soft_cap=soft_cap) + + output = torch.empty(query.shape, dtype=dtype) + wrapper.run(query, + key_value_cache, + k_scale=k_scale, + v_scale=v_scale, + out=output) + + # TRTLLM Decode + kv_lens_tensor = torch.tensor(kv_lens, dtype=torch.int32) + output_trtllm = torch.empty(query.shape, dtype=dtype) + flashinfer.decode.trtllm_batch_decode_with_kv_cache( + query=query.contiguous(), + kv_cache=key_value_cache, + workspace_buffer=workspace_buffer, + block_tables=block_tables, + seq_lens=kv_lens_tensor, + max_seq_len=max_kv_len, + bmm1_scale=k_scale * scale, + bmm2_scale=v_scale, + out=output_trtllm, + ) + + torch.testing.assert_close(output, output_trtllm, atol=1e-2, rtol=1e-2), \ + f"{torch.max(torch.abs(output - output_trtllm))}" + + +@pytest.mark.parametrize("batch_size", BATCH_SIZES) +@pytest.mark.parametrize("num_heads", NUM_HEADS) +@pytest.mark.parametrize("head_size", HEAD_SIZES) +@pytest.mark.parametrize("block_size", BLOCK_SIZES) +@pytest.mark.parametrize("kv_layout", KV_LAYOUTS) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPES) +@pytest.mark.parametrize("soft_cap", [None]) +@torch.inference_mode +def test_flashinfer_trtllm_prefill_with_baseline( + batch_size: int, + num_heads: tuple[int, int], + head_size: int, + block_size: int, + kv_layout: str, + dtype: torch.dtype, + kv_cache_dtype: Optional[torch.dtype], + soft_cap: Optional[float], +) -> None: + kv_cache_dtype = dtype if kv_cache_dtype is None else kv_cache_dtype + if dtype != kv_cache_dtype: + pytest.skip(f"Not supported dtype({dtype}) with " + "kv_cache_dtype({kv_cache_dtype})") + + torch.set_default_device("cuda") + current_platform.seed_everything(0) + + q_lens = torch.randint(1, MAX_Q_LEN, (batch_size, ), dtype=torch.int32) + q_lens[-1] = MAX_Q_LEN + max_q_len = torch.max(q_lens).item() + q_indptr = torch.cat([ + torch.tensor([0], dtype=torch.int32), + torch.cumsum(q_lens, dim=0, dtype=torch.int32), + ]) + + kv_lens = torch.randint(0, MAX_KV_LEN, (batch_size, ), dtype=torch.int32) + kv_lens[-1] = MAX_KV_LEN + + seq_lens = kv_lens + q_lens + max_seq_len = torch.max(seq_lens).item() + num_seqs = len(seq_lens) + + num_query_heads = num_heads[0] + num_kv_heads = num_heads[1] + assert num_query_heads % num_kv_heads == 0 + + scale = head_size**-0.5 + + query = torch.randn(torch.sum(q_lens).item(), + num_query_heads, + head_size, + dtype=dtype) + + kv_cache_shape = None + if kv_layout == "NHD": + kv_cache_shape = (NUM_BLOCKS, 2, block_size, num_kv_heads, head_size) + elif kv_layout == "HND": + kv_cache_shape = (NUM_BLOCKS, 2, num_kv_heads, block_size, head_size) + else: + raise ValueError(f"Invalid kv_layout: {kv_layout}") + key_value_cache = torch.randn(kv_cache_shape, dtype=dtype) + kv_scale = 1.0 + if kv_cache_dtype is current_platform.fp8_dtype(): + key_value_cache, kv_scale = to_float8(key_value_cache, + current_platform.fp8_dtype()) + + max_num_blocks_per_seq = (max_seq_len + block_size - 1) // block_size + block_tables = torch.randint(0, + NUM_BLOCKS, + (num_seqs, max_num_blocks_per_seq), + dtype=torch.int32) + k_scale = v_scale = kv_scale + kv_indptr = [0] + kv_indices = [] + kv_last_page_lens = [] + for i in range(num_seqs): + seq_len = seq_lens[i] + assert seq_len > 0 + num_blocks = (seq_len + block_size - 1) // block_size + kv_indices.extend(block_tables[i, :num_blocks]) + kv_indptr.append(kv_indptr[-1] + num_blocks) + kv_last_page_len = seq_len % block_size + if kv_last_page_len == 0: + kv_last_page_len = block_size + kv_last_page_lens.append(kv_last_page_len) + + kv_indptr = torch.tensor(kv_indptr, dtype=torch.int32) + kv_indices = torch.tensor(kv_indices, dtype=torch.int32) + kv_last_page_lens = torch.tensor(kv_last_page_lens, dtype=torch.int32) + + workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8) + wrapper = flashinfer.BatchPrefillWithPagedKVCacheWrapper( + workspace_buffer, kv_layout) + wrapper.plan(q_indptr, + kv_indptr, + kv_indices, + kv_last_page_lens, + num_query_heads, + num_kv_heads, + head_size, + block_size, + causal=True, + sm_scale=scale, + q_data_type=dtype, + kv_data_type=kv_cache_dtype, + logits_soft_cap=soft_cap) + + output = torch.empty(query.shape, dtype=dtype) + wrapper.run(query, + key_value_cache, + k_scale=k_scale, + v_scale=v_scale, + out=output) + + # TRTLLM Decode + output_trtllm = torch.empty(query.shape, dtype=dtype) + flashinfer.prefill.trtllm_batch_context_with_kv_cache( + query=query.contiguous(), + kv_cache=key_value_cache, + workspace_buffer=workspace_buffer, + block_tables=block_tables, + seq_lens=seq_lens, + max_q_len=max_q_len, + max_kv_len=max_seq_len, + bmm1_scale=k_scale * scale, + bmm2_scale=v_scale, + batch_size=num_seqs, + cum_seq_lens_q=q_indptr, + cum_seq_lens_kv=kv_indptr, + out=output_trtllm, + ) + + torch.testing.assert_close(output, output_trtllm, atol=1e-2, rtol=1e-2), \ + f"{torch.max(torch.abs(output - output_trtllm))}" diff --git a/tests/kernels/attention/test_flashinfer_trtllm_decode_attention.py b/tests/kernels/attention/test_flashinfer_trtllm_decode_attention.py deleted file mode 100644 index 2e2130fab6a2..000000000000 --- a/tests/kernels/attention/test_flashinfer_trtllm_decode_attention.py +++ /dev/null @@ -1,138 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Optional - -import flashinfer -import pytest -import torch - -from vllm.platforms import current_platform - -if not current_platform.is_device_capability(100): - pytest.skip("This TRTLLM kernel requires NVIDIA Blackwell.", - allow_module_level=True) - -FLOAT32_BYTES = torch.finfo(torch.float).bits // 8 - -# KV Cache Layout for TRT-LLM -# kv_cache_shape = (num_blocks, 2, num_kv_heads, page_size, head_dim) - -NUM_HEADS = [(64, 8), (16, 16), (40, 8), (32, 8)] -HEAD_SIZES = [128] -BLOCK_SIZES = [16, 32] -DTYPES = [torch.float16, torch.bfloat16] -NUM_BLOCKS = 32768 # Large enough to test overflow in index calculation. -SOFT_CAPS = [None, 30.0, 50.0] - - -def to_float8(x, dtype=torch.float8_e4m3fn): - finfo = torch.finfo(dtype) - min_val, max_val = x.aminmax() - amax = torch.maximum(min_val.abs(), max_val.abs()).clamp(min=1e-12) - scale = finfo.max / amax * 0.1 - x_scl_sat = (x * scale).clamp(min=finfo.min, max=finfo.max) - return x_scl_sat.to(dtype), scale.float().reciprocal() - - -@pytest.mark.parametrize("kv_lens", [[1328, 18, 463], [1, 54, 293, 70]]) -@pytest.mark.parametrize("num_heads", NUM_HEADS) -@pytest.mark.parametrize("head_size", HEAD_SIZES) -@pytest.mark.parametrize("block_size", BLOCK_SIZES) -@pytest.mark.parametrize("kv_layout", ["HND"]) -@pytest.mark.parametrize("dtype", DTYPES) -@pytest.mark.parametrize("soft_cap", SOFT_CAPS) -@torch.inference_mode -def test_flashinfer_trtllm_decode_with_baseline( - kv_lens: list[int], - num_heads: tuple[int, int], - head_size: int, - dtype: torch.dtype, - block_size: int, - soft_cap: Optional[float], - kv_layout: str, -) -> None: - torch.set_default_device("cuda") - current_platform.seed_everything(0) - num_seqs = len(kv_lens) - num_query_heads = num_heads[0] - num_kv_heads = num_heads[1] - - assert num_query_heads % num_kv_heads == 0 - max_kv_len = max(kv_lens) - scale = head_size**-0.5 - - query = torch.randn(num_seqs, num_query_heads, head_size, dtype=dtype) - kv_cache_shape = None - if kv_layout == "NHD": - kv_cache_shape = (NUM_BLOCKS, 2, block_size, num_kv_heads, head_size) - elif kv_layout == "HND": - kv_cache_shape = (NUM_BLOCKS, 2, num_kv_heads, block_size, head_size) - else: - raise ValueError(f"Invalid kv_layout: {kv_layout}") - key_value_cache = torch.randn(kv_cache_shape, dtype=dtype) - - max_num_blocks_per_seq = (max_kv_len + block_size - 1) // block_size - block_tables = torch.randint(0, - NUM_BLOCKS, - (num_seqs, max_num_blocks_per_seq), - dtype=torch.int32) - k_scale = v_scale = 1.0 - kv_indptr = [0] - kv_indices = [] - kv_last_page_lens = [] - for i in range(num_seqs): - seq_len = kv_lens[i] - assert seq_len > 0 - num_blocks = (seq_len + block_size - 1) // block_size - kv_indices.extend(block_tables[i, :num_blocks]) - kv_indptr.append(kv_indptr[-1] + num_blocks) - kv_last_page_len = seq_len % block_size - if kv_last_page_len == 0: - kv_last_page_len = block_size - kv_last_page_lens.append(kv_last_page_len) - - kv_indptr = torch.tensor(kv_indptr, dtype=torch.int32) - kv_indices = torch.tensor(kv_indices, dtype=torch.int32) - kv_last_page_lens = torch.tensor(kv_last_page_lens, dtype=torch.int32) - - workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8) - wrapper = flashinfer.\ - BatchDecodeWithPagedKVCacheWrapper(workspace_buffer, kv_layout, - use_tensor_cores=( - (num_query_heads//num_kv_heads) > 4) - ) - wrapper.plan(kv_indptr, - kv_indices, - kv_last_page_lens, - num_query_heads, - num_kv_heads, - head_size, - block_size, - "NONE", - q_data_type=dtype, - kv_data_type=dtype, - logits_soft_cap=soft_cap) - - output = torch.empty(query.shape, dtype=dtype) - wrapper.run(query, key_value_cache, scale, out=output) - - # TRTLLM Decode - max_kv_len = max(kv_lens) - kv_lens_tensor = torch.tensor(kv_lens, - dtype=torch.int, - device=query.device) - output_trtllm = torch.empty(query.shape, dtype=dtype) - flashinfer.decode.trtllm_batch_decode_with_kv_cache( - query.contiguous(), - key_value_cache, - workspace_buffer, - block_tables, - kv_lens_tensor, - max_kv_len, - bmm1_scale=k_scale * scale, - bmm2_scale=v_scale, - out=output_trtllm, - ) - - torch.testing.assert_close(output, output_trtllm, atol=1e-2, rtol=1e-2), \ - f"{torch.max(torch.abs(output - output_trtllm))}" diff --git a/vllm/attention/backends/flashinfer.py b/vllm/attention/backends/flashinfer.py index b3372ce2eca8..78d8a67e37f8 100644 --- a/vllm/attention/backends/flashinfer.py +++ b/vllm/attention/backends/flashinfer.py @@ -46,7 +46,7 @@ from vllm.config import VllmConfig, get_layers_from_vllm_config from vllm.logger import init_logger from vllm.utils import (async_tensor_h2d, get_kv_cache_torch_dtype, make_tensor_with_pad) -from vllm.utils.flashinfer import use_trtllm_decode_attention +from vllm.utils.flashinfer import use_trtllm_attention logger = init_logger(__name__) @@ -1114,7 +1114,7 @@ class FlashInferImpl(AttentionImpl): assert decode_meta.decode_wrapper._sm_scale == softmax_scale # TODO: @pavanimajety Remove this once the switch happens # inside flashinfer. - if not use_trtllm_decode_attention( + if not use_trtllm_attention( num_decode_tokens, attn_metadata.max_decode_seq_len, kv_cache_dtype, attn_metadata.num_qo_heads, attn_metadata.num_kv_heads, attn_metadata.head_dim): diff --git a/vllm/envs.py b/vllm/envs.py index 78f955f78a98..9bce5c6d2e0b 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -1027,9 +1027,9 @@ environment_variables: dict[str, Callable[[], Any]] = { "VLLM_USE_CUDNN_PREFILL": lambda: bool(int(os.getenv("VLLM_USE_CUDNN_PREFILL", "0"))), - # If set to 1, use the TRTLLM Decode Attention backend in flashinfer. - "VLLM_USE_TRTLLM_DECODE_ATTENTION": - lambda: os.getenv("VLLM_USE_TRTLLM_DECODE_ATTENTION", None), + # If set to 1, use the TRTLLM Attention backend in flashinfer. + "VLLM_USE_TRTLLM_ATTENTION": + lambda: os.getenv("VLLM_USE_TRTLLM_ATTENTION", None), # Controls garbage collection during CUDA graph capture. # If set to 0 (default), enables GC freezing to speed up capture time. diff --git a/vllm/utils/flashinfer.py b/vllm/utils/flashinfer.py index 29967bc51671..cce1aefaf9b0 100644 --- a/vllm/utils/flashinfer.py +++ b/vllm/utils/flashinfer.py @@ -124,7 +124,7 @@ def has_flashinfer_cutlass_fused_moe() -> bool: @functools.cache def has_nvidia_artifactory() -> bool: """Return ``True`` if NVIDIA's artifactory is accessible. - + This checks connectivity to the kernel inference library artifactory which is required for downloading certain cubin kernels like TRTLLM FHMA. """ @@ -144,7 +144,7 @@ def has_nvidia_artifactory() -> bool: return False -def use_trtllm_decode_attention( +def use_trtllm_attention( num_tokens: int, max_seq_len: int, kv_cache_dtype: str, @@ -159,29 +159,26 @@ def use_trtllm_decode_attention( # Check if the dimensions are supported by TRTLLM decode attention if (attn_head_size is None or num_qo_heads is None or num_kv_heads is None - or num_qo_heads // num_kv_heads > 8 or num_qo_heads % num_kv_heads != 0 or attn_head_size != 128): return False - env_value = envs.VLLM_USE_TRTLLM_DECODE_ATTENTION + env_value = envs.VLLM_USE_TRTLLM_ATTENTION if env_value is not None: - logger.info_once("VLLM_USE_TRTLLM_DECODE_ATTENTION is set to %s", - env_value) + logger.info_once("VLLM_USE_TRTLLM_ATTENTION is set to %s", env_value) # Environment variable is set - respect it # Making the conditional check for zero because # the path is automatically enabled if the batch size condition # is satisfied. no_use_trtllm = (env_value == "0") if not no_use_trtllm: - logger.info_once("Using TRTLLM decode attention.") + logger.info_once("Using TRTLLM attention.") return not no_use_trtllm else: # Environment variable not set - use auto-detection use_trtllm = (num_tokens <= 256 and max_seq_len < 131072 and kv_cache_dtype == "auto") if use_trtllm: - logger.warning_once( - "Using TRTLLM decode attention (auto-detected).") + logger.warning_once("Using TRTLLM attention (auto-detected).") return use_trtllm @@ -195,5 +192,5 @@ __all__ = [ "has_flashinfer_moe", "has_flashinfer_cutlass_fused_moe", "has_nvidia_artifactory", - "use_trtllm_decode_attention", + "use_trtllm_attention", ] diff --git a/vllm/v1/attention/backends/flashinfer.py b/vllm/v1/attention/backends/flashinfer.py index 3697cb9387a9..8592d1b26dfa 100755 --- a/vllm/v1/attention/backends/flashinfer.py +++ b/vllm/v1/attention/backends/flashinfer.py @@ -12,6 +12,7 @@ from flashinfer import (BatchDecodeWithPagedKVCacheWrapper, MultiLevelCascadeAttentionWrapper) from flashinfer.decode import (_get_range_buf, get_seq_lens, trtllm_batch_decode_with_kv_cache) +from flashinfer.prefill import trtllm_batch_context_with_kv_cache import vllm.envs as envs from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, @@ -19,7 +20,7 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.utils import cdiv, is_pin_memory_available -from vllm.utils.flashinfer import use_trtllm_decode_attention +from vllm.utils.flashinfer import use_trtllm_attention from vllm.v1.attention.backends.flash_attn import use_cascade_attention # yapf conflicts with isort for this block # yapf: disable @@ -149,9 +150,12 @@ class FlashInferMetadata: slot_mapping: torch.Tensor # For flashinfer trtllm batch decode + max_q_len: int max_seq_len: int seq_lens: torch.Tensor block_table_tensor: torch.Tensor + prefill_use_trtllm: bool + decode_use_trtllm: bool # For handling prefill decode split num_decodes: int @@ -170,6 +174,9 @@ class FlashInferMetadata: decode_wrapper: Optional[BatchDecodeWithPagedKVCacheWrapper] = None cascade_wrapper: Optional[MultiLevelCascadeAttentionWrapper] = None + qo_indptr_gpu: Optional[torch.Tensor] = None + paged_kv_indptr_gpu: Optional[torch.Tensor] = None + def __post_init__(self): if self.head_dim is not None: FlashInferBackend.validate_head_size(self.head_dim) @@ -305,8 +312,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): 2, self._get_workspace_buffer(), get_kv_cache_layout()) return self._cascade_wrapper - def _plan(self, num_prefills: int, num_decodes: int, - attn_metadata: FlashInferMetadata): + def _plan(self, attn_metadata: FlashInferMetadata): if attn_metadata.use_cascade: attn_metadata.cascade_wrapper = self._get_cascade_wrapper() attn_metadata.cascade_wrapper.plan( @@ -341,6 +347,8 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): # Regular attention (common case). # Decodes are at the front and prefills are at the back, # according to reorder_batch() + num_prefills = attn_metadata.num_prefills + num_decodes = attn_metadata.num_decodes if num_prefills > 0: # Decodes are first so prefills start after the last decode prefill_start = num_decodes @@ -356,23 +364,31 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): # to be relative to the start of the prefill queries. qo_indptr_cpu = attn_metadata.qo_indptr_cpu[ prefill_start:] - attn_metadata.qo_indptr_cpu[prefill_start] - attn_metadata.prefill_wrapper.plan( - qo_indptr_cpu, - attn_metadata.paged_kv_indptr_cpu[prefill_start:], - attn_metadata.paged_kv_indices, - attn_metadata.paged_kv_last_page_len_cpu[prefill_start:], - attn_metadata.num_qo_heads, - attn_metadata.num_kv_heads, - attn_metadata.head_dim, - attn_metadata.page_size, - causal=True, - sm_scale=self.global_hyperparameters.sm_scale, - window_left=self.global_hyperparameters.window_left, - logits_soft_cap=self.global_hyperparameters. - logits_soft_cap, - q_data_type=attn_metadata.q_data_type, - kv_data_type=attn_metadata.kv_data_type, - ) + paged_kv_indptr_cpu = attn_metadata.paged_kv_indptr_cpu[ + prefill_start:] + if not attn_metadata.prefill_use_trtllm: + attn_metadata.prefill_wrapper.plan( + qo_indptr_cpu, + paged_kv_indptr_cpu, + attn_metadata.paged_kv_indices, + attn_metadata. + paged_kv_last_page_len_cpu[prefill_start:], + attn_metadata.num_qo_heads, + attn_metadata.num_kv_heads, + attn_metadata.head_dim, + attn_metadata.page_size, + causal=True, + sm_scale=self.global_hyperparameters.sm_scale, + window_left=self.global_hyperparameters.window_left, + logits_soft_cap=self.global_hyperparameters. + logits_soft_cap, + q_data_type=attn_metadata.q_data_type, + kv_data_type=attn_metadata.kv_data_type, + ) + else: + attn_metadata.qo_indptr_gpu = qo_indptr_cpu.to(self.device) + attn_metadata.paged_kv_indptr_gpu = paged_kv_indptr_cpu.to( + self.device) if num_decodes > 0: pure_decode = num_prefills == 0 @@ -400,11 +416,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): attn_metadata.decode_wrapper = self._get_decode_wrapper( num_input_tokens, use_cudagraph) - if not use_trtllm_decode_attention( - num_decodes, attn_metadata.max_seq_len, - self.cache_config.cache_dtype, - attn_metadata.num_qo_heads, attn_metadata.num_kv_heads, - attn_metadata.head_dim): + if not attn_metadata.decode_use_trtllm: # Use the persistent buffer with padding length, # instead of the same address but chunked version # in atten_metadata when using cudagraph. @@ -437,6 +449,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): split_decodes_and_prefills(common_attn_metadata) page_size = self.kv_cache_spec.block_size + max_q_len = common_attn_metadata.max_query_len max_seq_len = common_attn_metadata.seq_lens_cpu.max() seq_lens = common_attn_metadata.seq_lens seq_lens_cpu = common_attn_metadata.seq_lens_cpu @@ -503,6 +516,24 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): cache_dtype) else: kv_cache_dtype = self.kv_cache_spec.dtype + + num_qo_heads = self.vllm_config.model_config.get_num_attention_heads( + self.vllm_config.parallel_config) + num_kv_heads = self.kv_cache_spec.num_kv_heads + head_dim = self.kv_cache_spec.head_size + + # currently prefill trtllm attention does not support fp8 kv cache + # trtllm may not support sliding window + prefill_use_trtllm = (self.global_hyperparameters.window_left == -1 + and not cache_dtype.startswith("fp8") + and use_trtllm_attention( + num_prefill_tokens, max_seq_len, cache_dtype, + num_qo_heads, num_kv_heads, head_dim)) + decode_use_trtllm = (self.global_hyperparameters.window_left == -1 + and use_trtllm_attention( + num_decode_tokens, max_seq_len, cache_dtype, + num_qo_heads, num_kv_heads, head_dim)) + attn_metadata = FlashInferMetadata( num_actual_tokens=num_actual_tokens, qo_indptr_cpu=common_attn_metadata.query_start_loc_cpu, @@ -510,14 +541,19 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): paged_kv_indices=paged_kv_indices, paged_kv_last_page_len_cpu=self. paged_kv_last_page_len_cpu[:num_reqs], - num_qo_heads=self.vllm_config.model_config.get_num_attention_heads( - self.vllm_config.parallel_config), - num_kv_heads=self.kv_cache_spec.num_kv_heads, - head_dim=self.kv_cache_spec.head_size, + num_qo_heads=num_qo_heads, + num_kv_heads=num_kv_heads, + head_dim=head_dim, page_size=page_size, kv_data_type=kv_cache_dtype, q_data_type=self.vllm_config.model_config.dtype, slot_mapping=common_attn_metadata.slot_mapping, + max_q_len=max_q_len, + max_seq_len=max_seq_len, + seq_lens=seq_lens, + block_table_tensor=block_table_tensor, + prefill_use_trtllm=prefill_use_trtllm, + decode_use_trtllm=decode_use_trtllm, num_decodes=num_decodes, num_decode_tokens=num_decode_tokens, num_prefills=num_prefills, @@ -527,12 +563,9 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): shared_kv_page_indptr_cpu=shared_kv_page_indptr_cpu, shared_kv_page_indices_cpu=shared_kv_page_indices_cpu, shared_kv_last_page_len_cpu=shared_kv_last_page_len_cpu, - max_seq_len=max_seq_len, - seq_lens=seq_lens, - block_table_tensor=block_table_tensor, ) - self._plan(num_prefills, num_decodes, attn_metadata) + self._plan(attn_metadata) return attn_metadata @@ -698,30 +731,64 @@ class FlashInferImpl(AttentionImpl): # Regular attention (common case). # Decodes are at the front and prefills are at the back, # according to reorder_batch() - if prefill_wrapper := attn_metadata.prefill_wrapper: + if num_prefill_tokens > 0: + prefill_wrapper = attn_metadata.prefill_wrapper prefill_query = query[num_decode_tokens:] assert prefill_query.shape[0] == num_prefill_tokens assert prefill_wrapper is not None - assert prefill_wrapper._causal - assert prefill_wrapper._window_left == window_left - assert prefill_wrapper._logits_soft_cap == (self.logits_soft_cap - or 0.0) - assert prefill_wrapper._sm_scale == self.scale - prefill_wrapper.run( - prefill_query, - kv_cache_permute, - k_scale=layer._k_scale_float, - v_scale=layer._v_scale_float, - out=output[num_decode_tokens:], - ) - if decode_wrapper := attn_metadata.decode_wrapper: + + if not attn_metadata.prefill_use_trtllm: + assert prefill_wrapper._causal + assert prefill_wrapper._window_left == window_left + assert prefill_wrapper._logits_soft_cap == ( + self.logits_soft_cap or 0.0) + assert prefill_wrapper._sm_scale == self.scale + prefill_wrapper.run( + prefill_query, + kv_cache_permute, + k_scale=layer._k_scale_float, + v_scale=layer._v_scale_float, + out=output[num_decode_tokens:], + ) + else: + # prefill_query may be non-contiguous + prefill_query = prefill_query.contiguous() + workspace_buffer = prefill_wrapper._float_workspace_buffer + block_tables_prefill = attn_metadata.block_table_tensor[ + num_decode_tokens:] + seq_lens_prefill = attn_metadata.seq_lens[num_decode_tokens:] + + # This path needs to be enabled with VLLM_KV_CACHE_LAYOUT = HND + assert get_kv_cache_layout() == "HND" + assert prefill_query.is_contiguous() + assert kv_cache_permute.is_contiguous() + assert workspace_buffer.is_contiguous() + assert block_tables_prefill.is_contiguous() + assert seq_lens_prefill.is_contiguous() + + trtllm_batch_context_with_kv_cache( + query=prefill_query, + kv_cache=kv_cache_permute, + workspace_buffer=workspace_buffer, + block_tables=block_tables_prefill, + seq_lens=seq_lens_prefill, + max_q_len=attn_metadata.max_q_len, + max_kv_len=attn_metadata.max_seq_len, + bmm1_scale=layer._k_scale_float * self.scale, + bmm2_scale=layer._v_scale_float, + batch_size=attn_metadata.num_prefills, + cum_seq_lens_q=attn_metadata.qo_indptr_gpu, + cum_seq_lens_kv=attn_metadata.paged_kv_indptr_gpu, + out=output[num_decode_tokens:], + ) + + if num_decode_tokens > 0: + decode_wrapper = attn_metadata.decode_wrapper decode_query = query[:num_decode_tokens] assert decode_query.shape[0] == num_decode_tokens assert decode_wrapper is not None - if not use_trtllm_decode_attention( - attn_metadata.num_decodes, attn_metadata.max_seq_len, - self.kv_cache_dtype, attn_metadata.num_qo_heads, - attn_metadata.num_kv_heads, attn_metadata.head_dim): + + if not attn_metadata.decode_use_trtllm: assert decode_wrapper._window_left == window_left assert decode_wrapper._logits_soft_cap == (self.logits_soft_cap or 0.0) @@ -734,34 +801,32 @@ class FlashInferImpl(AttentionImpl): out=output[:num_decode_tokens], ) else: + # decode_query may be non-contiguous + decode_query = decode_query.contiguous() + workspace_buffer = decode_wrapper._float_workspace_buffer + block_tables_decode = attn_metadata.block_table_tensor[: + num_decode_tokens] + seq_lens_decode = attn_metadata.seq_lens[:num_decode_tokens] + # This path needs to be enabled with VLLM_KV_CACHE_LAYOUT = HND - if num_decode_tokens > 0: - # decode_query may be non-contiguous - decode_query = decode_query.contiguous() - block_tables_decode = attn_metadata.block_table_tensor[: - num_decode_tokens] - seq_lens_decode = attn_metadata.seq_lens[: - num_decode_tokens] - workspace_buffer = decode_wrapper._float_workspace_buffer + assert get_kv_cache_layout() == "HND" + assert decode_query.is_contiguous() + assert kv_cache_permute.is_contiguous() + assert workspace_buffer.is_contiguous() + assert block_tables_decode.is_contiguous() + assert seq_lens_decode.is_contiguous() - assert get_kv_cache_layout() == "HND" - assert decode_query.is_contiguous() - assert kv_cache_permute.is_contiguous() - assert block_tables_decode.is_contiguous() - assert seq_lens_decode.is_contiguous() - assert workspace_buffer.is_contiguous() - - trtllm_batch_decode_with_kv_cache( - query=decode_query, - kv_cache=kv_cache_permute, - workspace_buffer=workspace_buffer, - block_tables=block_tables_decode, - seq_lens=seq_lens_decode, - max_seq_len=attn_metadata.max_seq_len, - bmm1_scale=layer._k_scale_float * self.scale, - bmm2_scale=layer._v_scale_float, - out=output[:num_decode_tokens], - ) + trtllm_batch_decode_with_kv_cache( + query=decode_query, + kv_cache=kv_cache_permute, + workspace_buffer=workspace_buffer, + block_tables=block_tables_decode, + seq_lens=seq_lens_decode, + max_seq_len=attn_metadata.max_seq_len, + bmm1_scale=layer._k_scale_float * self.scale, + bmm2_scale=layer._v_scale_float, + out=output[:num_decode_tokens], + ) return output_padded @@ -786,8 +851,8 @@ def fast_plan_decode( non_blocking: bool = True, ) -> None: """ - A faster version of BatchDecodeWithPagedKVCacheWrapper::plan used for - cudagraph capture/replay, while the no cudagraph version turns back + A faster version of BatchDecodeWithPagedKVCacheWrapper::plan used for + cudagraph capture/replay, while the no cudagraph version turns back to the original plan. using original plan after passing host-side buffers: - only host-to-device copy of indptr and last_page_len buffers