mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-16 09:37:25 +08:00
[V1][CUDA] Full cudagraph support for FlashInfer (#21367)
This commit is contained in:
parent
3654847db5
commit
23322431c8
@ -25,7 +25,8 @@ if is_flash_attn_varlen_func_available():
|
||||
from vllm.config import VllmConfig, get_layers_from_vllm_config
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import cdiv
|
||||
from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
|
||||
from vllm.v1.attention.backends.utils import (AttentionCGSupport,
|
||||
AttentionMetadataBuilder,
|
||||
CommonAttentionMetadata,
|
||||
get_kv_cache_layout)
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||
@ -153,7 +154,9 @@ def _get_sliding_window_configs(
|
||||
|
||||
class FlashAttentionMetadataBuilder(
|
||||
AttentionMetadataBuilder[FlashAttentionMetadata]):
|
||||
full_cudagraph_supported: ClassVar[bool] = get_flash_attn_version() == 3
|
||||
attn_cudagraph_support: ClassVar[AttentionCGSupport] = \
|
||||
AttentionCGSupport.NEVER if get_flash_attn_version() == 2 \
|
||||
else AttentionCGSupport.ALWAYS
|
||||
|
||||
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
|
||||
vllm_config: VllmConfig, device: torch.device):
|
||||
|
||||
@ -4,26 +4,28 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
from typing import TYPE_CHECKING, ClassVar, Optional, Union
|
||||
|
||||
import torch
|
||||
from flashinfer import (BatchDecodeWithPagedKVCacheWrapper,
|
||||
BatchPrefillWithPagedKVCacheWrapper,
|
||||
MultiLevelCascadeAttentionWrapper)
|
||||
from flashinfer.decode import trtllm_batch_decode_with_kv_cache
|
||||
from flashinfer.decode import (_get_range_buf, get_seq_lens,
|
||||
trtllm_batch_decode_with_kv_cache)
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||
AttentionType)
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import cdiv
|
||||
from vllm.utils import cdiv, is_pin_memory_available
|
||||
from vllm.utils.flashinfer import use_trtllm_decode_attention
|
||||
from vllm.v1.attention.backends.flash_attn import use_cascade_attention
|
||||
from vllm.v1.attention.backends.utils import (
|
||||
AttentionMetadataBuilder, CommonAttentionMetadata, get_kv_cache_layout,
|
||||
get_per_layer_parameters, infer_global_hyperparameters,
|
||||
reorder_batch_to_split_decodes_and_prefills, split_decodes_and_prefills)
|
||||
AttentionCGSupport, AttentionMetadataBuilder, CommonAttentionMetadata,
|
||||
get_kv_cache_layout, get_per_layer_parameters,
|
||||
infer_global_hyperparameters, reorder_batch_to_split_decodes_and_prefills,
|
||||
split_decodes_and_prefills)
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@ -174,26 +176,66 @@ class FlashInferMetadata:
|
||||
|
||||
|
||||
class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
||||
attn_cudagraph_support: ClassVar[AttentionCGSupport] = \
|
||||
AttentionCGSupport.PURE_DECODE_ONLY
|
||||
|
||||
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
|
||||
vllm_config: VllmConfig, device: torch.device):
|
||||
self.device = device
|
||||
self.vllm_config = vllm_config
|
||||
self.cache_config = vllm_config.cache_config
|
||||
self.kv_cache_spec = kv_cache_spec
|
||||
self._workspace_buffer = None
|
||||
self._prefill_wrapper = None # Wrapper for prefill/append
|
||||
self._decode_wrapper = None # Wrapper for decode
|
||||
self._decode_wrapper = None # Wrapper for decode (general shape)
|
||||
|
||||
self.compilation_config = vllm_config.compilation_config
|
||||
max_num_pages_per_req = cdiv(vllm_config.model_config.max_model_len,
|
||||
self.kv_cache_spec.block_size)
|
||||
max_num_reqs = vllm_config.scheduler_config.max_num_seqs
|
||||
max_num_pages = max_num_reqs * max_num_pages_per_req
|
||||
self.enable_cuda_graph = self.compilation_config.full_cuda_graph
|
||||
if self.enable_cuda_graph:
|
||||
# For full cudagraph capture, one `decode_wrapper` for each batch
|
||||
# size is needed for FlashInfer.
|
||||
self._decode_wrappers_cudagraph: dict[
|
||||
int, BatchDecodeWithPagedKVCacheWrapper] = {}
|
||||
self._decode_cudagraph_max_bs = min(
|
||||
max_num_reqs, self.compilation_config.max_capture_size)
|
||||
|
||||
self._cascade_wrapper = None # Wrapper for cascade attention
|
||||
|
||||
# Global hyperparameters shared by all attention layers
|
||||
self.global_hyperparameters = infer_global_hyperparameters(
|
||||
get_per_layer_parameters(vllm_config, layer_names, FlashInferImpl))
|
||||
|
||||
self.vllm_config = vllm_config
|
||||
self.cache_config = vllm_config.cache_config
|
||||
self.kv_cache_spec = kv_cache_spec
|
||||
max_num_blocks_per_request = cdiv(
|
||||
vllm_config.model_config.max_model_len,
|
||||
self.kv_cache_spec.block_size)
|
||||
self.block_table_arange = torch.arange(max_num_blocks_per_request,
|
||||
# Preparing persistent buffers (device-side)
|
||||
self.paged_kv_indptr = torch.zeros(max_num_reqs + 1,
|
||||
dtype=torch.int32,
|
||||
device=self.device)
|
||||
self.paged_kv_indices = torch.zeros(
|
||||
max_num_pages, # max num pages possible
|
||||
dtype=torch.int32,
|
||||
device=self.device)
|
||||
self.paged_kv_last_page_len = torch.zeros(max_num_reqs,
|
||||
dtype=torch.int32,
|
||||
device=self.device)
|
||||
# host-side buffer
|
||||
pin_memory = is_pin_memory_available()
|
||||
self.paged_kv_indptr_cpu = torch.zeros(max_num_reqs + 1,
|
||||
dtype=torch.int32,
|
||||
device="cpu",
|
||||
pin_memory=pin_memory)
|
||||
self.paged_kv_indices_cpu = torch.zeros(max_num_pages,
|
||||
dtype=torch.int32,
|
||||
device="cpu",
|
||||
pin_memory=pin_memory)
|
||||
self.paged_kv_last_page_len_cpu = torch.zeros(max_num_reqs,
|
||||
dtype=torch.int32,
|
||||
device="cpu",
|
||||
pin_memory=pin_memory)
|
||||
|
||||
self.block_table_arange = torch.arange(max_num_pages_per_req,
|
||||
dtype=torch.int32,
|
||||
device=self.device)
|
||||
|
||||
@ -217,8 +259,16 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
||||
self._get_workspace_buffer(), get_kv_cache_layout())
|
||||
return self._prefill_wrapper
|
||||
|
||||
def _get_decode_wrapper(self):
|
||||
if self._decode_wrapper is None:
|
||||
def _get_decode_wrapper(self,
|
||||
batch_size: int,
|
||||
use_cudagraph: bool = False):
|
||||
if use_cudagraph:
|
||||
decode_wrapper = self._decode_wrappers_cudagraph.get(
|
||||
batch_size, None)
|
||||
else:
|
||||
decode_wrapper = self._decode_wrapper
|
||||
|
||||
if decode_wrapper is None:
|
||||
num_qo_heads = (
|
||||
self.vllm_config.model_config.get_num_attention_heads(
|
||||
self.vllm_config.parallel_config))
|
||||
@ -226,11 +276,32 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
||||
self.vllm_config.parallel_config)
|
||||
use_tensor_cores = envs.VLLM_FLASHINFER_FORCE_TENSOR_CORES or (
|
||||
num_qo_heads // num_kv_heads > 4)
|
||||
self._decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
|
||||
|
||||
if use_cudagraph:
|
||||
paged_kv_indptr = self.paged_kv_indptr[:batch_size + 1]
|
||||
paged_kv_indices = self.paged_kv_indices
|
||||
paged_kv_last_page_len = self.paged_kv_last_page_len[:
|
||||
batch_size]
|
||||
else:
|
||||
paged_kv_indptr = None
|
||||
paged_kv_indices = None
|
||||
paged_kv_last_page_len = None
|
||||
decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
|
||||
self._get_workspace_buffer(),
|
||||
get_kv_cache_layout(),
|
||||
use_cuda_graph=use_cudagraph,
|
||||
paged_kv_indptr_buffer=paged_kv_indptr,
|
||||
paged_kv_indices_buffer=paged_kv_indices,
|
||||
paged_kv_last_page_len_buffer=paged_kv_last_page_len,
|
||||
use_tensor_cores=use_tensor_cores)
|
||||
return self._decode_wrapper
|
||||
|
||||
# save the decode wrapper
|
||||
if use_cudagraph:
|
||||
self._decode_wrappers_cudagraph[batch_size] = decode_wrapper
|
||||
else:
|
||||
self._decode_wrapper = decode_wrapper
|
||||
|
||||
return decode_wrapper
|
||||
|
||||
def _get_cascade_wrapper(self):
|
||||
if self._cascade_wrapper is None:
|
||||
@ -308,16 +379,44 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
||||
)
|
||||
|
||||
if num_decodes > 0:
|
||||
attn_metadata.decode_wrapper = self._get_decode_wrapper()
|
||||
pure_decode = num_prefills == 0
|
||||
# possible required padding for cudagraph replay
|
||||
use_cudagraph = (self.enable_cuda_graph and pure_decode and
|
||||
num_decodes <= self._decode_cudagraph_max_bs)
|
||||
if use_cudagraph:
|
||||
num_input_tokens = (
|
||||
self.vllm_config.pad_for_cudagraph(num_decodes))
|
||||
# Carefully fulfill the padding region with reasonable value
|
||||
# on cpu.
|
||||
# Make sure paged_kv_indptr_cpu is not decreasing
|
||||
self.paged_kv_indptr_cpu[1 + num_decodes:1 +
|
||||
num_input_tokens].fill_(
|
||||
attn_metadata.
|
||||
paged_kv_indptr_cpu[-1])
|
||||
# Fill the remaining paged_kv_last_page_len_cpu with 1.
|
||||
# This is because flashinfer treats 0 as a full page
|
||||
# instead of empty.
|
||||
self.paged_kv_last_page_len_cpu[
|
||||
num_decodes:num_input_tokens].fill_(1)
|
||||
|
||||
else:
|
||||
num_input_tokens = num_decodes
|
||||
|
||||
attn_metadata.decode_wrapper = self._get_decode_wrapper(
|
||||
num_input_tokens, use_cudagraph)
|
||||
if not use_trtllm_decode_attention(
|
||||
num_decodes, attn_metadata.max_seq_len,
|
||||
self.cache_config.cache_dtype,
|
||||
attn_metadata.num_qo_heads, attn_metadata.num_kv_heads,
|
||||
attn_metadata.head_dim):
|
||||
attn_metadata.decode_wrapper.plan(
|
||||
attn_metadata.paged_kv_indptr_cpu[:num_decodes + 1],
|
||||
# Use the persistent buffer with padding length,
|
||||
# instead of the same address but chunked version
|
||||
# in atten_metadata when using cudagraph.
|
||||
fast_plan_decode(
|
||||
attn_metadata.decode_wrapper,
|
||||
self.paged_kv_indptr_cpu[:num_input_tokens + 1],
|
||||
attn_metadata.paged_kv_indices,
|
||||
attn_metadata.paged_kv_last_page_len_cpu[:num_decodes],
|
||||
self.paged_kv_last_page_len_cpu[:num_input_tokens],
|
||||
attn_metadata.num_qo_heads,
|
||||
attn_metadata.num_kv_heads,
|
||||
attn_metadata.head_dim,
|
||||
@ -336,6 +435,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
||||
common_prefix_len: int,
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
fast_build: bool = False) -> FlashInferMetadata:
|
||||
num_reqs = common_attn_metadata.num_reqs
|
||||
num_actual_tokens = common_attn_metadata.num_actual_tokens
|
||||
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens =\
|
||||
split_decodes_and_prefills(common_attn_metadata)
|
||||
@ -381,18 +481,26 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
||||
non_blocking=True)
|
||||
mask = (self.block_table_arange[:max_num_blocks].unsqueeze(0)
|
||||
< block_table_bounds.unsqueeze(1))
|
||||
paged_kv_indices = block_table_tensor[:, :max_num_blocks][mask]
|
||||
# write self.paged_kv_indices inplace
|
||||
num_actual_pages = torch.sum(mask)
|
||||
paged_kv_indices = self.paged_kv_indices[:num_actual_pages]
|
||||
torch.masked_select(block_table_tensor[:, :max_num_blocks],
|
||||
mask,
|
||||
out=paged_kv_indices)
|
||||
|
||||
paged_kv_indptr_cpu = torch.zeros(len(block_table_bounds_cpu) + 1,
|
||||
dtype=torch.int32,
|
||||
device='cpu')
|
||||
paged_kv_indptr_cpu[1:] = block_table_bounds_cpu.cumsum(
|
||||
dim=0, dtype=torch.int32)
|
||||
# write self.paged_kv_indptr_cpu inplace (0-index is always 0)
|
||||
torch.cumsum(block_table_bounds_cpu,
|
||||
dim=0,
|
||||
dtype=torch.int32,
|
||||
out=self.paged_kv_indptr_cpu[1:1 + num_reqs])
|
||||
|
||||
paged_kv_last_page_len_cpu = seq_lens_cpu % page_size
|
||||
paged_kv_last_page_len_cpu = torch.where(
|
||||
paged_kv_last_page_len_cpu == 0, page_size,
|
||||
paged_kv_last_page_len_cpu)
|
||||
# write self.paged_kv_last_page_len_cpu inplace
|
||||
torch.where(paged_kv_last_page_len_cpu == 0,
|
||||
torch.tensor(page_size),
|
||||
paged_kv_last_page_len_cpu,
|
||||
out=self.paged_kv_last_page_len_cpu[:num_reqs])
|
||||
|
||||
cache_dtype = self.cache_config.cache_dtype
|
||||
if cache_dtype.startswith("fp8"):
|
||||
kv_cache_dtype = FlashInferBackend.get_fp8_dtype_for_flashinfer(
|
||||
@ -402,9 +510,10 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
||||
attn_metadata = FlashInferMetadata(
|
||||
num_actual_tokens=num_actual_tokens,
|
||||
qo_indptr_cpu=common_attn_metadata.query_start_loc_cpu,
|
||||
paged_kv_indptr_cpu=paged_kv_indptr_cpu,
|
||||
paged_kv_indptr_cpu=self.paged_kv_indptr_cpu[:1 + num_reqs],
|
||||
paged_kv_indices=paged_kv_indices,
|
||||
paged_kv_last_page_len_cpu=paged_kv_last_page_len_cpu,
|
||||
paged_kv_last_page_len_cpu=self.
|
||||
paged_kv_last_page_len_cpu[:num_reqs],
|
||||
num_qo_heads=self.vllm_config.model_config.get_num_attention_heads(
|
||||
self.vllm_config.parallel_config),
|
||||
num_kv_heads=self.kv_cache_spec.num_kv_heads,
|
||||
@ -431,6 +540,26 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
||||
|
||||
return attn_metadata
|
||||
|
||||
def build_for_cudagraph_capture(
|
||||
self, common_attn_metadata: CommonAttentionMetadata):
|
||||
"""
|
||||
This method builds the metadata for full cudagraph capture.
|
||||
Currently, only decode is supported for full cudagraphs with FlashInfer.
|
||||
"""
|
||||
m = common_attn_metadata
|
||||
|
||||
assert m.num_reqs == m.num_actual_tokens, \
|
||||
"FlashInfer only supports decode-only full CUDAGraph capture. " \
|
||||
"Make sure all cudagraph capture sizes <= max_num_seq."
|
||||
|
||||
m.max_query_len = 1 # decode-only
|
||||
|
||||
return self.build(0, m)
|
||||
|
||||
def can_run_in_cudagraph(
|
||||
self, common_attn_metadata: CommonAttentionMetadata) -> bool:
|
||||
return common_attn_metadata.max_query_len == 1
|
||||
|
||||
def use_cascade_attention(self, *args, **kwargs) -> bool:
|
||||
if self.kv_cache_spec.dtype != self.vllm_config.model_config.dtype:
|
||||
# TODO: The cascade wrapper currently does not support setting
|
||||
@ -638,3 +767,163 @@ class FlashInferImpl(AttentionImpl):
|
||||
out=output[:num_decode_tokens],
|
||||
)
|
||||
return output_padded
|
||||
|
||||
|
||||
def fast_plan_decode(
|
||||
self, # decode wrapper
|
||||
indptr_cpu: torch.Tensor,
|
||||
indices: torch.Tensor,
|
||||
last_page_len_cpu: torch.Tensor,
|
||||
num_qo_heads: int,
|
||||
num_kv_heads: int,
|
||||
head_dim: int,
|
||||
page_size: int,
|
||||
pos_encoding_mode: str = "NONE",
|
||||
window_left: int = -1,
|
||||
logits_soft_cap: Optional[float] = None,
|
||||
q_data_type: Optional[Union[str, torch.dtype]] = "float16",
|
||||
kv_data_type: Optional[Union[str, torch.dtype]] = None,
|
||||
data_type: Optional[Union[str, torch.dtype]] = None,
|
||||
sm_scale: Optional[float] = None,
|
||||
rope_scale: Optional[float] = None,
|
||||
rope_theta: Optional[float] = None,
|
||||
non_blocking: bool = True,
|
||||
) -> None:
|
||||
"""
|
||||
A faster version of BatchDecodeWithPagedKVCacheWrapper::plan used for
|
||||
cudagraph capture/replay, while the no cudagraph version turns back
|
||||
to the original plan.
|
||||
using original plan after passing host-side buffers:
|
||||
- only host-to-device copy of indptr and last_page_len buffers
|
||||
Modifications for cudagraph:
|
||||
- only host-to-device copy of indptr and last_page_len buffers.
|
||||
- avoid device-to-device copy of indices buffer.
|
||||
|
||||
Part of the code get inspiration from the original plan from FlashInfer repo
|
||||
and the implementation of fast_decode_plan for FlashInfer in SGlang repo.
|
||||
"""
|
||||
# Warm up with the original plan if it is first call, and always run the
|
||||
# original plan if we run for dynamic shape. For fixed shape (cudagraph),
|
||||
# this warm up is to generate the _cached_module for the decode wrapper.
|
||||
if not self.is_cuda_graph_enabled or \
|
||||
getattr(self, "vllm_first_call", True):
|
||||
self.plan(
|
||||
indptr_cpu,
|
||||
indices,
|
||||
last_page_len_cpu,
|
||||
num_qo_heads,
|
||||
num_kv_heads,
|
||||
head_dim,
|
||||
page_size,
|
||||
pos_encoding_mode,
|
||||
window_left,
|
||||
logits_soft_cap,
|
||||
q_data_type,
|
||||
kv_data_type,
|
||||
data_type,
|
||||
sm_scale,
|
||||
rope_scale,
|
||||
rope_theta,
|
||||
non_blocking,
|
||||
)
|
||||
self.vllm_first_call = False
|
||||
return
|
||||
|
||||
assert self.is_cuda_graph_enabled, "Should be cudagraph only here"
|
||||
|
||||
batch_size = len(last_page_len_cpu)
|
||||
if logits_soft_cap is None:
|
||||
logits_soft_cap = 0.0
|
||||
|
||||
# Handle data types consistently
|
||||
if data_type is not None:
|
||||
if q_data_type is None:
|
||||
q_data_type = data_type
|
||||
if kv_data_type is None:
|
||||
kv_data_type = data_type
|
||||
elif q_data_type is None:
|
||||
q_data_type = "float16"
|
||||
|
||||
if kv_data_type is None:
|
||||
kv_data_type = q_data_type
|
||||
q_data_type = getattr(torch, q_data_type) if isinstance(
|
||||
q_data_type, str) else q_data_type
|
||||
kv_data_type = getattr(torch, kv_data_type) if isinstance(
|
||||
kv_data_type, str) else kv_data_type
|
||||
|
||||
if self.use_tensor_cores:
|
||||
qo_indptr_host = _get_range_buf(batch_size + 1, "cpu")
|
||||
|
||||
if batch_size != self._fixed_batch_size:
|
||||
raise ValueError(
|
||||
"The batch size should be fixed in cudagraph mode, the runtime "
|
||||
"batch size {} mismatches the batch size set during "
|
||||
"initialization {}".format(batch_size, self._fixed_batch_size))
|
||||
if len(indices) > len(self._paged_kv_indices_buf):
|
||||
raise ValueError(
|
||||
"The size of indices should be less than or equal to the "
|
||||
"allocated buffer")
|
||||
|
||||
# host-to-device copy for the indptr buffer
|
||||
self._paged_kv_indptr_buf.copy_(indptr_cpu, non_blocking=True)
|
||||
# host-to-device copy for the last_page_len buffer
|
||||
self._paged_kv_last_page_len_buf.copy_(last_page_len_cpu,
|
||||
non_blocking=True)
|
||||
|
||||
indptr_host = indptr_cpu
|
||||
last_page_len_host = last_page_len_cpu
|
||||
|
||||
if self.use_tensor_cores:
|
||||
kv_lens_arr_host = get_seq_lens(indptr_host, last_page_len_host,
|
||||
page_size)
|
||||
|
||||
try:
|
||||
# Make sure we pass exactly 15 arguments for tensor core version
|
||||
self._plan_info = self._cached_module.plan(
|
||||
self._float_workspace_buffer,
|
||||
self._int_workspace_buffer,
|
||||
self._pin_memory_int_workspace_buffer,
|
||||
qo_indptr_host,
|
||||
indptr_host,
|
||||
kv_lens_arr_host,
|
||||
batch_size, # total_num_rows
|
||||
batch_size,
|
||||
num_qo_heads,
|
||||
num_kv_heads,
|
||||
page_size,
|
||||
self.is_cuda_graph_enabled,
|
||||
head_dim,
|
||||
head_dim,
|
||||
False, # causal
|
||||
)
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Error in tensor core plan: {e}") from e
|
||||
else:
|
||||
try:
|
||||
# Make sure we pass exactly 15 arguments for standard version
|
||||
self._plan_info = self._cached_module.plan(
|
||||
self._float_workspace_buffer,
|
||||
self._int_workspace_buffer,
|
||||
self._pin_memory_int_workspace_buffer,
|
||||
indptr_host,
|
||||
batch_size,
|
||||
num_qo_heads,
|
||||
num_kv_heads,
|
||||
page_size,
|
||||
self.is_cuda_graph_enabled,
|
||||
window_left,
|
||||
logits_soft_cap,
|
||||
head_dim,
|
||||
head_dim,
|
||||
torch.empty(0, dtype=q_data_type),
|
||||
torch.empty(0, dtype=kv_data_type),
|
||||
)
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Error in standard plan: {e}") from e
|
||||
|
||||
self._pos_encoding_mode = pos_encoding_mode
|
||||
self._window_left = window_left
|
||||
self._logits_soft_cap = logits_soft_cap
|
||||
self._sm_scale = sm_scale
|
||||
self._rope_scale = rope_scale
|
||||
self._rope_theta = rope_theta
|
||||
|
||||
@ -18,6 +18,7 @@ from vllm.v1.attention.backends.mla.common import (MLACommonBackend,
|
||||
MLACommonImpl,
|
||||
MLACommonMetadata,
|
||||
MLACommonMetadataBuilder)
|
||||
from vllm.v1.attention.backends.utils import AttentionCGSupport
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@ -54,7 +55,8 @@ class FlashMLAMetadata(MLACommonMetadata[FlashMLADecodeMetadata]):
|
||||
|
||||
|
||||
class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]):
|
||||
full_cudagraph_supported: ClassVar[bool] = True # Decode-only
|
||||
attn_cudagraph_support: ClassVar[AttentionCGSupport] = \
|
||||
AttentionCGSupport.PURE_DECODE_ONLY
|
||||
|
||||
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
|
||||
vllm_config: VllmConfig, device: torch.device):
|
||||
|
||||
@ -17,6 +17,7 @@ from vllm.v1.attention.backends.mla.common import (MLACommonBackend,
|
||||
MLACommonImpl,
|
||||
MLACommonMetadata,
|
||||
MLACommonMetadataBuilder)
|
||||
from vllm.v1.attention.backends.utils import AttentionCGSupport
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||
|
||||
# yapf: enable
|
||||
@ -64,7 +65,8 @@ class AiterMLAMetadata(MLACommonMetadata[AiterMLADecodeMetadata]):
|
||||
|
||||
|
||||
class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
|
||||
full_cudagraph_supported: ClassVar[bool] = True # decode only
|
||||
attn_cudagraph_support: ClassVar[AttentionCGSupport] = \
|
||||
AttentionCGSupport.PURE_DECODE_ONLY
|
||||
|
||||
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
|
||||
vllm_config: VllmConfig, device: torch.device):
|
||||
|
||||
@ -18,7 +18,8 @@ from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata
|
||||
from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
|
||||
from vllm.v1.attention.backends.utils import (AttentionCGSupport,
|
||||
AttentionMetadataBuilder,
|
||||
CommonAttentionMetadata)
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||
|
||||
@ -57,7 +58,8 @@ class TritonAttentionMetadata:
|
||||
|
||||
class TritonAttentionMetadataBuilder(
|
||||
AttentionMetadataBuilder[TritonAttentionMetadata]):
|
||||
full_cudagraph_supported: ClassVar[bool] = True
|
||||
attn_cudagraph_support: ClassVar[AttentionCGSupport] = \
|
||||
AttentionCGSupport.ALWAYS
|
||||
|
||||
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
|
||||
vllm_config: VllmConfig, device: torch.device):
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import abc
|
||||
import enum
|
||||
import functools
|
||||
from abc import abstractmethod
|
||||
from dataclasses import dataclass, make_dataclass
|
||||
@ -65,9 +66,24 @@ class CommonAttentionMetadata:
|
||||
M = TypeVar("M")
|
||||
|
||||
|
||||
class AttentionCGSupport(enum.Enum):
|
||||
""" Constants for the cudagraph support of the attention backend
|
||||
Here we do not consider the cascade attention, as currently
|
||||
it is never cudagraph supported."""
|
||||
|
||||
NEVER = 0
|
||||
"""NO cudagraph support"""
|
||||
PURE_DECODE_ONLY = 1
|
||||
"""Cudagraph supported for pure decode, need to run without
|
||||
cudagraph for mixed prefill-decode batches"""
|
||||
ALWAYS = 2
|
||||
"""Cudagraph always supported"""
|
||||
|
||||
|
||||
class AttentionMetadataBuilder(abc.ABC, Generic[M]):
|
||||
# Does this backend/builder support CUDA Graphs for attention.
|
||||
full_cudagraph_supported: ClassVar[bool] = False
|
||||
attn_cudagraph_support: ClassVar[AttentionCGSupport] = \
|
||||
AttentionCGSupport.NEVER
|
||||
|
||||
@abstractmethod
|
||||
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
|
||||
|
||||
@ -47,7 +47,7 @@ from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler,
|
||||
is_pin_memory_available, round_up, supports_dynamo)
|
||||
from vllm.v1.attention.backends.mamba_selectors import get_mamba_attn_backend
|
||||
from vllm.v1.attention.backends.utils import (
|
||||
AttentionMetadataBuilder, CommonAttentionMetadata,
|
||||
AttentionCGSupport, AttentionMetadataBuilder, CommonAttentionMetadata,
|
||||
make_kv_sharing_fast_prefill_attention_metadata,
|
||||
make_local_attention_virtual_batches)
|
||||
from vllm.v1.core.encoder_cache_manager import compute_encoder_budget
|
||||
@ -2619,12 +2619,22 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
self.device,
|
||||
)
|
||||
|
||||
if (self.full_cuda_graph
|
||||
and not attn_metadata_builder_i.full_cudagraph_supported):
|
||||
raise ValueError(
|
||||
f"Full CUDAGraph not supported for "
|
||||
f"{attn_backend_i.__name__}. Turn off CompilationConfig."
|
||||
f"full_cuda_graph or use a different attention backend.")
|
||||
if self.full_cuda_graph:
|
||||
if attn_metadata_builder_i.attn_cudagraph_support == \
|
||||
AttentionCGSupport.NEVER:
|
||||
raise ValueError(f"Full CUDAGraph not supported for "
|
||||
f"{attn_backend_i.__name__}. Turn off "
|
||||
f"CompilationConfig.full_cuda_graph or use a "
|
||||
f" different attention backend.")
|
||||
if attn_metadata_builder_i.attn_cudagraph_support == \
|
||||
AttentionCGSupport.PURE_DECODE_ONLY:
|
||||
# Limit the max cudagraph size to the max number of
|
||||
# sequences for pure decode only cudagraph backend,
|
||||
# whose max_query_len is 1.
|
||||
self.cudagraph_batch_sizes = [
|
||||
size for size in self.cudagraph_batch_sizes
|
||||
if size <= self.scheduler_config.max_num_seqs
|
||||
]
|
||||
return attn_backend_i, attn_metadata_builder_i
|
||||
|
||||
def initialize_attn_backend(self, kv_cache_config: KVCacheConfig) -> None:
|
||||
|
||||
@ -321,11 +321,16 @@ class Worker(WorkerBase):
|
||||
if get_pp_group().is_last_rank:
|
||||
max_num_reqs = min(self.scheduler_config.max_num_seqs,
|
||||
self.scheduler_config.max_num_batched_tokens)
|
||||
# activate building attn_metadata for this dummy run to avoid
|
||||
# potential illegal memory access for full cudagraph relay.
|
||||
attn_cudagraph = self.compilation_config.full_cuda_graph and\
|
||||
not self.model_config.enforce_eager
|
||||
|
||||
# We skip EPLB here since we don't want to record dummy metrics
|
||||
hidden_states, last_hidden_states = \
|
||||
self.model_runner._dummy_run(
|
||||
num_tokens=max_num_reqs,
|
||||
capture_attn_cudagraph=attn_cudagraph,
|
||||
skip_eplb=True,
|
||||
)
|
||||
if self.model_runner.is_pooling_model:
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user