diff --git a/benchmarks/kernels/benchmark_reshape_and_cache_flash.py b/benchmarks/kernels/benchmark_reshape_and_cache_flash.py index d4648c18f31d5..0aace571064a0 100644 --- a/benchmarks/kernels/benchmark_reshape_and_cache_flash.py +++ b/benchmarks/kernels/benchmark_reshape_and_cache_flash.py @@ -9,6 +9,9 @@ import torch from tabulate import tabulate from vllm import _custom_ops as ops +from vllm.attention.ops.triton_reshape_and_cache_flash import ( + triton_reshape_and_cache_flash, +) from vllm.logger import init_logger from vllm.platforms import current_platform from vllm.utils import ( @@ -31,6 +34,8 @@ def run_benchmark( kv_cache_dtype: str, kv_cache_layout: str, num_iters: int, + implementation: str, + benchmark_mode: str, device: str = "cuda", ) -> float: """Return latency (seconds) for given num_tokens.""" @@ -38,6 +43,14 @@ def run_benchmark( if kv_cache_dtype == "fp8" and head_size % 16: raise ValueError("fp8 kv-cache requires head_size to be a multiple of 16.") + if implementation not in ("cuda", "triton"): + raise ValueError( + f"Unsupported implementation: {implementation}. " + "Only 'cuda' and 'triton' are supported." + ) + if implementation == "triton" and kv_cache_layout == "HND": + return float("nan") # Triton does not support HND layout yet. + current_platform.seed_everything(42) torch.set_default_device(device) @@ -65,27 +78,49 @@ def run_benchmark( cache_layout=kv_cache_layout, ) key_cache, value_cache = key_caches[0], value_caches[0] + # to free unused memory + del key_caches, value_caches # compute per-kernel scaling factors for fp8 conversion (if used). k_scale = (key.amax() / 64.0).to(torch.float32) v_scale = (value.amax() / 64.0).to(torch.float32) + if implementation == "cuda": + function_under_test = lambda: ops.reshape_and_cache_flash( + key, # noqa: F821 + value, # noqa: F821 + key_cache, # noqa: F821 + value_cache, # noqa: F821 + slot_mapping, # noqa: F821 + kv_cache_dtype, + k_scale, + v_scale, + ) + else: + function_under_test = lambda: triton_reshape_and_cache_flash( + key, # noqa: F821 + value, # noqa: F821 + key_cache, # noqa: F821 + value_cache, # noqa: F821 + slot_mapping, # noqa: F821 + kv_cache_dtype, + k_scale, + v_scale, + ) + if benchmark_mode == "cudagraph": + g = torch.cuda.CUDAGraph() + with torch.cuda.graph(g): + function_under_test() + torch.cuda.synchronize() + function_under_test = lambda: g.replay() + def run_cuda_benchmark(n_iters: int) -> float: nonlocal key, value, key_cache, value_cache, slot_mapping torch.cuda.synchronize() start = time.perf_counter() for _ in range(n_iters): - ops.reshape_and_cache_flash( - key, - value, - key_cache, - value_cache, - slot_mapping, - kv_cache_dtype, - k_scale, - v_scale, - ) - torch.cuda.synchronize() + function_under_test() + torch.cuda.synchronize() end = time.perf_counter() return (end - start) / n_iters @@ -116,10 +151,16 @@ def main(args): kv_cache_dtype=args.kv_cache_dtype, kv_cache_layout=layout, num_iters=args.iters, + implementation=args.implementation, + benchmark_mode=args.mode, device="cuda", ) rows.append([n_tok, layout, f"{lat * 1e6:.3f}"]) + print( + f"Benchmark results for implementation {args.implementation}" + f" (measuring with {args.mode}):" + ) print(tabulate(rows, headers=["num_tokens", "layout", "latency (µs)"])) @@ -151,6 +192,21 @@ if __name__ == "__main__": ) parser.add_argument("--iters", type=int, default=100) + + parser.add_argument( + "--implementation", + type=str, + choices=["cuda", "triton"], + default="cuda", + ) + + parser.add_argument( + "--mode", + type=str, + choices=["cudagraph", "no_graph"], + default="cudagraph", + ) + args = parser.parse_args() main(args) diff --git a/tests/kernels/attention/test_cache.py b/tests/kernels/attention/test_cache.py index 69e96dfd2cb13..1325e6883132a 100644 --- a/tests/kernels/attention/test_cache.py +++ b/tests/kernels/attention/test_cache.py @@ -39,6 +39,8 @@ CUDA_DEVICES = [ # We assume fp8 is always enabled for testing. KV_CACHE_DTYPE = ["auto", "fp8"] +RESHAPE_FLASH_IMPLEMENTATIONS = ["cuda", "triton"] + @pytest.mark.parametrize("num_mappings", NUM_MAPPINGS) @pytest.mark.parametrize("num_layers", NUM_LAYERS) @@ -223,6 +225,7 @@ def test_reshape_and_cache( @pytest.mark.parametrize("device", CUDA_DEVICES) @pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE) @pytest.mark.parametrize("kv_cache_layout", CACHE_LAYOUTS) +@pytest.mark.parametrize("implementation", RESHAPE_FLASH_IMPLEMENTATIONS) @torch.inference_mode() def test_reshape_and_cache_flash( kv_cache_factory_flashinfer, @@ -236,9 +239,13 @@ def test_reshape_and_cache_flash( device: str, kv_cache_dtype: str, kv_cache_layout: str, + implementation: str, ) -> None: current_platform.seed_everything(seed) torch.set_default_device(device) + assert implementation in ["cuda", "triton"] + if implementation == "triton" and kv_cache_layout == "HND": + pytest.skip("Triton implementation only supports NHD layout.") # fp8 conversion requires continugous memory buffer. Reduce the number of # blocks and tokens to consume less memory. @@ -298,12 +305,20 @@ def test_reshape_and_cache_flash( cloned_key_cache = key_cache_compact.clone() cloned_value_cache = value_cache_compact.clone() # Call the reshape_and_cache kernel. - opcheck(torch.ops._C_cache_ops.reshape_and_cache_flash, - (key, value, key_cache, value_cache, slot_mapping, kv_cache_dtype, - k_scale, v_scale), - cond=(head_size == HEAD_SIZES[0])) - ops.reshape_and_cache_flash(key, value, key_cache, value_cache, - slot_mapping, kv_cache_dtype, k_scale, v_scale) + if implementation == "cuda": + opcheck(torch.ops._C_cache_ops.reshape_and_cache_flash, + (key, value, key_cache, value_cache, slot_mapping, + kv_cache_dtype, k_scale, v_scale), + cond=(head_size == HEAD_SIZES[0])) + ops.reshape_and_cache_flash(key, value, key_cache, value_cache, + slot_mapping, kv_cache_dtype, k_scale, + v_scale) + elif implementation == "triton": + from vllm.attention.ops.triton_reshape_and_cache_flash import ( + triton_reshape_and_cache_flash) + triton_reshape_and_cache_flash(key, value, key_cache, value_cache, + slot_mapping, kv_cache_dtype, k_scale, + v_scale) key_cache_compact = permute_and_compact(key_cache) value_cache_compact = permute_and_compact(value_cache) diff --git a/vllm/attention/ops/triton_reshape_and_cache_flash.py b/vllm/attention/ops/triton_reshape_and_cache_flash.py new file mode 100644 index 0000000000000..0b0c706626af3 --- /dev/null +++ b/vllm/attention/ops/triton_reshape_and_cache_flash.py @@ -0,0 +1,176 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import torch +import triton +import triton.language as tl + +from vllm.platforms import current_platform + + +@triton.jit +def reshape_and_cache_kernel_flash( + key_ptr, # [num_tokens, num_heads, head_size] + value_ptr, # [num_tokens, num_heads, head_size] + key_cache_ptr, # [num_blocks, block_size, num_heads, head_size] + value_cache_ptr, # [num_blocks, block_size, num_heads, head_size] + slot_mapping_ptr, # [num_tokens] + k_scale, # float32 + v_scale, # float32 + # strides + key_stride: tl.int64, + value_stride: tl.int64, + block_stride: tl.int64, + page_stride: tl.int64, + num_heads: tl.constexpr, + head_size: tl.constexpr, + block_size: tl.constexpr, + # FP8 flags + FP8_KV_CACHE: tl.constexpr, + # tune parameters + TILE_SIZE: tl.constexpr, +): + + token_idx = tl.program_id(axis=0) + slot_idx = tl.load(slot_mapping_ptr + token_idx).to(tl.int64) + if slot_idx < 0: + # Padding token that should be ignored. + return + + tile_i = tl.program_id(axis=1) + tile_offs = tl.arange(0, TILE_SIZE) + tile_pos = tile_i * TILE_SIZE + tile_offs + + block_idx = slot_idx // block_size + block_offset = slot_idx % block_size + + src_key_idx = token_idx * key_stride + src_value_idx = token_idx * value_stride + + tgt_idx = block_idx * block_stride + block_offset * page_stride + + # [TILE_SIZE] + key_load = tl.load(key_ptr + src_key_idx + tile_pos, + mask=tile_pos < (num_heads * head_size)) + if FP8_KV_CACHE: + if key_load.dtype.is_fp8(): + key_tile = key_load + else: + # tl.store will do the correct implicit cast to fp8, + # based on the key_cache_ptr.dtype.element_ty + key_tile = key_load / tl.load(k_scale) + else: + key_tile = key_load + + # [TILE_SIZE] + value_load = tl.load(value_ptr + src_value_idx + tile_pos, + mask=tile_pos < (num_heads * head_size)) + if FP8_KV_CACHE: + if value_load.dtype.is_fp8(): + value_tile = value_load + else: + # tl.store will do the correct implicit cast to fp8, + # based on the value_cache_ptr.dtype.element_ty + value_tile = value_load / tl.load(v_scale) + else: + value_tile = value_load + + tl.store( + key_cache_ptr + tgt_idx + tile_pos, + key_tile, + mask=tile_pos < (num_heads * head_size), + ) + tl.store( + value_cache_ptr + tgt_idx + tile_pos, + value_tile, + mask=tile_pos < (num_heads * head_size), + ) + return + + +def triton_reshape_and_cache_flash( + key: torch.Tensor, # [num_tokens, num_heads, head_size] + value: torch.Tensor, # [num_tokens, num_heads, head_size] + # [num_blocks, block_size, num_heads, head_size] + key_cache: torch.Tensor, + # [num_blocks, block_size, num_heads, head_size] + value_cache: torch.Tensor, + slot_mapping: torch.Tensor, # [num_tokens] + kv_cache_dtype: str, # "auto", "fp8" + k_scale: torch.Tensor, # float32 + v_scale: torch.Tensor, # float32 +): + num_tokens = key.shape[0] + num_heads = key.shape[1] + head_size = key.shape[2] + block_size = key_cache.shape[1] + n = num_heads * head_size + + key_stride = key.stride()[0] + value_stride = value.stride()[0] + block_stride = key_cache.stride()[0] + page_stride = key_cache.stride()[1] + + head_stride = key_cache.stride()[2] + assert head_stride == head_size, "only continous heads are supported" + + assert kv_cache_dtype == "auto" or kv_cache_dtype.startswith("fp8"), \ + f"unsupported kv_cache_dtype (str), got {kv_cache_dtype}." + kv_cache_torch_dtype = current_platform.fp8_dtype() if \ + kv_cache_dtype.startswith("fp8") else key_cache.dtype + + if key_cache.dtype != kv_cache_torch_dtype and kv_cache_dtype.startswith( + "fp8"): + # to avoid erounous implicit cast in triton kernel (tl.store to uint8) + # (e.g. explicit cast to fp8e4m3fnuz is not supported in triton 3.4) + key_cache = key_cache.view(kv_cache_torch_dtype) + value_cache = value_cache.view(kv_cache_torch_dtype) + assert kv_cache_dtype != torch.uint8, "explicit fp8 cast and store to "\ + "uint8 is not supported by triton reshape_and_cache_flash" + + FP8_KV_CACHE = kv_cache_dtype.startswith("fp8") + assert (not FP8_KV_CACHE) or kv_cache_torch_dtype in [ + torch.float8_e4m3fn, torch.float8_e5m2, torch.uint8, + torch.float8_e4m3fnuz], \ + "unsupported dtype of KV cache tensor, got "\ + "{kv_cache_torch_dtype}. Supported kv cache dtypes: fp8e4m3fn, " \ + "fp8e5m2, uint8, bfloat16, float16, float32, fp8e4m3fnuz." + + # heuristics instead of autotuning + TILE_SIZE = min(2048, triton.next_power_of_2(n)) + if torch.version.hip: + num_stages = 4 + num_warps = 8 + else: # cuda + num_stages = 10 + num_warps = 16 + if torch.cuda.get_device_capability(key.device)[0] < 9: + TILE_SIZE = min(512, TILE_SIZE) + + # TODO(ngl): maybe replace with static launch grid to avoid overhead if + # using cudagraphs + grid = lambda meta: (int(num_tokens), triton.cdiv(n, meta["TILE_SIZE"])) + + reshape_and_cache_kernel_flash[grid]( + key_ptr=key, + value_ptr=value, + key_cache_ptr=key_cache, + value_cache_ptr=value_cache, + slot_mapping_ptr=slot_mapping, + k_scale=k_scale, + v_scale=v_scale, + # strides + key_stride=key_stride, + value_stride=value_stride, + block_stride=block_stride, + page_stride=page_stride, + num_heads=num_heads, + head_size=head_size, + block_size=block_size, + # FP8 flags + FP8_KV_CACHE=FP8_KV_CACHE, + # autotune parameters + TILE_SIZE=TILE_SIZE, + num_warps=num_warps, + num_stages=num_stages, + ) diff --git a/vllm/v1/attention/backends/triton_attn.py b/vllm/v1/attention/backends/triton_attn.py index 722c23f150cd3..f9fbd05efc67d 100644 --- a/vllm/v1/attention/backends/triton_attn.py +++ b/vllm/v1/attention/backends/triton_attn.py @@ -8,6 +8,8 @@ import torch from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, AttentionMetadata, AttentionType) +from vllm.attention.ops.triton_reshape_and_cache_flash import ( + triton_reshape_and_cache_flash) from vllm.attention.ops.triton_unified_attention import unified_attention from vllm.config import VllmConfig from vllm.logger import init_logger @@ -291,7 +293,13 @@ class TritonAttentionImpl(AttentionImpl): if self.kv_sharing_target_layer_name is None: # Reshape the input keys and values and store them in the cache. # Skip this if sharing KV cache with an earlier attention layer. - ops.reshape_and_cache_flash( + if self.kv_cache_dtype.startswith("fp8"): + key_cache = key_cache.view(self.fp8_dtype) + value_cache = value_cache.view(self.fp8_dtype) + # triton kernel does not support uint8 kv_cache + # (because some explicit casts (e.g. float8_e4m3fnuz) + # are not supported) + triton_reshape_and_cache_flash( key, value, key_cache, @@ -303,8 +311,9 @@ class TritonAttentionImpl(AttentionImpl): ) if self.kv_cache_dtype.startswith("fp8"): - key_cache = key_cache.view(self.fp8_dtype) - value_cache = value_cache.view(self.fp8_dtype) + if key_cache.dtype != self.fp8_dtype: + key_cache = key_cache.view(self.fp8_dtype) + value_cache = value_cache.view(self.fp8_dtype) num_tokens, num_heads, head_size = query.shape assert layer._q_scale_float == 1.0, \ "A non 1.0 q_scale is not currently supported."