mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-25 14:44:28 +08:00
[Core] Add Flashinfer TRTLLM Backend for Flashinfer decode path (SM100). (#19825)
Signed-off-by: Pavani Majety <pmajety@nvidia.com> Signed-off-by: mgoin <mgoin64@gmail.com> Co-authored-by: shuw <shuw@nvidia.com> Co-authored-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
parent
8020e98c9f
commit
7bd4c37ae7
240
benchmarks/kernels/benchmark_trtllm_attention.py
Normal file
240
benchmarks/kernels/benchmark_trtllm_attention.py
Normal file
@ -0,0 +1,240 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
import csv
|
||||||
|
import os
|
||||||
|
import random
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
import flashinfer
|
||||||
|
import torch
|
||||||
|
|
||||||
|
FLOAT32_BYTES = torch.finfo(torch.float).bits // 8
|
||||||
|
|
||||||
|
# KV Cache Layout for TRT-LLM
|
||||||
|
# kv_cache_shape = (num_blocks, 2, num_kv_heads, page_size, head_dim)
|
||||||
|
|
||||||
|
|
||||||
|
def to_float8(x, dtype=torch.float8_e4m3fn):
|
||||||
|
finfo = torch.finfo(dtype)
|
||||||
|
min_val, max_val = x.aminmax()
|
||||||
|
amax = torch.maximum(min_val.abs(), max_val.abs()).clamp(min=1e-12)
|
||||||
|
scale = finfo.max / amax * 0.1
|
||||||
|
x_scl_sat = (x * scale).clamp(min=finfo.min, max=finfo.max)
|
||||||
|
return x_scl_sat.to(dtype), scale.float().reciprocal()
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def benchmark_decode(
|
||||||
|
num_seqs,
|
||||||
|
max_seq_len,
|
||||||
|
page_size=16,
|
||||||
|
dtype=torch.bfloat16,
|
||||||
|
kv_layout="HND",
|
||||||
|
num_kv_heads=8,
|
||||||
|
kv_cache_dtype="auto",
|
||||||
|
head_dim=128,
|
||||||
|
warmup=10,
|
||||||
|
trials=20,
|
||||||
|
):
|
||||||
|
torch.set_default_device("cuda")
|
||||||
|
device = "cuda"
|
||||||
|
torch.manual_seed(0)
|
||||||
|
|
||||||
|
# Currently only HEAD_GRP_SIZE == 8 is supported
|
||||||
|
HEAD_GRP_SIZE = 8
|
||||||
|
MAX_SEQ_LEN = max_seq_len
|
||||||
|
|
||||||
|
# large number to reduce kv_cache reuse
|
||||||
|
NUM_BLOCKS = int(256000 / page_size)
|
||||||
|
|
||||||
|
workspace_buffer = torch.empty(1024 * 1024 * 1024, dtype=torch.int8, device=device)
|
||||||
|
|
||||||
|
# For decode, batch_size is num_decode_token
|
||||||
|
num_qo_heads = num_kv_heads * HEAD_GRP_SIZE
|
||||||
|
sm_scale = float(1.0 / (head_dim**0.5))
|
||||||
|
q = torch.randn(num_seqs, num_qo_heads, head_dim, device=device, dtype=dtype)
|
||||||
|
kv_lens = [random.randint(1, MAX_SEQ_LEN) for _ in range(num_seqs)]
|
||||||
|
|
||||||
|
max_kv_len = max(kv_lens)
|
||||||
|
kv_lens_tensor = torch.tensor(kv_lens, dtype=torch.int, device=device)
|
||||||
|
max_num_blocks_per_seq = (max_kv_len + page_size - 1) // page_size
|
||||||
|
|
||||||
|
block_tables = torch.randint(
|
||||||
|
0, NUM_BLOCKS, (num_seqs, max_num_blocks_per_seq), dtype=torch.int32
|
||||||
|
)
|
||||||
|
|
||||||
|
kv_cache_shape = (NUM_BLOCKS, 2, num_kv_heads, page_size, head_dim)
|
||||||
|
kv_cache = torch.randn(size=kv_cache_shape, device=device, dtype=dtype)
|
||||||
|
k_scale = v_scale = 1.0
|
||||||
|
|
||||||
|
if kv_cache_dtype.startswith("fp8"):
|
||||||
|
kv_cache, _ = to_float8(kv_cache)
|
||||||
|
|
||||||
|
# Benchmark TRT decode
|
||||||
|
def trt_decode():
|
||||||
|
return flashinfer.decode.trtllm_batch_decode_with_kv_cache(
|
||||||
|
q,
|
||||||
|
kv_cache,
|
||||||
|
workspace_buffer,
|
||||||
|
num_qo_heads,
|
||||||
|
num_kv_heads,
|
||||||
|
sm_scale,
|
||||||
|
block_tables,
|
||||||
|
kv_lens_tensor,
|
||||||
|
page_size,
|
||||||
|
max_kv_len,
|
||||||
|
kv_cache_dtype,
|
||||||
|
k_scale,
|
||||||
|
v_scale,
|
||||||
|
)
|
||||||
|
|
||||||
|
def time_fn(fn, warmup=10, trials=20):
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
start = torch.cuda.Event(enable_timing=True)
|
||||||
|
end = torch.cuda.Event(enable_timing=True)
|
||||||
|
times = []
|
||||||
|
for i in range(warmup):
|
||||||
|
fn()
|
||||||
|
for i in range(trials):
|
||||||
|
start.record()
|
||||||
|
fn()
|
||||||
|
end.record()
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
times.append(start.elapsed_time(end)) # ms
|
||||||
|
return sum(times) / len(times), torch.std(torch.tensor(times))
|
||||||
|
|
||||||
|
# TRT Decode
|
||||||
|
trt_mean, trt_std = time_fn(trt_decode)
|
||||||
|
|
||||||
|
kv_indptr = [0]
|
||||||
|
kv_indices = []
|
||||||
|
kv_last_page_lens = []
|
||||||
|
for i in range(num_seqs):
|
||||||
|
seq_len = kv_lens[i]
|
||||||
|
assert seq_len > 0
|
||||||
|
num_blocks = (seq_len + page_size - 1) // page_size
|
||||||
|
kv_indices.extend(block_tables[i, :num_blocks])
|
||||||
|
kv_indptr.append(kv_indptr[-1] + num_blocks)
|
||||||
|
kv_last_page_len = seq_len % page_size
|
||||||
|
if kv_last_page_len == 0:
|
||||||
|
kv_last_page_len = page_size
|
||||||
|
kv_last_page_lens.append(kv_last_page_len)
|
||||||
|
|
||||||
|
kv_indptr = torch.tensor(kv_indptr, dtype=torch.int32)
|
||||||
|
kv_indices = torch.tensor(kv_indices, dtype=torch.int32)
|
||||||
|
kv_last_page_lens = torch.tensor(kv_last_page_lens, dtype=torch.int32)
|
||||||
|
|
||||||
|
wrapper = flashinfer.BatchDecodeWithPagedKVCacheWrapper(
|
||||||
|
workspace_buffer,
|
||||||
|
kv_layout,
|
||||||
|
use_tensor_cores=((num_qo_heads // num_kv_heads) > 4),
|
||||||
|
)
|
||||||
|
|
||||||
|
wrapper.plan(
|
||||||
|
kv_indptr,
|
||||||
|
kv_indices,
|
||||||
|
kv_last_page_lens,
|
||||||
|
num_qo_heads,
|
||||||
|
num_kv_heads,
|
||||||
|
head_dim,
|
||||||
|
page_size,
|
||||||
|
"NONE",
|
||||||
|
q_data_type=dtype,
|
||||||
|
kv_data_type=torch.float8_e4m3fn if kv_cache_dtype.startswith("fp8") else dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
def baseline_decode():
|
||||||
|
return wrapper.run(q, kv_cache, sm_scale, k_scale, v_scale)
|
||||||
|
|
||||||
|
baseline_mean, baseline_std = time_fn(baseline_decode)
|
||||||
|
|
||||||
|
# Calculate percentage speedup (positive means TRT is faster)
|
||||||
|
speedup_percent = (baseline_mean - trt_mean) / baseline_mean
|
||||||
|
|
||||||
|
print(
|
||||||
|
f"\t{num_seqs}\t{max_seq_len}\t{trt_mean:.3f}\t{trt_std.item():.3f}"
|
||||||
|
f"\t{baseline_mean:.3f}\t{baseline_std.item():.3f}\t{speedup_percent:.3f}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Return results for CSV writing
|
||||||
|
return {
|
||||||
|
"num_seqs": num_seqs,
|
||||||
|
"trt_mean": trt_mean,
|
||||||
|
"trt_std": trt_std.item(),
|
||||||
|
"baseline_mean": baseline_mean,
|
||||||
|
"baseline_std": baseline_std.item(),
|
||||||
|
"speedup_percent": speedup_percent,
|
||||||
|
"q_dtype": str(dtype),
|
||||||
|
"kv_cache_dtype": kv_cache_dtype,
|
||||||
|
"page_size": page_size,
|
||||||
|
"num_kv_heads": num_kv_heads,
|
||||||
|
"head_dim": head_dim,
|
||||||
|
"max_seq_len": max_seq_len,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def write_results_to_csv(results, filename=None):
|
||||||
|
"""Write benchmark results to CSV file."""
|
||||||
|
if filename is None:
|
||||||
|
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||||
|
filename = f"flashinfer_trtllm_benchmark_{timestamp}.csv"
|
||||||
|
|
||||||
|
fieldnames = [
|
||||||
|
"num_seqs",
|
||||||
|
"trt_mean",
|
||||||
|
"trt_std",
|
||||||
|
"baseline_mean",
|
||||||
|
"baseline_std",
|
||||||
|
"speedup_percent",
|
||||||
|
"q_dtype",
|
||||||
|
"kv_cache_dtype",
|
||||||
|
"page_size",
|
||||||
|
"num_kv_heads",
|
||||||
|
"head_dim",
|
||||||
|
"max_seq_len",
|
||||||
|
]
|
||||||
|
|
||||||
|
file_exists = os.path.exists(filename)
|
||||||
|
|
||||||
|
with open(filename, "a", newline="") as csvfile:
|
||||||
|
writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
|
||||||
|
|
||||||
|
if not file_exists:
|
||||||
|
writer.writeheader()
|
||||||
|
|
||||||
|
for result in results:
|
||||||
|
writer.writerow(result)
|
||||||
|
|
||||||
|
print(f"Results written to {filename}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
num_seqs = [1, 4, 8, 16, 32, 64, 128, 256]
|
||||||
|
max_seq_lens = [1024, 2048, 4096, 8192, 16384, 32768, 65536, 131072]
|
||||||
|
all_results = []
|
||||||
|
|
||||||
|
print("Running benchmark for kv_cache_dtype: bfloat16")
|
||||||
|
print(
|
||||||
|
"\tnum_seqs\tmax_seq_len\ttrt_mean\ttrt_std\tbaseline_mean\tbaseline_std\tspeedup_percent"
|
||||||
|
)
|
||||||
|
for max_seq_len in max_seq_lens:
|
||||||
|
for bs in num_seqs:
|
||||||
|
result = benchmark_decode(
|
||||||
|
bs, max_seq_len, dtype=torch.bfloat16, kv_cache_dtype="auto"
|
||||||
|
)
|
||||||
|
all_results.append(result)
|
||||||
|
|
||||||
|
print("Running benchmark for q_dtype = bfloat16, kv_cache_dtype: fp8")
|
||||||
|
print(
|
||||||
|
"\tnum_seqs\tmax_seq_len\ttrt_mean\ttrt_std\tbaseline_mean\tbaseline_std\tspeedup_percent"
|
||||||
|
)
|
||||||
|
for max_seq_len in max_seq_lens:
|
||||||
|
for bs in num_seqs:
|
||||||
|
result = benchmark_decode(
|
||||||
|
bs, max_seq_len, dtype=torch.bfloat16, kv_cache_dtype="fp8"
|
||||||
|
)
|
||||||
|
all_results.append(result)
|
||||||
|
|
||||||
|
# Write all results to CSV
|
||||||
|
write_results_to_csv(all_results)
|
||||||
@ -0,0 +1,140 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import flashinfer
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
|
if not current_platform.is_device_capability(100):
|
||||||
|
pytest.skip("This TRTLLM kernel requires NVIDIA Blackwell.",
|
||||||
|
allow_module_level=True)
|
||||||
|
|
||||||
|
FLOAT32_BYTES = torch.finfo(torch.float).bits // 8
|
||||||
|
|
||||||
|
# KV Cache Layout for TRT-LLM
|
||||||
|
# kv_cache_shape = (num_blocks, 2, num_kv_heads, page_size, head_dim)
|
||||||
|
|
||||||
|
NUM_HEADS = [(64, 8), (16, 16), (40, 8), (32, 8)]
|
||||||
|
HEAD_SIZES = [128]
|
||||||
|
BLOCK_SIZES = [16, 32]
|
||||||
|
DTYPES = [torch.float16, torch.bfloat16]
|
||||||
|
NUM_BLOCKS = 32768 # Large enough to test overflow in index calculation.
|
||||||
|
SOFT_CAPS = [None, 30.0, 50.0]
|
||||||
|
|
||||||
|
|
||||||
|
def to_float8(x, dtype=torch.float8_e4m3fn):
|
||||||
|
finfo = torch.finfo(dtype)
|
||||||
|
min_val, max_val = x.aminmax()
|
||||||
|
amax = torch.maximum(min_val.abs(), max_val.abs()).clamp(min=1e-12)
|
||||||
|
scale = finfo.max / amax * 0.1
|
||||||
|
x_scl_sat = (x * scale).clamp(min=finfo.min, max=finfo.max)
|
||||||
|
return x_scl_sat.to(dtype), scale.float().reciprocal()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("kv_lens", [[1328, 18, 463], [1, 54, 293, 70]])
|
||||||
|
@pytest.mark.parametrize("num_heads", NUM_HEADS)
|
||||||
|
@pytest.mark.parametrize("head_size", HEAD_SIZES)
|
||||||
|
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
|
||||||
|
@pytest.mark.parametrize("kv_layout", ["HND"])
|
||||||
|
@pytest.mark.parametrize("dtype", DTYPES)
|
||||||
|
@pytest.mark.parametrize("soft_cap", SOFT_CAPS)
|
||||||
|
@torch.inference_mode
|
||||||
|
def test_flashinfer_trtllm_decode_with_baseline(
|
||||||
|
kv_lens: list[int],
|
||||||
|
num_heads: tuple[int, int],
|
||||||
|
head_size: int,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
block_size: int,
|
||||||
|
soft_cap: Optional[float],
|
||||||
|
kv_layout: str,
|
||||||
|
) -> None:
|
||||||
|
torch.set_default_device("cuda")
|
||||||
|
current_platform.seed_everything(0)
|
||||||
|
num_seqs = len(kv_lens)
|
||||||
|
num_query_heads = num_heads[0]
|
||||||
|
num_kv_heads = num_heads[1]
|
||||||
|
|
||||||
|
assert num_query_heads % num_kv_heads == 0
|
||||||
|
max_kv_len = max(kv_lens)
|
||||||
|
scale = head_size**-0.5
|
||||||
|
|
||||||
|
query = torch.randn(num_seqs, num_query_heads, head_size, dtype=dtype)
|
||||||
|
kv_cache_shape = None
|
||||||
|
if kv_layout == "NHD":
|
||||||
|
kv_cache_shape = (NUM_BLOCKS, 2, block_size, num_kv_heads, head_size)
|
||||||
|
elif kv_layout == "HND":
|
||||||
|
kv_cache_shape = (NUM_BLOCKS, 2, num_kv_heads, block_size, head_size)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Invalid kv_layout: {kv_layout}")
|
||||||
|
key_value_cache = torch.randn(kv_cache_shape, dtype=dtype)
|
||||||
|
|
||||||
|
max_num_blocks_per_seq = (max_kv_len + block_size - 1) // block_size
|
||||||
|
block_tables = torch.randint(0,
|
||||||
|
NUM_BLOCKS,
|
||||||
|
(num_seqs, max_num_blocks_per_seq),
|
||||||
|
dtype=torch.int32)
|
||||||
|
k_scale = v_scale = 1.0
|
||||||
|
kv_indptr = [0]
|
||||||
|
kv_indices = []
|
||||||
|
kv_last_page_lens = []
|
||||||
|
for i in range(num_seqs):
|
||||||
|
seq_len = kv_lens[i]
|
||||||
|
assert seq_len > 0
|
||||||
|
num_blocks = (seq_len + block_size - 1) // block_size
|
||||||
|
kv_indices.extend(block_tables[i, :num_blocks])
|
||||||
|
kv_indptr.append(kv_indptr[-1] + num_blocks)
|
||||||
|
kv_last_page_len = seq_len % block_size
|
||||||
|
if kv_last_page_len == 0:
|
||||||
|
kv_last_page_len = block_size
|
||||||
|
kv_last_page_lens.append(kv_last_page_len)
|
||||||
|
|
||||||
|
kv_indptr = torch.tensor(kv_indptr, dtype=torch.int32)
|
||||||
|
kv_indices = torch.tensor(kv_indices, dtype=torch.int32)
|
||||||
|
kv_last_page_lens = torch.tensor(kv_last_page_lens, dtype=torch.int32)
|
||||||
|
|
||||||
|
workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8)
|
||||||
|
wrapper = flashinfer.\
|
||||||
|
BatchDecodeWithPagedKVCacheWrapper(workspace_buffer, kv_layout,
|
||||||
|
use_tensor_cores=(
|
||||||
|
(num_query_heads//num_kv_heads) > 4)
|
||||||
|
)
|
||||||
|
wrapper.plan(kv_indptr,
|
||||||
|
kv_indices,
|
||||||
|
kv_last_page_lens,
|
||||||
|
num_query_heads,
|
||||||
|
num_kv_heads,
|
||||||
|
head_size,
|
||||||
|
block_size,
|
||||||
|
"NONE",
|
||||||
|
q_data_type=dtype,
|
||||||
|
kv_data_type=dtype,
|
||||||
|
logits_soft_cap=soft_cap)
|
||||||
|
|
||||||
|
output = wrapper.run(query, key_value_cache, scale)
|
||||||
|
|
||||||
|
# TRTLLM Decode
|
||||||
|
max_kv_len = max(kv_lens)
|
||||||
|
kv_lens_tensor = torch.tensor(kv_lens,
|
||||||
|
dtype=torch.int,
|
||||||
|
device=query.device)
|
||||||
|
output_trtllm = flashinfer.decode.trtllm_batch_decode_with_kv_cache(
|
||||||
|
query.contiguous(),
|
||||||
|
key_value_cache,
|
||||||
|
workspace_buffer,
|
||||||
|
num_query_heads,
|
||||||
|
num_kv_heads,
|
||||||
|
scale,
|
||||||
|
block_tables,
|
||||||
|
kv_lens_tensor,
|
||||||
|
block_size,
|
||||||
|
max_kv_len,
|
||||||
|
"auto",
|
||||||
|
k_scale,
|
||||||
|
v_scale,
|
||||||
|
)
|
||||||
|
|
||||||
|
torch.testing.assert_close(output, output_trtllm, atol=1e-2, rtol=1e-2), \
|
||||||
|
f"{torch.max(torch.abs(output - output_trtllm))}"
|
||||||
@ -11,7 +11,8 @@ from vllm.multimodal import MultiModalPlaceholderMap
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
from flashinfer import BatchDecodeWithPagedKVCacheWrapper
|
from flashinfer import BatchDecodeWithPagedKVCacheWrapper
|
||||||
from flashinfer.decode import CUDAGraphBatchDecodeWithPagedKVCacheWrapper
|
from flashinfer.decode import (CUDAGraphBatchDecodeWithPagedKVCacheWrapper,
|
||||||
|
trtllm_batch_decode_with_kv_cache)
|
||||||
from flashinfer.prefill import BatchPrefillWithPagedKVCacheWrapper
|
from flashinfer.prefill import BatchPrefillWithPagedKVCacheWrapper
|
||||||
|
|
||||||
from vllm.vllm_flash_attn import flash_attn_varlen_func
|
from vllm.vllm_flash_attn import flash_attn_varlen_func
|
||||||
@ -22,7 +23,10 @@ except ImportError:
|
|||||||
BatchDecodeWithPagedKVCacheWrapper = None
|
BatchDecodeWithPagedKVCacheWrapper = None
|
||||||
CUDAGraphBatchDecodeWithPagedKVCacheWrapper = None
|
CUDAGraphBatchDecodeWithPagedKVCacheWrapper = None
|
||||||
BatchPrefillWithPagedKVCacheWrapper = None
|
BatchPrefillWithPagedKVCacheWrapper = None
|
||||||
|
trtllm_batch_decode_with_kv_cache = None
|
||||||
FLASHINFER_WORKSPACE_BUFFER_SIZE = 0
|
FLASHINFER_WORKSPACE_BUFFER_SIZE = 0
|
||||||
|
raise ImportError("FlashInfer is not installed. Please install it from "
|
||||||
|
"https://github.com/flashinfer-ai/flashinfer") from None
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@ -40,6 +44,7 @@ from vllm.attention.layer import Attention
|
|||||||
from vllm.attention.ops.paged_attn import PagedAttention
|
from vllm.attention.ops.paged_attn import PagedAttention
|
||||||
from vllm.config import VllmConfig, get_layers_from_vllm_config
|
from vllm.config import VllmConfig, get_layers_from_vllm_config
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
|
from vllm.platforms import current_platform
|
||||||
from vllm.utils import (async_tensor_h2d, get_kv_cache_torch_dtype,
|
from vllm.utils import (async_tensor_h2d, get_kv_cache_torch_dtype,
|
||||||
make_tensor_with_pad)
|
make_tensor_with_pad)
|
||||||
|
|
||||||
@ -49,10 +54,9 @@ if TYPE_CHECKING:
|
|||||||
from vllm.worker.model_runner import (ModelInputForGPUBuilder,
|
from vllm.worker.model_runner import (ModelInputForGPUBuilder,
|
||||||
ModelInputForGPUWithSamplingMetadata)
|
ModelInputForGPUWithSamplingMetadata)
|
||||||
|
|
||||||
FLASHINFER_KV_CACHE_LAYOUT: str = envs.VLLM_KV_CACHE_LAYOUT or "NHD"
|
|
||||||
|
|
||||||
|
|
||||||
class FlashInferBackend(AttentionBackend):
|
class FlashInferBackend(AttentionBackend):
|
||||||
|
cached_sm100a_supported: Optional[bool] = None
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_name() -> str:
|
def get_name() -> str:
|
||||||
@ -85,7 +89,7 @@ class FlashInferBackend(AttentionBackend):
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_kv_cache_stride_order() -> Tuple[int, ...]:
|
def get_kv_cache_stride_order() -> Tuple[int, ...]:
|
||||||
cache_layout = FLASHINFER_KV_CACHE_LAYOUT
|
cache_layout = FlashInferState.get_kv_cache_layout()
|
||||||
assert (cache_layout in ("NHD", "HND"))
|
assert (cache_layout in ("NHD", "HND"))
|
||||||
stride_order = (0, 1, 2, 3, 4) if cache_layout == "NHD" else (0, 1, 3,
|
stride_order = (0, 1, 2, 3, 4) if cache_layout == "NHD" else (0, 1, 3,
|
||||||
2, 4)
|
2, 4)
|
||||||
@ -119,6 +123,47 @@ class FlashInferBackend(AttentionBackend):
|
|||||||
else:
|
else:
|
||||||
raise ValueError(f"Unrecognized FP8 dtype: {kv_cache_dtype}")
|
raise ValueError(f"Unrecognized FP8 dtype: {kv_cache_dtype}")
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def use_trtllm_decode_attention(
|
||||||
|
batch_size: int,
|
||||||
|
max_seq_len: int,
|
||||||
|
kv_cache_dtype: str,
|
||||||
|
num_qo_heads: Optional[int],
|
||||||
|
num_kv_heads: Optional[int],
|
||||||
|
attn_head_size: Optional[int],
|
||||||
|
) -> bool:
|
||||||
|
if FlashInferBackend.cached_sm100a_supported is None:
|
||||||
|
FlashInferBackend.cached_sm100a_supported = (
|
||||||
|
current_platform.has_device_capability(100))
|
||||||
|
if not FlashInferBackend.cached_sm100a_supported:
|
||||||
|
return False
|
||||||
|
# Check if the dimensions are supported by TRTLLM decode attention
|
||||||
|
if (attn_head_size is None or num_qo_heads is None
|
||||||
|
or num_kv_heads is None or num_qo_heads // num_kv_heads > 8
|
||||||
|
or num_qo_heads % num_kv_heads != 0 or attn_head_size != 128):
|
||||||
|
return False
|
||||||
|
env_value = envs.VLLM_USE_TRTLLM_DECODE_ATTENTION
|
||||||
|
if env_value is not None:
|
||||||
|
logger.info_once("VLLM_USE_TRTLLM_DECODE_ATTENTION is set to %s",
|
||||||
|
env_value)
|
||||||
|
# Environment variable is set - respect it
|
||||||
|
# Making the conditional check for zero because
|
||||||
|
# the path is automatically enabled if the batch size condition
|
||||||
|
# is satisfied.
|
||||||
|
no_use_trtllm = (env_value == "0")
|
||||||
|
if not no_use_trtllm:
|
||||||
|
logger.info_once("Using TRTLLM decode attention.")
|
||||||
|
return not no_use_trtllm
|
||||||
|
else:
|
||||||
|
# Environment variable not set - use auto-detection
|
||||||
|
use_trtllm = (FlashInferBackend.cached_sm100a_supported
|
||||||
|
and batch_size <= 256 and max_seq_len < 131072
|
||||||
|
and kv_cache_dtype == "auto")
|
||||||
|
if use_trtllm:
|
||||||
|
logger.warning_once(
|
||||||
|
"Using TRTLLM decode attention (auto-detected).")
|
||||||
|
return use_trtllm
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class PerLayerParameters:
|
class PerLayerParameters:
|
||||||
@ -207,10 +252,19 @@ class FlashInferState(AttentionState):
|
|||||||
device=self.runner.device)
|
device=self.runner.device)
|
||||||
return self._workspace_buffer
|
return self._workspace_buffer
|
||||||
|
|
||||||
def get_kv_cache_layout(self):
|
@staticmethod
|
||||||
if self._kv_cache_layout is None:
|
def get_kv_cache_layout():
|
||||||
self._kv_cache_layout = FLASHINFER_KV_CACHE_LAYOUT
|
from vllm.v1.attention.backends.utils import _KV_CACHE_LAYOUT_OVERRIDE
|
||||||
return self._kv_cache_layout
|
if _KV_CACHE_LAYOUT_OVERRIDE is not None:
|
||||||
|
logger.info_once("Using KV cache layout %s",
|
||||||
|
_KV_CACHE_LAYOUT_OVERRIDE)
|
||||||
|
return _KV_CACHE_LAYOUT_OVERRIDE
|
||||||
|
cache_layout = envs.VLLM_KV_CACHE_LAYOUT
|
||||||
|
if cache_layout is None:
|
||||||
|
logger.info_once("Using default KV cache layout NHD")
|
||||||
|
return "NHD"
|
||||||
|
logger.info_once("Using KV cache layout %s", cache_layout)
|
||||||
|
return cache_layout
|
||||||
|
|
||||||
def _get_prefill_wrapper(self):
|
def _get_prefill_wrapper(self):
|
||||||
if self._prefill_wrapper is None:
|
if self._prefill_wrapper is None:
|
||||||
@ -323,6 +377,8 @@ class FlashInferState(AttentionState):
|
|||||||
num_prefill_tokens=0,
|
num_prefill_tokens=0,
|
||||||
num_decode_tokens=batch_size,
|
num_decode_tokens=batch_size,
|
||||||
max_prefill_seq_len=0,
|
max_prefill_seq_len=0,
|
||||||
|
max_decode_seq_len=0,
|
||||||
|
seq_lens_tensor=self._graph_seq_lens,
|
||||||
block_tables=self._graph_block_tables,
|
block_tables=self._graph_block_tables,
|
||||||
paged_kv_indptr=paged_kv_indptr_tensor_host,
|
paged_kv_indptr=paged_kv_indptr_tensor_host,
|
||||||
paged_kv_indices=paged_kv_indices_tensor_host,
|
paged_kv_indices=paged_kv_indices_tensor_host,
|
||||||
@ -348,6 +404,8 @@ class FlashInferState(AttentionState):
|
|||||||
attn_metadata,
|
attn_metadata,
|
||||||
is_encoder_decoder_model: bool = False):
|
is_encoder_decoder_model: bool = False):
|
||||||
return {
|
return {
|
||||||
|
"block_tables": attn_metadata.block_tables,
|
||||||
|
"seq_lens_tensor": attn_metadata.seq_lens_tensor,
|
||||||
"slot_mapping": attn_metadata.slot_mapping,
|
"slot_mapping": attn_metadata.slot_mapping,
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -355,7 +413,13 @@ class FlashInferState(AttentionState):
|
|||||||
input_buffers,
|
input_buffers,
|
||||||
attn_metadata,
|
attn_metadata,
|
||||||
is_encoder_decoder_model: bool = False):
|
is_encoder_decoder_model: bool = False):
|
||||||
return
|
# FlashInfer-specific logic: copy additional tensors
|
||||||
|
num_total_blocks = attn_metadata.decode_metadata.seq_lens_tensor.shape[
|
||||||
|
0]
|
||||||
|
input_buffers["seq_lens_tensor"][:num_total_blocks].copy_(
|
||||||
|
attn_metadata.seq_lens_tensor, non_blocking=True)
|
||||||
|
input_buffers["block_tables"][:num_total_blocks].copy_(
|
||||||
|
attn_metadata.block_tables, non_blocking=True)
|
||||||
|
|
||||||
def begin_forward(self, model_input):
|
def begin_forward(self, model_input):
|
||||||
assert not self._is_graph_capturing
|
assert not self._is_graph_capturing
|
||||||
@ -388,6 +452,8 @@ class FlashInferMetadata(AttentionMetadata):
|
|||||||
# Maximum sequence length among prefill batch. 0 if there are decoding
|
# Maximum sequence length among prefill batch. 0 if there are decoding
|
||||||
# requests only.
|
# requests only.
|
||||||
max_prefill_seq_len: int
|
max_prefill_seq_len: int
|
||||||
|
max_decode_seq_len: int
|
||||||
|
|
||||||
# Number of query tokens for each request in the batch.
|
# Number of query tokens for each request in the batch.
|
||||||
# Currently, we require that all requests have the same number of query
|
# Currently, we require that all requests have the same number of query
|
||||||
# tokens during the decoding phase. When speculavie decoding is enabled,
|
# tokens during the decoding phase. When speculavie decoding is enabled,
|
||||||
@ -790,6 +856,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
|||||||
use_captured_graph = cuda_graph_pad_size != -1
|
use_captured_graph = cuda_graph_pad_size != -1
|
||||||
|
|
||||||
max_prefill_seq_len = max(self.prefill_seq_lens, default=0)
|
max_prefill_seq_len = max(self.prefill_seq_lens, default=0)
|
||||||
|
max_decode_seq_len = max(self.curr_seq_lens, default=0)
|
||||||
num_decode_tokens = self.num_decode_tokens
|
num_decode_tokens = self.num_decode_tokens
|
||||||
decode_query_len = max(query_lens[self.num_prefills:], default=1)
|
decode_query_len = max(query_lens[self.num_prefills:], default=1)
|
||||||
|
|
||||||
@ -895,6 +962,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
|||||||
num_prefill_tokens=self.num_prefill_tokens,
|
num_prefill_tokens=self.num_prefill_tokens,
|
||||||
num_decode_tokens=num_decode_tokens,
|
num_decode_tokens=num_decode_tokens,
|
||||||
max_prefill_seq_len=max_prefill_seq_len,
|
max_prefill_seq_len=max_prefill_seq_len,
|
||||||
|
max_decode_seq_len=max_decode_seq_len,
|
||||||
block_tables=block_tables,
|
block_tables=block_tables,
|
||||||
paged_kv_indptr=paged_kv_indptr_tensor,
|
paged_kv_indptr=paged_kv_indptr_tensor,
|
||||||
paged_kv_indices=paged_kv_indices_tensor,
|
paged_kv_indices=paged_kv_indices_tensor,
|
||||||
@ -1081,13 +1149,36 @@ class FlashInferImpl(AttentionImpl):
|
|||||||
assert decode_meta.decode_wrapper._logits_soft_cap == (
|
assert decode_meta.decode_wrapper._logits_soft_cap == (
|
||||||
logits_soft_cap or 0.0)
|
logits_soft_cap or 0.0)
|
||||||
assert decode_meta.decode_wrapper._sm_scale == softmax_scale
|
assert decode_meta.decode_wrapper._sm_scale == softmax_scale
|
||||||
|
# TODO: @pavanimajety Remove this once the switch happens
|
||||||
decode_output = decode_meta.decode_wrapper.run(
|
# inside flashinfer.
|
||||||
decode_query,
|
if not FlashInferBackend.use_trtllm_decode_attention(
|
||||||
kv_cache.permute(*stride_order),
|
num_decode_tokens, attn_metadata.max_decode_seq_len,
|
||||||
k_scale=layer._k_scale_float,
|
kv_cache_dtype, attn_metadata.num_qo_heads,
|
||||||
v_scale=layer._v_scale_float,
|
attn_metadata.num_kv_heads, attn_metadata.head_dim):
|
||||||
)
|
decode_output = decode_meta.decode_wrapper.run(
|
||||||
|
decode_query,
|
||||||
|
kv_cache.permute(*stride_order),
|
||||||
|
k_scale=layer._k_scale_float,
|
||||||
|
v_scale=layer._v_scale_float,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
workspace_buffer = (
|
||||||
|
decode_meta.decode_wrapper._int_workspace_buffer)
|
||||||
|
assert FlashInferState.get_kv_cache_layout() == "HND"
|
||||||
|
decode_output = trtllm_batch_decode_with_kv_cache(
|
||||||
|
query=decode_query,
|
||||||
|
kv_cache=kv_cache.permute(*stride_order),
|
||||||
|
workspace_buffer=workspace_buffer,
|
||||||
|
num_heads=num_heads,
|
||||||
|
num_kv_heads=num_kv_heads,
|
||||||
|
scale=softmax_scale,
|
||||||
|
block_tables=attn_metadata.block_tables,
|
||||||
|
seq_lens=decode_meta.seq_lens_tensor,
|
||||||
|
block_size=attn_metadata.page_size,
|
||||||
|
max_seq_len=attn_metadata.max_decode_seq_len,
|
||||||
|
kv_cache_dtype=kv_cache_dtype,
|
||||||
|
k_scale=layer._k_scale_float,
|
||||||
|
v_scale=layer._v_scale_float)
|
||||||
|
|
||||||
if prefill_output is None and decode_output is not None:
|
if prefill_output is None and decode_output is not None:
|
||||||
# Decode only batch.
|
# Decode only batch.
|
||||||
|
|||||||
@ -1424,6 +1424,8 @@ class EngineArgs:
|
|||||||
from vllm.attention.utils.fa_utils import (
|
from vllm.attention.utils.fa_utils import (
|
||||||
flash_attn_supports_fp8)
|
flash_attn_supports_fp8)
|
||||||
supported = flash_attn_supports_fp8()
|
supported = flash_attn_supports_fp8()
|
||||||
|
elif envs.VLLM_USE_TRTLLM_DECODE_ATTENTION:
|
||||||
|
supported = True
|
||||||
if not supported:
|
if not supported:
|
||||||
_raise_or_fallback(feature_name="--kv-cache-dtype",
|
_raise_or_fallback(feature_name="--kv-cache-dtype",
|
||||||
recommend_to_remove=False)
|
recommend_to_remove=False)
|
||||||
|
|||||||
@ -959,7 +959,11 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
|||||||
# consumer. This is only applicable when using NixlConnector in a
|
# consumer. This is only applicable when using NixlConnector in a
|
||||||
# disaggregated decode-prefill setup.
|
# disaggregated decode-prefill setup.
|
||||||
"VLLM_NIXL_ABORT_REQUEST_TIMEOUT":
|
"VLLM_NIXL_ABORT_REQUEST_TIMEOUT":
|
||||||
lambda: int(os.getenv("VLLM_NIXL_ABORT_REQUEST_TIMEOUT", "120"))
|
lambda: int(os.getenv("VLLM_NIXL_ABORT_REQUEST_TIMEOUT", "120")),
|
||||||
|
|
||||||
|
# If set to 1, use the TRTLLM Decode Attention backend in flashinfer.
|
||||||
|
"VLLM_USE_TRTLLM_DECODE_ATTENTION":
|
||||||
|
lambda: os.getenv("VLLM_USE_TRTLLM_DECODE_ATTENTION", None),
|
||||||
}
|
}
|
||||||
|
|
||||||
# --8<-- [end:env-vars-definition]
|
# --8<-- [end:env-vars-definition]
|
||||||
|
|||||||
@ -244,6 +244,10 @@ class CudaPlatformBase(Platform):
|
|||||||
|
|
||||||
if selected_backend == _Backend.FLASHINFER:
|
if selected_backend == _Backend.FLASHINFER:
|
||||||
logger.info_once("Using FlashInfer backend on V1 engine.")
|
logger.info_once("Using FlashInfer backend on V1 engine.")
|
||||||
|
if cls.has_device_capability(100):
|
||||||
|
from vllm.v1.attention.backends.utils import (
|
||||||
|
set_kv_cache_layout)
|
||||||
|
set_kv_cache_layout("HND")
|
||||||
return FLASHINFER_V1
|
return FLASHINFER_V1
|
||||||
elif selected_backend == _Backend.FLEX_ATTENTION:
|
elif selected_backend == _Backend.FLEX_ATTENTION:
|
||||||
logger.info_once("Using FlexAttention backend on V1 engine.")
|
logger.info_once("Using FlexAttention backend on V1 engine.")
|
||||||
@ -271,9 +275,13 @@ class CudaPlatformBase(Platform):
|
|||||||
supports_head_size(FLASHINFER_V1, head_size):
|
supports_head_size(FLASHINFER_V1, head_size):
|
||||||
try:
|
try:
|
||||||
import flashinfer # noqa: F401
|
import flashinfer # noqa: F401
|
||||||
|
|
||||||
|
from vllm.v1.attention.backends.utils import (
|
||||||
|
set_kv_cache_layout)
|
||||||
logger.info_once(
|
logger.info_once(
|
||||||
"Using FlashInfer backend on V1 engine by default for "
|
"Using FlashInfer backend with HND KV cache layout on "
|
||||||
"Blackwell (SM 10.0) GPUs.")
|
"V1 engine by default for Blackwell (SM 10.0) GPUs.")
|
||||||
|
set_kv_cache_layout("HND")
|
||||||
return FLASHINFER_V1
|
return FLASHINFER_V1
|
||||||
except ImportError:
|
except ImportError:
|
||||||
logger.info_once(
|
logger.info_once(
|
||||||
@ -293,6 +301,13 @@ class CudaPlatformBase(Platform):
|
|||||||
# Backends for V0 engine
|
# Backends for V0 engine
|
||||||
if selected_backend == _Backend.FLASHINFER:
|
if selected_backend == _Backend.FLASHINFER:
|
||||||
logger.info("Using FlashInfer backend.")
|
logger.info("Using FlashInfer backend.")
|
||||||
|
if cls.has_device_capability(100):
|
||||||
|
from vllm.v1.attention.backends.utils import (
|
||||||
|
set_kv_cache_layout)
|
||||||
|
logger.info_once(
|
||||||
|
"Using HND KV cache layout on V1 engine by default for "
|
||||||
|
"Blackwell (SM 10.0) GPUs.")
|
||||||
|
set_kv_cache_layout("HND")
|
||||||
return "vllm.attention.backends.flashinfer.FlashInferBackend"
|
return "vllm.attention.backends.flashinfer.FlashInferBackend"
|
||||||
elif selected_backend == _Backend.XFORMERS:
|
elif selected_backend == _Backend.XFORMERS:
|
||||||
logger.info("Using XFormers backend.")
|
logger.info("Using XFormers backend.")
|
||||||
|
|||||||
@ -10,11 +10,13 @@ import torch
|
|||||||
from flashinfer import (BatchDecodeWithPagedKVCacheWrapper,
|
from flashinfer import (BatchDecodeWithPagedKVCacheWrapper,
|
||||||
BatchPrefillWithPagedKVCacheWrapper,
|
BatchPrefillWithPagedKVCacheWrapper,
|
||||||
MultiLevelCascadeAttentionWrapper)
|
MultiLevelCascadeAttentionWrapper)
|
||||||
|
from flashinfer.decode import trtllm_batch_decode_with_kv_cache
|
||||||
|
|
||||||
import vllm.envs as envs
|
import vllm.envs as envs
|
||||||
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||||
AttentionType)
|
AttentionType)
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
|
from vllm.platforms import current_platform
|
||||||
from vllm.v1.attention.backends.flash_attn import use_cascade_attention
|
from vllm.v1.attention.backends.flash_attn import use_cascade_attention
|
||||||
from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
|
from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
|
||||||
CommonAttentionMetadata,
|
CommonAttentionMetadata,
|
||||||
@ -38,6 +40,7 @@ logger = init_logger(__name__)
|
|||||||
class FlashInferBackend(AttentionBackend):
|
class FlashInferBackend(AttentionBackend):
|
||||||
|
|
||||||
accept_output_buffer: bool = True
|
accept_output_buffer: bool = True
|
||||||
|
cached_sm100a_supported: Optional[bool] = None
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_supported_head_sizes(cls) -> list[int]:
|
def get_supported_head_sizes(cls) -> list[int]:
|
||||||
@ -93,6 +96,57 @@ class FlashInferBackend(AttentionBackend):
|
|||||||
raise ValueError(f"Unknown cache layout format {cache_layout}.")
|
raise ValueError(f"Unknown cache layout format {cache_layout}.")
|
||||||
return stride_order
|
return stride_order
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def use_trtllm_decode_attention(
|
||||||
|
batch_size: int,
|
||||||
|
max_seq_len: int,
|
||||||
|
kv_cache_dtype: str,
|
||||||
|
num_qo_heads: int,
|
||||||
|
num_kv_heads: int,
|
||||||
|
attn_head_size: int,
|
||||||
|
) -> bool:
|
||||||
|
if FlashInferBackend.cached_sm100a_supported is None:
|
||||||
|
FlashInferBackend.cached_sm100a_supported = (
|
||||||
|
current_platform.has_device_capability(100))
|
||||||
|
if not FlashInferBackend.cached_sm100a_supported:
|
||||||
|
return False
|
||||||
|
if (num_qo_heads // num_kv_heads > 8
|
||||||
|
or num_qo_heads % num_kv_heads != 0 or attn_head_size != 128):
|
||||||
|
return False
|
||||||
|
env_value = envs.VLLM_USE_TRTLLM_DECODE_ATTENTION
|
||||||
|
if env_value is not None:
|
||||||
|
logger.info_once("VLLM_USE_TRTLLM_DECODE_ATTENTION is set to %s",
|
||||||
|
env_value)
|
||||||
|
# Environment variable is set - respect it
|
||||||
|
# Making the conditional check for zero because
|
||||||
|
# the path is automatically enabled if the batch size condition
|
||||||
|
# is satisfied.
|
||||||
|
no_use_trtllm = env_value == "0"
|
||||||
|
if not no_use_trtllm:
|
||||||
|
logger.info_once(
|
||||||
|
"VLLM_USE_TRTLLM_DECODE_ATTENTION is set to 1, "
|
||||||
|
"using TRTLLM decode attention.")
|
||||||
|
return not no_use_trtllm
|
||||||
|
else:
|
||||||
|
# Environment variable not set - use auto-detection
|
||||||
|
# Only supports attention head size of 128
|
||||||
|
use_trtllm = (FlashInferBackend.cached_sm100a_supported
|
||||||
|
and batch_size <= 256 and max_seq_len < 131072
|
||||||
|
and kv_cache_dtype == "auto")
|
||||||
|
if use_trtllm:
|
||||||
|
logger.warning_once(
|
||||||
|
"Using TRTLLM decode attention (auto-detected).")
|
||||||
|
return use_trtllm
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_fp8_dtype_for_flashinfer(kv_cache_dtype: str) -> torch.dtype:
|
||||||
|
if kv_cache_dtype in ("fp8", "fp8_e4m3"):
|
||||||
|
return torch.float8_e4m3fn
|
||||||
|
elif kv_cache_dtype == "fp8_e5m2":
|
||||||
|
return torch.float8_e5m2
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unrecognized FP8 dtype: {kv_cache_dtype}")
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class FlashInferMetadata:
|
class FlashInferMetadata:
|
||||||
@ -127,12 +181,18 @@ class FlashInferMetadata:
|
|||||||
# Block size of vllm
|
# Block size of vllm
|
||||||
page_size: int
|
page_size: int
|
||||||
# The data type of the paged kv cache
|
# The data type of the paged kv cache
|
||||||
data_type: torch.dtype
|
kv_data_type: torch.dtype
|
||||||
# The data type of the query
|
# The data type of the query
|
||||||
q_data_type: torch.dtype
|
q_data_type: torch.dtype
|
||||||
|
|
||||||
slot_mapping: torch.Tensor
|
slot_mapping: torch.Tensor
|
||||||
|
|
||||||
|
# For flashinfer trtllm batch decode
|
||||||
|
max_seq_len: int
|
||||||
|
seq_lens: torch.Tensor
|
||||||
|
block_table_tensor: torch.Tensor
|
||||||
|
workspace_buffer: torch.Tensor
|
||||||
|
|
||||||
# For handling prefill decode split
|
# For handling prefill decode split
|
||||||
num_decodes: int
|
num_decodes: int
|
||||||
num_decode_tokens: int
|
num_decode_tokens: int
|
||||||
@ -299,6 +359,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
|||||||
window_left=self.global_hyperparameters.window_left,
|
window_left=self.global_hyperparameters.window_left,
|
||||||
logits_soft_cap=self.global_hyperparameters.logits_soft_cap,
|
logits_soft_cap=self.global_hyperparameters.logits_soft_cap,
|
||||||
q_data_type=attn_metadata.q_data_type,
|
q_data_type=attn_metadata.q_data_type,
|
||||||
|
kv_data_type=attn_metadata.kv_data_type,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# Regular attention (common case).
|
# Regular attention (common case).
|
||||||
@ -334,28 +395,33 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
|||||||
logits_soft_cap=self.global_hyperparameters.
|
logits_soft_cap=self.global_hyperparameters.
|
||||||
logits_soft_cap,
|
logits_soft_cap,
|
||||||
q_data_type=attn_metadata.q_data_type,
|
q_data_type=attn_metadata.q_data_type,
|
||||||
kv_data_type=attn_metadata.data_type,
|
kv_data_type=attn_metadata.kv_data_type,
|
||||||
)
|
)
|
||||||
|
|
||||||
if self._num_decodes > 0:
|
if self._num_decodes > 0:
|
||||||
attn_metadata.decode_wrapper = self._get_decode_wrapper()
|
attn_metadata.decode_wrapper = self._get_decode_wrapper()
|
||||||
attn_metadata.decode_wrapper.plan(
|
if not FlashInferBackend.use_trtllm_decode_attention(
|
||||||
attn_metadata.paged_kv_indptr[:self._num_decodes + 1],
|
self._num_decodes, attn_metadata.max_seq_len,
|
||||||
attn_metadata.paged_kv_indices,
|
attn_metadata.kv_data_type, attn_metadata.num_qo_heads,
|
||||||
attn_metadata.paged_kv_last_page_len[:self._num_decodes],
|
attn_metadata.num_kv_heads, attn_metadata.head_dim):
|
||||||
attn_metadata.num_qo_heads,
|
attn_metadata.decode_wrapper.plan(
|
||||||
attn_metadata.num_kv_heads,
|
attn_metadata.paged_kv_indptr[:self._num_decodes + 1],
|
||||||
attn_metadata.head_dim,
|
attn_metadata.paged_kv_indices,
|
||||||
attn_metadata.page_size,
|
attn_metadata.paged_kv_last_page_len[:self.
|
||||||
# Disable flashinfer's pos encoding and use vllm's rope.
|
_num_decodes],
|
||||||
pos_encoding_mode="NONE",
|
attn_metadata.num_qo_heads,
|
||||||
sm_scale=self.global_hyperparameters.sm_scale,
|
attn_metadata.num_kv_heads,
|
||||||
window_left=self.global_hyperparameters.window_left,
|
attn_metadata.head_dim,
|
||||||
logits_soft_cap=self.global_hyperparameters.
|
attn_metadata.page_size,
|
||||||
logits_soft_cap,
|
# Disable flashinfer's pos encoding and use vllm's rope.
|
||||||
q_data_type=attn_metadata.q_data_type,
|
pos_encoding_mode="NONE",
|
||||||
kv_data_type=attn_metadata.data_type,
|
sm_scale=self.global_hyperparameters.sm_scale,
|
||||||
)
|
window_left=self.global_hyperparameters.window_left,
|
||||||
|
logits_soft_cap=self.global_hyperparameters.
|
||||||
|
logits_soft_cap,
|
||||||
|
q_data_type=attn_metadata.q_data_type,
|
||||||
|
kv_data_type=attn_metadata.kv_data_type,
|
||||||
|
)
|
||||||
|
|
||||||
def build(self, common_prefix_len: int,
|
def build(self, common_prefix_len: int,
|
||||||
common_attn_metadata: CommonAttentionMetadata):
|
common_attn_metadata: CommonAttentionMetadata):
|
||||||
@ -368,6 +434,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
|||||||
page_size = self.kv_cache_spec.block_size
|
page_size = self.kv_cache_spec.block_size
|
||||||
device = self.runner.device
|
device = self.runner.device
|
||||||
qo_indptr = common_attn_metadata.query_start_loc
|
qo_indptr = common_attn_metadata.query_start_loc
|
||||||
|
max_seq_len = int(self.runner.seq_lens_np[:num_reqs].max())
|
||||||
seq_lens = common_attn_metadata.seq_lens
|
seq_lens = common_attn_metadata.seq_lens
|
||||||
block_table_tensor = self.block_table.get_device_tensor()[:num_reqs]
|
block_table_tensor = self.block_table.get_device_tensor()[:num_reqs]
|
||||||
slot_mapping = self.block_table.slot_mapping_cpu[:num_actual_tokens].to(
|
slot_mapping = self.block_table.slot_mapping_cpu[:num_actual_tokens].to(
|
||||||
@ -416,7 +483,12 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
|||||||
paged_kv_last_page_len = seq_lens % page_size
|
paged_kv_last_page_len = seq_lens % page_size
|
||||||
paged_kv_last_page_len = torch.where(paged_kv_last_page_len == 0,
|
paged_kv_last_page_len = torch.where(paged_kv_last_page_len == 0,
|
||||||
page_size, paged_kv_last_page_len)
|
page_size, paged_kv_last_page_len)
|
||||||
|
cache_dtype = self.runner.cache_config.cache_dtype
|
||||||
|
if cache_dtype.startswith("fp8"):
|
||||||
|
kv_cache_dtype = FlashInferBackend.get_fp8_dtype_for_flashinfer(
|
||||||
|
cache_dtype)
|
||||||
|
else:
|
||||||
|
kv_cache_dtype = self.kv_cache_spec.dtype
|
||||||
attn_metadata = FlashInferMetadata(
|
attn_metadata = FlashInferMetadata(
|
||||||
num_actual_tokens=num_actual_tokens,
|
num_actual_tokens=num_actual_tokens,
|
||||||
qo_indptr=qo_indptr,
|
qo_indptr=qo_indptr,
|
||||||
@ -427,7 +499,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
|||||||
num_kv_heads=self.kv_cache_spec.num_kv_heads,
|
num_kv_heads=self.kv_cache_spec.num_kv_heads,
|
||||||
head_dim=self.kv_cache_spec.head_size,
|
head_dim=self.kv_cache_spec.head_size,
|
||||||
page_size=page_size,
|
page_size=page_size,
|
||||||
data_type=self.kv_cache_spec.dtype,
|
kv_data_type=kv_cache_dtype,
|
||||||
q_data_type=self.runner.dtype,
|
q_data_type=self.runner.dtype,
|
||||||
slot_mapping=slot_mapping,
|
slot_mapping=slot_mapping,
|
||||||
num_decodes=self._num_decodes,
|
num_decodes=self._num_decodes,
|
||||||
@ -439,6 +511,10 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
|||||||
shared_kv_page_indptr=shared_kv_page_indptr,
|
shared_kv_page_indptr=shared_kv_page_indptr,
|
||||||
shared_kv_page_indices=shared_kv_page_indices,
|
shared_kv_page_indices=shared_kv_page_indices,
|
||||||
shared_kv_last_page_len=shared_kv_last_page_len,
|
shared_kv_last_page_len=shared_kv_last_page_len,
|
||||||
|
max_seq_len=max_seq_len,
|
||||||
|
seq_lens=seq_lens,
|
||||||
|
block_table_tensor=block_table_tensor,
|
||||||
|
workspace_buffer=self._workspace_buffer,
|
||||||
)
|
)
|
||||||
|
|
||||||
self._plan(attn_metadata)
|
self._plan(attn_metadata)
|
||||||
@ -514,7 +590,11 @@ class FlashInferImpl(AttentionImpl):
|
|||||||
query: shape = [num_tokens, num_heads, head_size]
|
query: shape = [num_tokens, num_heads, head_size]
|
||||||
key: shape = [num_tokens, num_kv_heads, head_size]
|
key: shape = [num_tokens, num_kv_heads, head_size]
|
||||||
value: shape = [num_tokens, num_kv_heads, head_size]
|
value: shape = [num_tokens, num_kv_heads, head_size]
|
||||||
kv_cache = [num_blocks, 2, block_size, num_kv_heads, head_size]
|
kv_cache: shape -
|
||||||
|
# NHD: [num_blocks, 2, block_size, num_kv_heads, head_size]
|
||||||
|
# HND: [num_blocks, 2, num_kv_heads, block_size, head_size]
|
||||||
|
|
||||||
|
|
||||||
attn_metadata: Metadata for attention.
|
attn_metadata: Metadata for attention.
|
||||||
Returns:
|
Returns:
|
||||||
shape = [num_tokens, num_heads * head_size]
|
shape = [num_tokens, num_heads * head_size]
|
||||||
@ -560,6 +640,13 @@ class FlashInferImpl(AttentionImpl):
|
|||||||
layer._v_scale,
|
layer._v_scale,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# The FlashInfer api requires data to be in fp8_e4m3 or fp8_e5m2
|
||||||
|
# to process the cache when the kv_cache_dtype is fp8
|
||||||
|
if self.kv_cache_dtype.startswith("fp8"):
|
||||||
|
torch_dtype = FlashInferBackend.get_fp8_dtype_for_flashinfer(
|
||||||
|
self.kv_cache_dtype)
|
||||||
|
kv_cache = kv_cache.view(torch_dtype)
|
||||||
|
|
||||||
window_left = (self.sliding_window[0]
|
window_left = (self.sliding_window[0]
|
||||||
if self.sliding_window is not None else -1)
|
if self.sliding_window is not None else -1)
|
||||||
|
|
||||||
@ -597,21 +684,45 @@ class FlashInferImpl(AttentionImpl):
|
|||||||
v_scale=layer._v_scale_float,
|
v_scale=layer._v_scale_float,
|
||||||
out=output[num_decode_tokens:],
|
out=output[num_decode_tokens:],
|
||||||
)
|
)
|
||||||
|
|
||||||
if decode_wrapper := attn_metadata.decode_wrapper:
|
if decode_wrapper := attn_metadata.decode_wrapper:
|
||||||
decode_query = query[:num_decode_tokens]
|
decode_query = query[:num_decode_tokens]
|
||||||
assert decode_query.shape[0] == num_decode_tokens
|
assert decode_query.shape[0] == num_decode_tokens
|
||||||
assert decode_wrapper is not None
|
if not FlashInferBackend.use_trtllm_decode_attention(
|
||||||
assert decode_wrapper._window_left == window_left
|
attn_metadata.num_decodes, attn_metadata.max_seq_len,
|
||||||
assert decode_wrapper._logits_soft_cap == (self.logits_soft_cap
|
self.kv_cache_dtype, attn_metadata.num_qo_heads,
|
||||||
or 0.0)
|
attn_metadata.num_kv_heads, attn_metadata.head_dim):
|
||||||
assert decode_wrapper._sm_scale == self.scale
|
assert decode_wrapper is not None
|
||||||
decode_wrapper.run(
|
assert decode_wrapper._window_left == window_left
|
||||||
decode_query,
|
assert decode_wrapper._logits_soft_cap == (self.logits_soft_cap
|
||||||
kv_cache.permute(*stride_order),
|
or 0.0)
|
||||||
k_scale=layer._k_scale_float,
|
assert decode_wrapper._sm_scale == self.scale
|
||||||
v_scale=layer._v_scale_float,
|
decode_wrapper.run(
|
||||||
out=output[:num_decode_tokens],
|
decode_query,
|
||||||
)
|
kv_cache.permute(*stride_order),
|
||||||
|
k_scale=layer._k_scale_float,
|
||||||
|
v_scale=layer._v_scale_float,
|
||||||
|
out=output[:num_decode_tokens],
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# This path needs to be enabled with VLLM_KV_CACHE_LAYOUT = HND
|
||||||
|
if num_decode_tokens > 0:
|
||||||
|
assert get_kv_cache_layout() == "HND"
|
||||||
|
output[:num_decode_tokens] = (
|
||||||
|
trtllm_batch_decode_with_kv_cache(
|
||||||
|
query=decode_query,
|
||||||
|
kv_cache=kv_cache.permute(*stride_order),
|
||||||
|
workspace_buffer=attn_metadata.workspace_buffer,
|
||||||
|
num_heads=self.num_heads,
|
||||||
|
num_kv_heads=self.num_kv_heads,
|
||||||
|
scale=self.scale,
|
||||||
|
block_tables=attn_metadata.
|
||||||
|
block_table_tensor[:num_decode_tokens],
|
||||||
|
seq_lens=attn_metadata.
|
||||||
|
seq_lens[:num_decode_tokens],
|
||||||
|
block_size=attn_metadata.page_size,
|
||||||
|
max_seq_len=attn_metadata.max_seq_len,
|
||||||
|
kv_cache_dtype=self.kv_cache_dtype,
|
||||||
|
k_scale=layer._k_scale_float,
|
||||||
|
v_scale=layer._v_scale_float,
|
||||||
|
))
|
||||||
return output_padded
|
return output_padded
|
||||||
|
|||||||
@ -24,6 +24,7 @@ from vllm.distributed.kv_transfer.kv_connector.utils import (
|
|||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
_KV_CACHE_LAYOUT_OVERRIDE = None
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -103,6 +104,7 @@ class AttentionMetadataBuilder(abc.ABC, Generic[M]):
|
|||||||
|
|
||||||
@functools.lru_cache
|
@functools.lru_cache
|
||||||
def get_kv_cache_layout():
|
def get_kv_cache_layout():
|
||||||
|
global _KV_CACHE_LAYOUT_OVERRIDE
|
||||||
# Override with format specified by the user.
|
# Override with format specified by the user.
|
||||||
cache_layout = envs.VLLM_KV_CACHE_LAYOUT
|
cache_layout = envs.VLLM_KV_CACHE_LAYOUT
|
||||||
if cache_layout is None:
|
if cache_layout is None:
|
||||||
@ -110,10 +112,16 @@ def get_kv_cache_layout():
|
|||||||
else:
|
else:
|
||||||
logger.info_once("`VLLM_KV_CACHE_LAYOUT` environment variable " \
|
logger.info_once("`VLLM_KV_CACHE_LAYOUT` environment variable " \
|
||||||
"detected. Setting KV cache layout to %s.", cache_layout)
|
"detected. Setting KV cache layout to %s.", cache_layout)
|
||||||
|
if _KV_CACHE_LAYOUT_OVERRIDE is not None:
|
||||||
|
cache_layout = _KV_CACHE_LAYOUT_OVERRIDE
|
||||||
return cache_layout
|
return cache_layout
|
||||||
|
|
||||||
|
|
||||||
|
def set_kv_cache_layout(cache_layout: str):
|
||||||
|
global _KV_CACHE_LAYOUT_OVERRIDE
|
||||||
|
_KV_CACHE_LAYOUT_OVERRIDE = cache_layout
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class PerLayerParameters:
|
class PerLayerParameters:
|
||||||
"""
|
"""
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user