From 7bd4c37ae7c6f2223c1a031bbdd2e3435d53da94 Mon Sep 17 00:00:00 2001 From: Pavani Majety Date: Fri, 11 Jul 2025 02:23:23 -0700 Subject: [PATCH] [Core] Add Flashinfer TRTLLM Backend for Flashinfer decode path (SM100). (#19825) Signed-off-by: Pavani Majety Signed-off-by: mgoin Co-authored-by: shuw Co-authored-by: mgoin --- .../kernels/benchmark_trtllm_attention.py | 240 ++++++++++++++++++ ...test_flashinfer_trtllm_decode_attention.py | 140 ++++++++++ vllm/attention/backends/flashinfer.py | 123 +++++++-- vllm/engine/arg_utils.py | 2 + vllm/envs.py | 6 +- vllm/platforms/cuda.py | 19 +- vllm/v1/attention/backends/flashinfer.py | 183 ++++++++++--- vllm/v1/attention/backends/utils.py | 10 +- 8 files changed, 667 insertions(+), 56 deletions(-) create mode 100644 benchmarks/kernels/benchmark_trtllm_attention.py create mode 100644 tests/kernels/attention/test_flashinfer_trtllm_decode_attention.py diff --git a/benchmarks/kernels/benchmark_trtllm_attention.py b/benchmarks/kernels/benchmark_trtllm_attention.py new file mode 100644 index 0000000000000..8c980f930366c --- /dev/null +++ b/benchmarks/kernels/benchmark_trtllm_attention.py @@ -0,0 +1,240 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import csv +import os +import random +from datetime import datetime + +import flashinfer +import torch + +FLOAT32_BYTES = torch.finfo(torch.float).bits // 8 + +# KV Cache Layout for TRT-LLM +# kv_cache_shape = (num_blocks, 2, num_kv_heads, page_size, head_dim) + + +def to_float8(x, dtype=torch.float8_e4m3fn): + finfo = torch.finfo(dtype) + min_val, max_val = x.aminmax() + amax = torch.maximum(min_val.abs(), max_val.abs()).clamp(min=1e-12) + scale = finfo.max / amax * 0.1 + x_scl_sat = (x * scale).clamp(min=finfo.min, max=finfo.max) + return x_scl_sat.to(dtype), scale.float().reciprocal() + + +@torch.no_grad() +def benchmark_decode( + num_seqs, + max_seq_len, + page_size=16, + dtype=torch.bfloat16, + kv_layout="HND", + num_kv_heads=8, + kv_cache_dtype="auto", + head_dim=128, + warmup=10, + trials=20, +): + torch.set_default_device("cuda") + device = "cuda" + torch.manual_seed(0) + + # Currently only HEAD_GRP_SIZE == 8 is supported + HEAD_GRP_SIZE = 8 + MAX_SEQ_LEN = max_seq_len + + # large number to reduce kv_cache reuse + NUM_BLOCKS = int(256000 / page_size) + + workspace_buffer = torch.empty(1024 * 1024 * 1024, dtype=torch.int8, device=device) + + # For decode, batch_size is num_decode_token + num_qo_heads = num_kv_heads * HEAD_GRP_SIZE + sm_scale = float(1.0 / (head_dim**0.5)) + q = torch.randn(num_seqs, num_qo_heads, head_dim, device=device, dtype=dtype) + kv_lens = [random.randint(1, MAX_SEQ_LEN) for _ in range(num_seqs)] + + max_kv_len = max(kv_lens) + kv_lens_tensor = torch.tensor(kv_lens, dtype=torch.int, device=device) + max_num_blocks_per_seq = (max_kv_len + page_size - 1) // page_size + + block_tables = torch.randint( + 0, NUM_BLOCKS, (num_seqs, max_num_blocks_per_seq), dtype=torch.int32 + ) + + kv_cache_shape = (NUM_BLOCKS, 2, num_kv_heads, page_size, head_dim) + kv_cache = torch.randn(size=kv_cache_shape, device=device, dtype=dtype) + k_scale = v_scale = 1.0 + + if kv_cache_dtype.startswith("fp8"): + kv_cache, _ = to_float8(kv_cache) + + # Benchmark TRT decode + def trt_decode(): + return flashinfer.decode.trtllm_batch_decode_with_kv_cache( + q, + kv_cache, + workspace_buffer, + num_qo_heads, + num_kv_heads, + sm_scale, + block_tables, + kv_lens_tensor, + page_size, + max_kv_len, + kv_cache_dtype, + k_scale, + v_scale, + ) + + def time_fn(fn, warmup=10, trials=20): + torch.cuda.synchronize() + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + times = [] + for i in range(warmup): + fn() + for i in range(trials): + start.record() + fn() + end.record() + torch.cuda.synchronize() + times.append(start.elapsed_time(end)) # ms + return sum(times) / len(times), torch.std(torch.tensor(times)) + + # TRT Decode + trt_mean, trt_std = time_fn(trt_decode) + + kv_indptr = [0] + kv_indices = [] + kv_last_page_lens = [] + for i in range(num_seqs): + seq_len = kv_lens[i] + assert seq_len > 0 + num_blocks = (seq_len + page_size - 1) // page_size + kv_indices.extend(block_tables[i, :num_blocks]) + kv_indptr.append(kv_indptr[-1] + num_blocks) + kv_last_page_len = seq_len % page_size + if kv_last_page_len == 0: + kv_last_page_len = page_size + kv_last_page_lens.append(kv_last_page_len) + + kv_indptr = torch.tensor(kv_indptr, dtype=torch.int32) + kv_indices = torch.tensor(kv_indices, dtype=torch.int32) + kv_last_page_lens = torch.tensor(kv_last_page_lens, dtype=torch.int32) + + wrapper = flashinfer.BatchDecodeWithPagedKVCacheWrapper( + workspace_buffer, + kv_layout, + use_tensor_cores=((num_qo_heads // num_kv_heads) > 4), + ) + + wrapper.plan( + kv_indptr, + kv_indices, + kv_last_page_lens, + num_qo_heads, + num_kv_heads, + head_dim, + page_size, + "NONE", + q_data_type=dtype, + kv_data_type=torch.float8_e4m3fn if kv_cache_dtype.startswith("fp8") else dtype, + ) + + def baseline_decode(): + return wrapper.run(q, kv_cache, sm_scale, k_scale, v_scale) + + baseline_mean, baseline_std = time_fn(baseline_decode) + + # Calculate percentage speedup (positive means TRT is faster) + speedup_percent = (baseline_mean - trt_mean) / baseline_mean + + print( + f"\t{num_seqs}\t{max_seq_len}\t{trt_mean:.3f}\t{trt_std.item():.3f}" + f"\t{baseline_mean:.3f}\t{baseline_std.item():.3f}\t{speedup_percent:.3f}" + ) + + # Return results for CSV writing + return { + "num_seqs": num_seqs, + "trt_mean": trt_mean, + "trt_std": trt_std.item(), + "baseline_mean": baseline_mean, + "baseline_std": baseline_std.item(), + "speedup_percent": speedup_percent, + "q_dtype": str(dtype), + "kv_cache_dtype": kv_cache_dtype, + "page_size": page_size, + "num_kv_heads": num_kv_heads, + "head_dim": head_dim, + "max_seq_len": max_seq_len, + } + + +def write_results_to_csv(results, filename=None): + """Write benchmark results to CSV file.""" + if filename is None: + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + filename = f"flashinfer_trtllm_benchmark_{timestamp}.csv" + + fieldnames = [ + "num_seqs", + "trt_mean", + "trt_std", + "baseline_mean", + "baseline_std", + "speedup_percent", + "q_dtype", + "kv_cache_dtype", + "page_size", + "num_kv_heads", + "head_dim", + "max_seq_len", + ] + + file_exists = os.path.exists(filename) + + with open(filename, "a", newline="") as csvfile: + writer = csv.DictWriter(csvfile, fieldnames=fieldnames) + + if not file_exists: + writer.writeheader() + + for result in results: + writer.writerow(result) + + print(f"Results written to {filename}") + + +if __name__ == "__main__": + num_seqs = [1, 4, 8, 16, 32, 64, 128, 256] + max_seq_lens = [1024, 2048, 4096, 8192, 16384, 32768, 65536, 131072] + all_results = [] + + print("Running benchmark for kv_cache_dtype: bfloat16") + print( + "\tnum_seqs\tmax_seq_len\ttrt_mean\ttrt_std\tbaseline_mean\tbaseline_std\tspeedup_percent" + ) + for max_seq_len in max_seq_lens: + for bs in num_seqs: + result = benchmark_decode( + bs, max_seq_len, dtype=torch.bfloat16, kv_cache_dtype="auto" + ) + all_results.append(result) + + print("Running benchmark for q_dtype = bfloat16, kv_cache_dtype: fp8") + print( + "\tnum_seqs\tmax_seq_len\ttrt_mean\ttrt_std\tbaseline_mean\tbaseline_std\tspeedup_percent" + ) + for max_seq_len in max_seq_lens: + for bs in num_seqs: + result = benchmark_decode( + bs, max_seq_len, dtype=torch.bfloat16, kv_cache_dtype="fp8" + ) + all_results.append(result) + + # Write all results to CSV + write_results_to_csv(all_results) diff --git a/tests/kernels/attention/test_flashinfer_trtllm_decode_attention.py b/tests/kernels/attention/test_flashinfer_trtllm_decode_attention.py new file mode 100644 index 0000000000000..96eee13695a9d --- /dev/null +++ b/tests/kernels/attention/test_flashinfer_trtllm_decode_attention.py @@ -0,0 +1,140 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from typing import Optional + +import flashinfer +import pytest +import torch + +from vllm.platforms import current_platform + +if not current_platform.is_device_capability(100): + pytest.skip("This TRTLLM kernel requires NVIDIA Blackwell.", + allow_module_level=True) + +FLOAT32_BYTES = torch.finfo(torch.float).bits // 8 + +# KV Cache Layout for TRT-LLM +# kv_cache_shape = (num_blocks, 2, num_kv_heads, page_size, head_dim) + +NUM_HEADS = [(64, 8), (16, 16), (40, 8), (32, 8)] +HEAD_SIZES = [128] +BLOCK_SIZES = [16, 32] +DTYPES = [torch.float16, torch.bfloat16] +NUM_BLOCKS = 32768 # Large enough to test overflow in index calculation. +SOFT_CAPS = [None, 30.0, 50.0] + + +def to_float8(x, dtype=torch.float8_e4m3fn): + finfo = torch.finfo(dtype) + min_val, max_val = x.aminmax() + amax = torch.maximum(min_val.abs(), max_val.abs()).clamp(min=1e-12) + scale = finfo.max / amax * 0.1 + x_scl_sat = (x * scale).clamp(min=finfo.min, max=finfo.max) + return x_scl_sat.to(dtype), scale.float().reciprocal() + + +@pytest.mark.parametrize("kv_lens", [[1328, 18, 463], [1, 54, 293, 70]]) +@pytest.mark.parametrize("num_heads", NUM_HEADS) +@pytest.mark.parametrize("head_size", HEAD_SIZES) +@pytest.mark.parametrize("block_size", BLOCK_SIZES) +@pytest.mark.parametrize("kv_layout", ["HND"]) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("soft_cap", SOFT_CAPS) +@torch.inference_mode +def test_flashinfer_trtllm_decode_with_baseline( + kv_lens: list[int], + num_heads: tuple[int, int], + head_size: int, + dtype: torch.dtype, + block_size: int, + soft_cap: Optional[float], + kv_layout: str, +) -> None: + torch.set_default_device("cuda") + current_platform.seed_everything(0) + num_seqs = len(kv_lens) + num_query_heads = num_heads[0] + num_kv_heads = num_heads[1] + + assert num_query_heads % num_kv_heads == 0 + max_kv_len = max(kv_lens) + scale = head_size**-0.5 + + query = torch.randn(num_seqs, num_query_heads, head_size, dtype=dtype) + kv_cache_shape = None + if kv_layout == "NHD": + kv_cache_shape = (NUM_BLOCKS, 2, block_size, num_kv_heads, head_size) + elif kv_layout == "HND": + kv_cache_shape = (NUM_BLOCKS, 2, num_kv_heads, block_size, head_size) + else: + raise ValueError(f"Invalid kv_layout: {kv_layout}") + key_value_cache = torch.randn(kv_cache_shape, dtype=dtype) + + max_num_blocks_per_seq = (max_kv_len + block_size - 1) // block_size + block_tables = torch.randint(0, + NUM_BLOCKS, + (num_seqs, max_num_blocks_per_seq), + dtype=torch.int32) + k_scale = v_scale = 1.0 + kv_indptr = [0] + kv_indices = [] + kv_last_page_lens = [] + for i in range(num_seqs): + seq_len = kv_lens[i] + assert seq_len > 0 + num_blocks = (seq_len + block_size - 1) // block_size + kv_indices.extend(block_tables[i, :num_blocks]) + kv_indptr.append(kv_indptr[-1] + num_blocks) + kv_last_page_len = seq_len % block_size + if kv_last_page_len == 0: + kv_last_page_len = block_size + kv_last_page_lens.append(kv_last_page_len) + + kv_indptr = torch.tensor(kv_indptr, dtype=torch.int32) + kv_indices = torch.tensor(kv_indices, dtype=torch.int32) + kv_last_page_lens = torch.tensor(kv_last_page_lens, dtype=torch.int32) + + workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8) + wrapper = flashinfer.\ + BatchDecodeWithPagedKVCacheWrapper(workspace_buffer, kv_layout, + use_tensor_cores=( + (num_query_heads//num_kv_heads) > 4) + ) + wrapper.plan(kv_indptr, + kv_indices, + kv_last_page_lens, + num_query_heads, + num_kv_heads, + head_size, + block_size, + "NONE", + q_data_type=dtype, + kv_data_type=dtype, + logits_soft_cap=soft_cap) + + output = wrapper.run(query, key_value_cache, scale) + + # TRTLLM Decode + max_kv_len = max(kv_lens) + kv_lens_tensor = torch.tensor(kv_lens, + dtype=torch.int, + device=query.device) + output_trtllm = flashinfer.decode.trtllm_batch_decode_with_kv_cache( + query.contiguous(), + key_value_cache, + workspace_buffer, + num_query_heads, + num_kv_heads, + scale, + block_tables, + kv_lens_tensor, + block_size, + max_kv_len, + "auto", + k_scale, + v_scale, + ) + + torch.testing.assert_close(output, output_trtllm, atol=1e-2, rtol=1e-2), \ + f"{torch.max(torch.abs(output - output_trtllm))}" diff --git a/vllm/attention/backends/flashinfer.py b/vllm/attention/backends/flashinfer.py index b7d80f5194c0f..5bbe340b14300 100644 --- a/vllm/attention/backends/flashinfer.py +++ b/vllm/attention/backends/flashinfer.py @@ -11,7 +11,8 @@ from vllm.multimodal import MultiModalPlaceholderMap try: from flashinfer import BatchDecodeWithPagedKVCacheWrapper - from flashinfer.decode import CUDAGraphBatchDecodeWithPagedKVCacheWrapper + from flashinfer.decode import (CUDAGraphBatchDecodeWithPagedKVCacheWrapper, + trtllm_batch_decode_with_kv_cache) from flashinfer.prefill import BatchPrefillWithPagedKVCacheWrapper from vllm.vllm_flash_attn import flash_attn_varlen_func @@ -22,7 +23,10 @@ except ImportError: BatchDecodeWithPagedKVCacheWrapper = None CUDAGraphBatchDecodeWithPagedKVCacheWrapper = None BatchPrefillWithPagedKVCacheWrapper = None + trtllm_batch_decode_with_kv_cache = None FLASHINFER_WORKSPACE_BUFFER_SIZE = 0 + raise ImportError("FlashInfer is not installed. Please install it from " + "https://github.com/flashinfer-ai/flashinfer") from None import torch @@ -40,6 +44,7 @@ from vllm.attention.layer import Attention from vllm.attention.ops.paged_attn import PagedAttention from vllm.config import VllmConfig, get_layers_from_vllm_config from vllm.logger import init_logger +from vllm.platforms import current_platform from vllm.utils import (async_tensor_h2d, get_kv_cache_torch_dtype, make_tensor_with_pad) @@ -49,10 +54,9 @@ if TYPE_CHECKING: from vllm.worker.model_runner import (ModelInputForGPUBuilder, ModelInputForGPUWithSamplingMetadata) -FLASHINFER_KV_CACHE_LAYOUT: str = envs.VLLM_KV_CACHE_LAYOUT or "NHD" - class FlashInferBackend(AttentionBackend): + cached_sm100a_supported: Optional[bool] = None @staticmethod def get_name() -> str: @@ -85,7 +89,7 @@ class FlashInferBackend(AttentionBackend): @staticmethod def get_kv_cache_stride_order() -> Tuple[int, ...]: - cache_layout = FLASHINFER_KV_CACHE_LAYOUT + cache_layout = FlashInferState.get_kv_cache_layout() assert (cache_layout in ("NHD", "HND")) stride_order = (0, 1, 2, 3, 4) if cache_layout == "NHD" else (0, 1, 3, 2, 4) @@ -119,6 +123,47 @@ class FlashInferBackend(AttentionBackend): else: raise ValueError(f"Unrecognized FP8 dtype: {kv_cache_dtype}") + @staticmethod + def use_trtllm_decode_attention( + batch_size: int, + max_seq_len: int, + kv_cache_dtype: str, + num_qo_heads: Optional[int], + num_kv_heads: Optional[int], + attn_head_size: Optional[int], + ) -> bool: + if FlashInferBackend.cached_sm100a_supported is None: + FlashInferBackend.cached_sm100a_supported = ( + current_platform.has_device_capability(100)) + if not FlashInferBackend.cached_sm100a_supported: + return False + # Check if the dimensions are supported by TRTLLM decode attention + if (attn_head_size is None or num_qo_heads is None + or num_kv_heads is None or num_qo_heads // num_kv_heads > 8 + or num_qo_heads % num_kv_heads != 0 or attn_head_size != 128): + return False + env_value = envs.VLLM_USE_TRTLLM_DECODE_ATTENTION + if env_value is not None: + logger.info_once("VLLM_USE_TRTLLM_DECODE_ATTENTION is set to %s", + env_value) + # Environment variable is set - respect it + # Making the conditional check for zero because + # the path is automatically enabled if the batch size condition + # is satisfied. + no_use_trtllm = (env_value == "0") + if not no_use_trtllm: + logger.info_once("Using TRTLLM decode attention.") + return not no_use_trtllm + else: + # Environment variable not set - use auto-detection + use_trtllm = (FlashInferBackend.cached_sm100a_supported + and batch_size <= 256 and max_seq_len < 131072 + and kv_cache_dtype == "auto") + if use_trtllm: + logger.warning_once( + "Using TRTLLM decode attention (auto-detected).") + return use_trtllm + @dataclass class PerLayerParameters: @@ -207,10 +252,19 @@ class FlashInferState(AttentionState): device=self.runner.device) return self._workspace_buffer - def get_kv_cache_layout(self): - if self._kv_cache_layout is None: - self._kv_cache_layout = FLASHINFER_KV_CACHE_LAYOUT - return self._kv_cache_layout + @staticmethod + def get_kv_cache_layout(): + from vllm.v1.attention.backends.utils import _KV_CACHE_LAYOUT_OVERRIDE + if _KV_CACHE_LAYOUT_OVERRIDE is not None: + logger.info_once("Using KV cache layout %s", + _KV_CACHE_LAYOUT_OVERRIDE) + return _KV_CACHE_LAYOUT_OVERRIDE + cache_layout = envs.VLLM_KV_CACHE_LAYOUT + if cache_layout is None: + logger.info_once("Using default KV cache layout NHD") + return "NHD" + logger.info_once("Using KV cache layout %s", cache_layout) + return cache_layout def _get_prefill_wrapper(self): if self._prefill_wrapper is None: @@ -323,6 +377,8 @@ class FlashInferState(AttentionState): num_prefill_tokens=0, num_decode_tokens=batch_size, max_prefill_seq_len=0, + max_decode_seq_len=0, + seq_lens_tensor=self._graph_seq_lens, block_tables=self._graph_block_tables, paged_kv_indptr=paged_kv_indptr_tensor_host, paged_kv_indices=paged_kv_indices_tensor_host, @@ -348,6 +404,8 @@ class FlashInferState(AttentionState): attn_metadata, is_encoder_decoder_model: bool = False): return { + "block_tables": attn_metadata.block_tables, + "seq_lens_tensor": attn_metadata.seq_lens_tensor, "slot_mapping": attn_metadata.slot_mapping, } @@ -355,7 +413,13 @@ class FlashInferState(AttentionState): input_buffers, attn_metadata, is_encoder_decoder_model: bool = False): - return + # FlashInfer-specific logic: copy additional tensors + num_total_blocks = attn_metadata.decode_metadata.seq_lens_tensor.shape[ + 0] + input_buffers["seq_lens_tensor"][:num_total_blocks].copy_( + attn_metadata.seq_lens_tensor, non_blocking=True) + input_buffers["block_tables"][:num_total_blocks].copy_( + attn_metadata.block_tables, non_blocking=True) def begin_forward(self, model_input): assert not self._is_graph_capturing @@ -388,6 +452,8 @@ class FlashInferMetadata(AttentionMetadata): # Maximum sequence length among prefill batch. 0 if there are decoding # requests only. max_prefill_seq_len: int + max_decode_seq_len: int + # Number of query tokens for each request in the batch. # Currently, we require that all requests have the same number of query # tokens during the decoding phase. When speculavie decoding is enabled, @@ -790,6 +856,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): use_captured_graph = cuda_graph_pad_size != -1 max_prefill_seq_len = max(self.prefill_seq_lens, default=0) + max_decode_seq_len = max(self.curr_seq_lens, default=0) num_decode_tokens = self.num_decode_tokens decode_query_len = max(query_lens[self.num_prefills:], default=1) @@ -895,6 +962,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): num_prefill_tokens=self.num_prefill_tokens, num_decode_tokens=num_decode_tokens, max_prefill_seq_len=max_prefill_seq_len, + max_decode_seq_len=max_decode_seq_len, block_tables=block_tables, paged_kv_indptr=paged_kv_indptr_tensor, paged_kv_indices=paged_kv_indices_tensor, @@ -1081,13 +1149,36 @@ class FlashInferImpl(AttentionImpl): assert decode_meta.decode_wrapper._logits_soft_cap == ( logits_soft_cap or 0.0) assert decode_meta.decode_wrapper._sm_scale == softmax_scale - - decode_output = decode_meta.decode_wrapper.run( - decode_query, - kv_cache.permute(*stride_order), - k_scale=layer._k_scale_float, - v_scale=layer._v_scale_float, - ) + # TODO: @pavanimajety Remove this once the switch happens + # inside flashinfer. + if not FlashInferBackend.use_trtllm_decode_attention( + num_decode_tokens, attn_metadata.max_decode_seq_len, + kv_cache_dtype, attn_metadata.num_qo_heads, + attn_metadata.num_kv_heads, attn_metadata.head_dim): + decode_output = decode_meta.decode_wrapper.run( + decode_query, + kv_cache.permute(*stride_order), + k_scale=layer._k_scale_float, + v_scale=layer._v_scale_float, + ) + else: + workspace_buffer = ( + decode_meta.decode_wrapper._int_workspace_buffer) + assert FlashInferState.get_kv_cache_layout() == "HND" + decode_output = trtllm_batch_decode_with_kv_cache( + query=decode_query, + kv_cache=kv_cache.permute(*stride_order), + workspace_buffer=workspace_buffer, + num_heads=num_heads, + num_kv_heads=num_kv_heads, + scale=softmax_scale, + block_tables=attn_metadata.block_tables, + seq_lens=decode_meta.seq_lens_tensor, + block_size=attn_metadata.page_size, + max_seq_len=attn_metadata.max_decode_seq_len, + kv_cache_dtype=kv_cache_dtype, + k_scale=layer._k_scale_float, + v_scale=layer._v_scale_float) if prefill_output is None and decode_output is not None: # Decode only batch. diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 1b8dc640e056c..f47499309d8f6 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -1424,6 +1424,8 @@ class EngineArgs: from vllm.attention.utils.fa_utils import ( flash_attn_supports_fp8) supported = flash_attn_supports_fp8() + elif envs.VLLM_USE_TRTLLM_DECODE_ATTENTION: + supported = True if not supported: _raise_or_fallback(feature_name="--kv-cache-dtype", recommend_to_remove=False) diff --git a/vllm/envs.py b/vllm/envs.py index bf5dce2ca5c4f..7bff6ade81512 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -959,7 +959,11 @@ environment_variables: dict[str, Callable[[], Any]] = { # consumer. This is only applicable when using NixlConnector in a # disaggregated decode-prefill setup. "VLLM_NIXL_ABORT_REQUEST_TIMEOUT": - lambda: int(os.getenv("VLLM_NIXL_ABORT_REQUEST_TIMEOUT", "120")) + lambda: int(os.getenv("VLLM_NIXL_ABORT_REQUEST_TIMEOUT", "120")), + + # If set to 1, use the TRTLLM Decode Attention backend in flashinfer. + "VLLM_USE_TRTLLM_DECODE_ATTENTION": + lambda: os.getenv("VLLM_USE_TRTLLM_DECODE_ATTENTION", None), } # --8<-- [end:env-vars-definition] diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index 35a2b48c7d016..00151296a7544 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -244,6 +244,10 @@ class CudaPlatformBase(Platform): if selected_backend == _Backend.FLASHINFER: logger.info_once("Using FlashInfer backend on V1 engine.") + if cls.has_device_capability(100): + from vllm.v1.attention.backends.utils import ( + set_kv_cache_layout) + set_kv_cache_layout("HND") return FLASHINFER_V1 elif selected_backend == _Backend.FLEX_ATTENTION: logger.info_once("Using FlexAttention backend on V1 engine.") @@ -271,9 +275,13 @@ class CudaPlatformBase(Platform): supports_head_size(FLASHINFER_V1, head_size): try: import flashinfer # noqa: F401 + + from vllm.v1.attention.backends.utils import ( + set_kv_cache_layout) logger.info_once( - "Using FlashInfer backend on V1 engine by default for " - "Blackwell (SM 10.0) GPUs.") + "Using FlashInfer backend with HND KV cache layout on " + "V1 engine by default for Blackwell (SM 10.0) GPUs.") + set_kv_cache_layout("HND") return FLASHINFER_V1 except ImportError: logger.info_once( @@ -293,6 +301,13 @@ class CudaPlatformBase(Platform): # Backends for V0 engine if selected_backend == _Backend.FLASHINFER: logger.info("Using FlashInfer backend.") + if cls.has_device_capability(100): + from vllm.v1.attention.backends.utils import ( + set_kv_cache_layout) + logger.info_once( + "Using HND KV cache layout on V1 engine by default for " + "Blackwell (SM 10.0) GPUs.") + set_kv_cache_layout("HND") return "vllm.attention.backends.flashinfer.FlashInferBackend" elif selected_backend == _Backend.XFORMERS: logger.info("Using XFormers backend.") diff --git a/vllm/v1/attention/backends/flashinfer.py b/vllm/v1/attention/backends/flashinfer.py index 4cca618f6b3c9..4ae595c976b3e 100755 --- a/vllm/v1/attention/backends/flashinfer.py +++ b/vllm/v1/attention/backends/flashinfer.py @@ -10,11 +10,13 @@ import torch from flashinfer import (BatchDecodeWithPagedKVCacheWrapper, BatchPrefillWithPagedKVCacheWrapper, MultiLevelCascadeAttentionWrapper) +from flashinfer.decode import trtllm_batch_decode_with_kv_cache import vllm.envs as envs from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, AttentionType) from vllm.logger import init_logger +from vllm.platforms import current_platform from vllm.v1.attention.backends.flash_attn import use_cascade_attention from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder, CommonAttentionMetadata, @@ -38,6 +40,7 @@ logger = init_logger(__name__) class FlashInferBackend(AttentionBackend): accept_output_buffer: bool = True + cached_sm100a_supported: Optional[bool] = None @classmethod def get_supported_head_sizes(cls) -> list[int]: @@ -93,6 +96,57 @@ class FlashInferBackend(AttentionBackend): raise ValueError(f"Unknown cache layout format {cache_layout}.") return stride_order + @staticmethod + def use_trtllm_decode_attention( + batch_size: int, + max_seq_len: int, + kv_cache_dtype: str, + num_qo_heads: int, + num_kv_heads: int, + attn_head_size: int, + ) -> bool: + if FlashInferBackend.cached_sm100a_supported is None: + FlashInferBackend.cached_sm100a_supported = ( + current_platform.has_device_capability(100)) + if not FlashInferBackend.cached_sm100a_supported: + return False + if (num_qo_heads // num_kv_heads > 8 + or num_qo_heads % num_kv_heads != 0 or attn_head_size != 128): + return False + env_value = envs.VLLM_USE_TRTLLM_DECODE_ATTENTION + if env_value is not None: + logger.info_once("VLLM_USE_TRTLLM_DECODE_ATTENTION is set to %s", + env_value) + # Environment variable is set - respect it + # Making the conditional check for zero because + # the path is automatically enabled if the batch size condition + # is satisfied. + no_use_trtllm = env_value == "0" + if not no_use_trtllm: + logger.info_once( + "VLLM_USE_TRTLLM_DECODE_ATTENTION is set to 1, " + "using TRTLLM decode attention.") + return not no_use_trtllm + else: + # Environment variable not set - use auto-detection + # Only supports attention head size of 128 + use_trtllm = (FlashInferBackend.cached_sm100a_supported + and batch_size <= 256 and max_seq_len < 131072 + and kv_cache_dtype == "auto") + if use_trtllm: + logger.warning_once( + "Using TRTLLM decode attention (auto-detected).") + return use_trtllm + + @staticmethod + def get_fp8_dtype_for_flashinfer(kv_cache_dtype: str) -> torch.dtype: + if kv_cache_dtype in ("fp8", "fp8_e4m3"): + return torch.float8_e4m3fn + elif kv_cache_dtype == "fp8_e5m2": + return torch.float8_e5m2 + else: + raise ValueError(f"Unrecognized FP8 dtype: {kv_cache_dtype}") + @dataclass class FlashInferMetadata: @@ -127,12 +181,18 @@ class FlashInferMetadata: # Block size of vllm page_size: int # The data type of the paged kv cache - data_type: torch.dtype + kv_data_type: torch.dtype # The data type of the query q_data_type: torch.dtype slot_mapping: torch.Tensor + # For flashinfer trtllm batch decode + max_seq_len: int + seq_lens: torch.Tensor + block_table_tensor: torch.Tensor + workspace_buffer: torch.Tensor + # For handling prefill decode split num_decodes: int num_decode_tokens: int @@ -299,6 +359,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): window_left=self.global_hyperparameters.window_left, logits_soft_cap=self.global_hyperparameters.logits_soft_cap, q_data_type=attn_metadata.q_data_type, + kv_data_type=attn_metadata.kv_data_type, ) else: # Regular attention (common case). @@ -334,28 +395,33 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): logits_soft_cap=self.global_hyperparameters. logits_soft_cap, q_data_type=attn_metadata.q_data_type, - kv_data_type=attn_metadata.data_type, + kv_data_type=attn_metadata.kv_data_type, ) if self._num_decodes > 0: attn_metadata.decode_wrapper = self._get_decode_wrapper() - attn_metadata.decode_wrapper.plan( - attn_metadata.paged_kv_indptr[:self._num_decodes + 1], - attn_metadata.paged_kv_indices, - attn_metadata.paged_kv_last_page_len[:self._num_decodes], - attn_metadata.num_qo_heads, - attn_metadata.num_kv_heads, - attn_metadata.head_dim, - attn_metadata.page_size, - # Disable flashinfer's pos encoding and use vllm's rope. - pos_encoding_mode="NONE", - sm_scale=self.global_hyperparameters.sm_scale, - window_left=self.global_hyperparameters.window_left, - logits_soft_cap=self.global_hyperparameters. - logits_soft_cap, - q_data_type=attn_metadata.q_data_type, - kv_data_type=attn_metadata.data_type, - ) + if not FlashInferBackend.use_trtllm_decode_attention( + self._num_decodes, attn_metadata.max_seq_len, + attn_metadata.kv_data_type, attn_metadata.num_qo_heads, + attn_metadata.num_kv_heads, attn_metadata.head_dim): + attn_metadata.decode_wrapper.plan( + attn_metadata.paged_kv_indptr[:self._num_decodes + 1], + attn_metadata.paged_kv_indices, + attn_metadata.paged_kv_last_page_len[:self. + _num_decodes], + attn_metadata.num_qo_heads, + attn_metadata.num_kv_heads, + attn_metadata.head_dim, + attn_metadata.page_size, + # Disable flashinfer's pos encoding and use vllm's rope. + pos_encoding_mode="NONE", + sm_scale=self.global_hyperparameters.sm_scale, + window_left=self.global_hyperparameters.window_left, + logits_soft_cap=self.global_hyperparameters. + logits_soft_cap, + q_data_type=attn_metadata.q_data_type, + kv_data_type=attn_metadata.kv_data_type, + ) def build(self, common_prefix_len: int, common_attn_metadata: CommonAttentionMetadata): @@ -368,6 +434,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): page_size = self.kv_cache_spec.block_size device = self.runner.device qo_indptr = common_attn_metadata.query_start_loc + max_seq_len = int(self.runner.seq_lens_np[:num_reqs].max()) seq_lens = common_attn_metadata.seq_lens block_table_tensor = self.block_table.get_device_tensor()[:num_reqs] slot_mapping = self.block_table.slot_mapping_cpu[:num_actual_tokens].to( @@ -416,7 +483,12 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): paged_kv_last_page_len = seq_lens % page_size paged_kv_last_page_len = torch.where(paged_kv_last_page_len == 0, page_size, paged_kv_last_page_len) - + cache_dtype = self.runner.cache_config.cache_dtype + if cache_dtype.startswith("fp8"): + kv_cache_dtype = FlashInferBackend.get_fp8_dtype_for_flashinfer( + cache_dtype) + else: + kv_cache_dtype = self.kv_cache_spec.dtype attn_metadata = FlashInferMetadata( num_actual_tokens=num_actual_tokens, qo_indptr=qo_indptr, @@ -427,7 +499,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): num_kv_heads=self.kv_cache_spec.num_kv_heads, head_dim=self.kv_cache_spec.head_size, page_size=page_size, - data_type=self.kv_cache_spec.dtype, + kv_data_type=kv_cache_dtype, q_data_type=self.runner.dtype, slot_mapping=slot_mapping, num_decodes=self._num_decodes, @@ -439,6 +511,10 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): shared_kv_page_indptr=shared_kv_page_indptr, shared_kv_page_indices=shared_kv_page_indices, shared_kv_last_page_len=shared_kv_last_page_len, + max_seq_len=max_seq_len, + seq_lens=seq_lens, + block_table_tensor=block_table_tensor, + workspace_buffer=self._workspace_buffer, ) self._plan(attn_metadata) @@ -514,7 +590,11 @@ class FlashInferImpl(AttentionImpl): query: shape = [num_tokens, num_heads, head_size] key: shape = [num_tokens, num_kv_heads, head_size] value: shape = [num_tokens, num_kv_heads, head_size] - kv_cache = [num_blocks, 2, block_size, num_kv_heads, head_size] + kv_cache: shape - + # NHD: [num_blocks, 2, block_size, num_kv_heads, head_size] + # HND: [num_blocks, 2, num_kv_heads, block_size, head_size] + + attn_metadata: Metadata for attention. Returns: shape = [num_tokens, num_heads * head_size] @@ -560,6 +640,13 @@ class FlashInferImpl(AttentionImpl): layer._v_scale, ) + # The FlashInfer api requires data to be in fp8_e4m3 or fp8_e5m2 + # to process the cache when the kv_cache_dtype is fp8 + if self.kv_cache_dtype.startswith("fp8"): + torch_dtype = FlashInferBackend.get_fp8_dtype_for_flashinfer( + self.kv_cache_dtype) + kv_cache = kv_cache.view(torch_dtype) + window_left = (self.sliding_window[0] if self.sliding_window is not None else -1) @@ -597,21 +684,45 @@ class FlashInferImpl(AttentionImpl): v_scale=layer._v_scale_float, out=output[num_decode_tokens:], ) - if decode_wrapper := attn_metadata.decode_wrapper: decode_query = query[:num_decode_tokens] assert decode_query.shape[0] == num_decode_tokens - assert decode_wrapper is not None - assert decode_wrapper._window_left == window_left - assert decode_wrapper._logits_soft_cap == (self.logits_soft_cap - or 0.0) - assert decode_wrapper._sm_scale == self.scale - decode_wrapper.run( - decode_query, - kv_cache.permute(*stride_order), - k_scale=layer._k_scale_float, - v_scale=layer._v_scale_float, - out=output[:num_decode_tokens], - ) - + if not FlashInferBackend.use_trtllm_decode_attention( + attn_metadata.num_decodes, attn_metadata.max_seq_len, + self.kv_cache_dtype, attn_metadata.num_qo_heads, + attn_metadata.num_kv_heads, attn_metadata.head_dim): + assert decode_wrapper is not None + assert decode_wrapper._window_left == window_left + assert decode_wrapper._logits_soft_cap == (self.logits_soft_cap + or 0.0) + assert decode_wrapper._sm_scale == self.scale + decode_wrapper.run( + decode_query, + kv_cache.permute(*stride_order), + k_scale=layer._k_scale_float, + v_scale=layer._v_scale_float, + out=output[:num_decode_tokens], + ) + else: + # This path needs to be enabled with VLLM_KV_CACHE_LAYOUT = HND + if num_decode_tokens > 0: + assert get_kv_cache_layout() == "HND" + output[:num_decode_tokens] = ( + trtllm_batch_decode_with_kv_cache( + query=decode_query, + kv_cache=kv_cache.permute(*stride_order), + workspace_buffer=attn_metadata.workspace_buffer, + num_heads=self.num_heads, + num_kv_heads=self.num_kv_heads, + scale=self.scale, + block_tables=attn_metadata. + block_table_tensor[:num_decode_tokens], + seq_lens=attn_metadata. + seq_lens[:num_decode_tokens], + block_size=attn_metadata.page_size, + max_seq_len=attn_metadata.max_seq_len, + kv_cache_dtype=self.kv_cache_dtype, + k_scale=layer._k_scale_float, + v_scale=layer._v_scale_float, + )) return output_padded diff --git a/vllm/v1/attention/backends/utils.py b/vllm/v1/attention/backends/utils.py index 3787b39a81be5..88adc32406e4a 100644 --- a/vllm/v1/attention/backends/utils.py +++ b/vllm/v1/attention/backends/utils.py @@ -24,6 +24,7 @@ from vllm.distributed.kv_transfer.kv_connector.utils import ( from vllm.logger import init_logger logger = init_logger(__name__) +_KV_CACHE_LAYOUT_OVERRIDE = None @dataclass @@ -103,6 +104,7 @@ class AttentionMetadataBuilder(abc.ABC, Generic[M]): @functools.lru_cache def get_kv_cache_layout(): + global _KV_CACHE_LAYOUT_OVERRIDE # Override with format specified by the user. cache_layout = envs.VLLM_KV_CACHE_LAYOUT if cache_layout is None: @@ -110,10 +112,16 @@ def get_kv_cache_layout(): else: logger.info_once("`VLLM_KV_CACHE_LAYOUT` environment variable " \ "detected. Setting KV cache layout to %s.", cache_layout) - + if _KV_CACHE_LAYOUT_OVERRIDE is not None: + cache_layout = _KV_CACHE_LAYOUT_OVERRIDE return cache_layout +def set_kv_cache_layout(cache_layout: str): + global _KV_CACHE_LAYOUT_OVERRIDE + _KV_CACHE_LAYOUT_OVERRIDE = cache_layout + + @dataclass class PerLayerParameters: """