mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-03 21:48:31 +08:00
[Kernel] Support CUDA Graphs in 3D Triton Attention Kernel (#28306)
Signed-off-by: Jan van Lunteren <jvl@zurich.ibm.com> Signed-off-by: jvlunteren <161835099+jvlunteren@users.noreply.github.com> Co-authored-by: Thomas Parnell <tom.parnell@gmail.com> Co-authored-by: Thomas Parnell <tpa@zurich.ibm.com>
This commit is contained in:
parent
09ad3b76b3
commit
9c0ee995a8
@ -7,6 +7,7 @@ import torch
|
||||
|
||||
from vllm.attention.ops.triton_unified_attention import unified_attention
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.math_utils import next_power_of_2
|
||||
|
||||
NUM_HEADS = [(4, 4), (8, 2)]
|
||||
HEAD_SIZES = [128, 256]
|
||||
@ -22,6 +23,10 @@ QDTYPES = (
|
||||
# one value small enough to test the schema op check
|
||||
NUM_BLOCKS = [32768, 2048]
|
||||
|
||||
# 0: use 2D kernel for decode
|
||||
# 8: use 3D kernel for decode
|
||||
SEQ_THRESHOLD_3D_VALUES = [0, 8]
|
||||
|
||||
|
||||
def ref_paged_attn(
|
||||
query: torch.Tensor,
|
||||
@ -92,6 +97,7 @@ def ref_paged_attn(
|
||||
@pytest.mark.parametrize("soft_cap", [None, 50.0])
|
||||
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
|
||||
@pytest.mark.parametrize("q_dtype", QDTYPES)
|
||||
@pytest.mark.parametrize("seq_threshold_3D", SEQ_THRESHOLD_3D_VALUES)
|
||||
@torch.inference_mode()
|
||||
def test_triton_unified_attn(
|
||||
seq_lens: list[tuple[int, int]],
|
||||
@ -103,6 +109,7 @@ def test_triton_unified_attn(
|
||||
soft_cap: float | None,
|
||||
num_blocks: int,
|
||||
q_dtype: torch.dtype | None,
|
||||
seq_threshold_3D: int,
|
||||
) -> None:
|
||||
torch.set_default_device("cuda")
|
||||
|
||||
@ -152,6 +159,21 @@ def test_triton_unified_attn(
|
||||
k_descale = torch.rand(scale_shape, dtype=torch.float32)
|
||||
v_descale = torch.rand(scale_shape, dtype=torch.float32)
|
||||
|
||||
num_par_softmax_segments = 16
|
||||
head_size_padded = next_power_of_2(head_size)
|
||||
softmax_segm_output = torch.empty(
|
||||
(seq_threshold_3D, num_query_heads, num_par_softmax_segments, head_size_padded),
|
||||
dtype=torch.float32,
|
||||
)
|
||||
softmax_segm_max = torch.empty(
|
||||
(seq_threshold_3D, num_query_heads, num_par_softmax_segments),
|
||||
dtype=torch.float32,
|
||||
)
|
||||
softmax_segm_expsum = torch.empty(
|
||||
(seq_threshold_3D, num_query_heads, num_par_softmax_segments),
|
||||
dtype=torch.float32,
|
||||
)
|
||||
|
||||
unified_attention(
|
||||
q=maybe_quantized_query,
|
||||
k=maybe_quantized_key_cache,
|
||||
@ -169,6 +191,11 @@ def test_triton_unified_attn(
|
||||
q_descale=q_descale,
|
||||
k_descale=k_descale,
|
||||
v_descale=v_descale,
|
||||
seq_threshold_3D=seq_threshold_3D,
|
||||
num_par_softmax_segments=num_par_softmax_segments,
|
||||
softmax_segm_output=softmax_segm_output,
|
||||
softmax_segm_max=softmax_segm_max,
|
||||
softmax_segm_expsum=softmax_segm_expsum,
|
||||
)
|
||||
|
||||
ref_output = ref_paged_attn(
|
||||
|
||||
@ -355,7 +355,7 @@ def kernel_unified_attention_2d(
|
||||
@triton.jit
|
||||
def kernel_unified_attention_3d(
|
||||
segm_output_ptr,
|
||||
# [num_tokens, num_query_heads, num_segments, head_size]
|
||||
# [num_tokens, num_query_heads, num_segments, head_size_padded]
|
||||
segm_max_ptr, # [num_tokens, num_query_heads, num_segments]
|
||||
segm_expsum_ptr, # [num_tokens, num_query_heads, num_segments]
|
||||
query_ptr, # [num_tokens, num_query_heads, head_size]
|
||||
@ -749,6 +749,11 @@ def unified_attention(
|
||||
q_descale,
|
||||
k_descale,
|
||||
v_descale,
|
||||
seq_threshold_3D=None,
|
||||
num_par_softmax_segments=None,
|
||||
softmax_segm_output=None,
|
||||
softmax_segm_max=None,
|
||||
softmax_segm_expsum=None,
|
||||
alibi_slopes=None,
|
||||
output_scale=None,
|
||||
qq_bias=None,
|
||||
@ -793,8 +798,19 @@ def unified_attention(
|
||||
TILE_SIZE_PREFILL = 32
|
||||
TILE_SIZE_DECODE = 16 if q.element_size() >= 2 else 32
|
||||
|
||||
# if batch contains a prefill
|
||||
if max_seqlen_q > 1 or total_num_q_blocks * num_kv_heads > 128:
|
||||
# Launch the 2D kernel if
|
||||
# 1. No intermediate tiled softmax buffers for the 3D kernel have been allocated, or
|
||||
# 2. The batch includes at least one prefill request, or
|
||||
# 3. The number of sequences exceeds the configured threshold
|
||||
if (
|
||||
seq_threshold_3D is None
|
||||
or num_par_softmax_segments is None
|
||||
or softmax_segm_output is None
|
||||
or softmax_segm_max is None
|
||||
or softmax_segm_expsum is None
|
||||
or max_seqlen_q > 1
|
||||
or num_seqs > seq_threshold_3D
|
||||
):
|
||||
kernel_unified_attention_2d[
|
||||
(
|
||||
total_num_q_blocks,
|
||||
@ -847,37 +863,12 @@ def unified_attention(
|
||||
USE_FP8=output_scale is not None,
|
||||
)
|
||||
else:
|
||||
# for initial version, NUM_SEGMENTS = 16 is chosen as a default
|
||||
# value that showed good performance in tests
|
||||
NUM_SEGMENTS = 16
|
||||
|
||||
segm_output = torch.empty(
|
||||
q.shape[0],
|
||||
num_query_heads,
|
||||
NUM_SEGMENTS,
|
||||
triton.next_power_of_2(head_size),
|
||||
dtype=torch.float32,
|
||||
device=q.device,
|
||||
)
|
||||
segm_max = torch.empty(
|
||||
q.shape[0],
|
||||
num_query_heads,
|
||||
NUM_SEGMENTS,
|
||||
dtype=torch.float32,
|
||||
device=q.device,
|
||||
)
|
||||
segm_expsum = torch.empty(
|
||||
q.shape[0],
|
||||
num_query_heads,
|
||||
NUM_SEGMENTS,
|
||||
dtype=torch.float32,
|
||||
device=q.device,
|
||||
)
|
||||
|
||||
kernel_unified_attention_3d[(total_num_q_blocks, num_kv_heads, NUM_SEGMENTS)](
|
||||
segm_output_ptr=segm_output,
|
||||
segm_max_ptr=segm_max,
|
||||
segm_expsum_ptr=segm_expsum,
|
||||
kernel_unified_attention_3d[
|
||||
(total_num_q_blocks, num_kv_heads, num_par_softmax_segments)
|
||||
](
|
||||
segm_output_ptr=softmax_segm_output,
|
||||
segm_max_ptr=softmax_segm_max,
|
||||
segm_expsum_ptr=softmax_segm_expsum,
|
||||
query_ptr=q,
|
||||
key_cache_ptr=k,
|
||||
value_cache_ptr=v,
|
||||
@ -917,13 +908,13 @@ def unified_attention(
|
||||
BLOCK_Q=BLOCK_Q,
|
||||
num_seqs=num_seqs,
|
||||
BLOCK_M=BLOCK_M,
|
||||
NUM_SEGMENTS_PER_SEQ=NUM_SEGMENTS,
|
||||
NUM_SEGMENTS_PER_SEQ=num_par_softmax_segments,
|
||||
)
|
||||
reduce_segments[(q.shape[0], num_query_heads)](
|
||||
output_ptr=out,
|
||||
segm_output_ptr=segm_output,
|
||||
segm_max_ptr=segm_max,
|
||||
segm_expsum_ptr=segm_expsum,
|
||||
segm_output_ptr=softmax_segm_output,
|
||||
segm_max_ptr=softmax_segm_max,
|
||||
segm_expsum_ptr=softmax_segm_expsum,
|
||||
seq_lens_ptr=seqused_k,
|
||||
num_seqs=num_seqs,
|
||||
num_query_heads=num_query_heads,
|
||||
@ -936,6 +927,6 @@ def unified_attention(
|
||||
HEAD_SIZE_PADDED=triton.next_power_of_2(head_size),
|
||||
query_start_len_ptr=cu_seqlens_q,
|
||||
BLOCK_Q=BLOCK_Q,
|
||||
NUM_SEGMENTS_PER_SEQ=NUM_SEGMENTS,
|
||||
NUM_SEGMENTS_PER_SEQ=num_par_softmax_segments,
|
||||
USE_FP8=output_scale is not None,
|
||||
)
|
||||
|
||||
@ -17,7 +17,7 @@ from vllm.attention.ops.triton_reshape_and_cache_flash import (
|
||||
triton_reshape_and_cache_flash,
|
||||
)
|
||||
from vllm.attention.ops.triton_unified_attention import unified_attention
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config import CUDAGraphMode, VllmConfig
|
||||
from vllm.config.cache import CacheDType
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
@ -26,6 +26,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.platforms.interface import DeviceCapability
|
||||
from vllm.utils.math_utils import next_power_of_2
|
||||
from vllm.v1.attention.backends.utils import (
|
||||
AttentionCGSupport,
|
||||
AttentionMetadataBuilder,
|
||||
@ -36,6 +37,11 @@ from vllm.v1.kv_cache_interface import AttentionSpec
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
# constants
|
||||
MIN_LAUNCH_GRID_SIZE_2D = 128 # Minimum launch grid size of 2D kernel
|
||||
NUM_PAR_SOFTMAX_SEGMENTS = 16 # Number of parallel tiled softmax segments
|
||||
|
||||
|
||||
@dataclass
|
||||
class TritonAttentionMetadata:
|
||||
# NOTE(sang): Definition of context_len, query_len, and seq_len.
|
||||
@ -54,6 +60,12 @@ class TritonAttentionMetadata:
|
||||
block_table: torch.Tensor
|
||||
slot_mapping: torch.Tensor
|
||||
|
||||
seq_threshold_3D: int
|
||||
num_par_softmax_segments: int
|
||||
softmax_segm_output: torch.Tensor
|
||||
softmax_segm_max: torch.Tensor
|
||||
softmax_segm_expsum: torch.Tensor
|
||||
|
||||
# For cascade attention.
|
||||
use_cascade: bool
|
||||
common_prefix_len: int
|
||||
@ -87,6 +99,60 @@ class TritonAttentionMetadataBuilder(AttentionMetadataBuilder[TritonAttentionMet
|
||||
self.num_heads_kv = model_config.get_num_kv_heads(vllm_config.parallel_config)
|
||||
self.headdim = model_config.get_head_size()
|
||||
|
||||
# Check if CUDA Graphs are enabled for decode
|
||||
self.decode_cudagraph_enabled = (
|
||||
self.vllm_config.compilation_config.cudagraph_mode
|
||||
in (
|
||||
CUDAGraphMode.FULL_AND_PIECEWISE,
|
||||
CUDAGraphMode.FULL_DECODE_ONLY,
|
||||
CUDAGraphMode.FULL,
|
||||
)
|
||||
)
|
||||
|
||||
# The launch grid for the 2D kernel is defined as (num_q_blocks, num_heads_kv).
|
||||
# A lower bound for num_q_blocks is the number of sequences.
|
||||
# To ensure the minimum launch grid size is achieved, the number of sequences
|
||||
# must be at least equal to the threshold below.
|
||||
# If this threshold is not reached (i.e., the batch size is not large enough),
|
||||
# the 3D kernel will be selected instead.
|
||||
self.seq_threshold_3D = MIN_LAUNCH_GRID_SIZE_2D // self.num_heads_kv
|
||||
|
||||
# Modify the threshold if needed.
|
||||
if self.decode_cudagraph_enabled:
|
||||
capture_sizes = self.vllm_config.compilation_config.cudagraph_capture_sizes
|
||||
assert capture_sizes, "CUDA Graphs enabled but no capture sizes specified."
|
||||
|
||||
# Select the CUDA Graph capture size closest to self.seq_threshold_3D
|
||||
# as threshold. This ensures that each captured graph covers the
|
||||
# correct execution path.
|
||||
self.seq_threshold_3D = min(
|
||||
capture_sizes,
|
||||
key=lambda x: abs(x - self.seq_threshold_3D),
|
||||
)
|
||||
|
||||
self.num_par_softmax_segments = NUM_PAR_SOFTMAX_SEGMENTS
|
||||
headdim_padded = next_power_of_2(self.headdim)
|
||||
self.softmax_segm_output = torch.empty(
|
||||
(
|
||||
self.seq_threshold_3D,
|
||||
self.num_heads_q,
|
||||
self.num_par_softmax_segments,
|
||||
headdim_padded,
|
||||
),
|
||||
dtype=torch.float32,
|
||||
device=device,
|
||||
)
|
||||
self.softmax_segm_max = torch.empty(
|
||||
(self.seq_threshold_3D, self.num_heads_q, self.num_par_softmax_segments),
|
||||
dtype=torch.float32,
|
||||
device=device,
|
||||
)
|
||||
self.softmax_segm_expsum = torch.empty(
|
||||
(self.seq_threshold_3D, self.num_heads_q, self.num_par_softmax_segments),
|
||||
dtype=torch.float32,
|
||||
device=device,
|
||||
)
|
||||
|
||||
def build_for_cudagraph_capture(
|
||||
self, common_attn_metadata: CommonAttentionMetadata
|
||||
) -> TritonAttentionMetadata:
|
||||
@ -143,6 +209,11 @@ class TritonAttentionMetadataBuilder(AttentionMetadataBuilder[TritonAttentionMet
|
||||
prefix_kv_lens=prefix_kv_lens,
|
||||
suffix_kv_lens=suffix_kv_lens,
|
||||
prefix_scheduler_metadata=prefix_scheduler_metadata,
|
||||
seq_threshold_3D=self.seq_threshold_3D,
|
||||
num_par_softmax_segments=self.num_par_softmax_segments,
|
||||
softmax_segm_output=self.softmax_segm_output,
|
||||
softmax_segm_max=self.softmax_segm_max,
|
||||
softmax_segm_expsum=self.softmax_segm_expsum,
|
||||
)
|
||||
return attn_metadata
|
||||
|
||||
@ -349,6 +420,12 @@ class TritonAttentionImpl(AttentionImpl):
|
||||
max_seqlen_k = attn_metadata.max_seq_len
|
||||
block_table = attn_metadata.block_table
|
||||
|
||||
seq_threshold_3D = attn_metadata.seq_threshold_3D
|
||||
num_par_softmax_segments = attn_metadata.num_par_softmax_segments
|
||||
softmax_segm_output = attn_metadata.softmax_segm_output
|
||||
softmax_segm_max = attn_metadata.softmax_segm_max
|
||||
softmax_segm_expsum = attn_metadata.softmax_segm_expsum
|
||||
|
||||
descale_shape = (cu_seqlens_q.shape[0] - 1, key_cache.shape[2])
|
||||
|
||||
unified_attention(
|
||||
@ -369,6 +446,11 @@ class TritonAttentionImpl(AttentionImpl):
|
||||
q_descale=None, # Not supported
|
||||
k_descale=layer._k_scale.expand(descale_shape),
|
||||
v_descale=layer._v_scale.expand(descale_shape),
|
||||
seq_threshold_3D=seq_threshold_3D,
|
||||
num_par_softmax_segments=num_par_softmax_segments,
|
||||
softmax_segm_output=softmax_segm_output,
|
||||
softmax_segm_max=softmax_segm_max,
|
||||
softmax_segm_expsum=softmax_segm_expsum,
|
||||
sinks=self.sinks,
|
||||
output_scale=output_scale,
|
||||
)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user