[Attention] MLA - Flashinfer Ragged Prefill (#20034)

This commit is contained in:
Alexander Matveev 2025-07-10 23:17:47 -04:00 committed by GitHub
parent 922f316441
commit 5b032352cc
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 421 additions and 214 deletions

View File

View File

@ -3,16 +3,10 @@
import filecmp import filecmp
import shutil import shutil
import tempfile import tempfile
from collections import defaultdict
from pathlib import Path from pathlib import Path
from vllm import LLM, SamplingParams from vllm import LLM, SamplingParams
from vllm.config import KVTransferConfig, VllmConfig from vllm.config import KVTransferConfig
from vllm.distributed.kv_transfer.kv_connector.factory import (
KVConnectorFactory)
from vllm.distributed.kv_transfer.kv_connector.v1.shared_storage_connector import ( # noqa
SharedStorageConnector)
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
MODEL_NAME = "meta-llama/Llama-3.2-1B-Instruct" MODEL_NAME = "meta-llama/Llama-3.2-1B-Instruct"
@ -25,65 +19,6 @@ PROMPTS = [
SAMPLING_PARAMS = SamplingParams(temperature=0, max_tokens=20) SAMPLING_PARAMS = SamplingParams(temperature=0, max_tokens=20)
class TestSharedStorageConnector(SharedStorageConnector):
def __init__(self, config: VllmConfig, role):
self.name = config.kv_transfer_config.kv_connector_extra_config["name"]
self._connector = SharedStorageConnector(config, role)
self.call_record: dict[str, int] = defaultdict(int)
# Use a unique temp file per connector
self._event_file = tempfile.gettempdir(
) + f"/connector_{self.name}-{self.role.name}_events.log"
# Start with an empty file
with open(self._event_file, "w") as _:
pass
def __getattribute__(self, name):
if name in ("_connector", "call_record", "name", "_event_file",
"__class__", "__dict__", "__getattribute__",
"__init__"): # avoid recursion
return object.__getattribute__(self, name)
if not hasattr(self._connector, name):
return object.__getattribute__(self, name)
attr = getattr(self._connector, name)
# Intercept calls to the connector interface and write an event
# for each one to a file, which can be read back in the main test proc.
if callable(attr):
def wrapper(*args, **kwargs):
self.call_record[name] += 1
# Include args that we're interested in
to_log = [name]
for arg in args:
if isinstance(arg, int):
to_log.append(str(arg))
elif isinstance(arg, KVCacheBlocks):
to_log.append(
f"num_blocks={[len(b) for b in arg.blocks]}")
# Log the event as a line to the file
try:
with open(self._event_file, "a") as f:
f.write(' '.join(to_log) + "\n")
except Exception as e:
print(f"[ERROR] Could not log event {name} "
f"for {self.name}: {e}")
return attr(*args, **kwargs)
return wrapper
return attr
# This relies on "fork" multiprocessing method being used.
# It's the default but vLLM may fall back to spawn if for example CUDA
# is already initialized.
KVConnectorFactory.register_connector("TestSharedStorageConnector",
TestSharedStorageConnector.__module__,
TestSharedStorageConnector.__name__)
# Helper function to compare directories recursively # Helper function to compare directories recursively
def _compare_directories(dir1: Path, dir2: Path) -> bool: def _compare_directories(dir1: Path, dir2: Path) -> bool:
"""Compares two directories recursively for identical content.""" """Compares two directories recursively for identical content."""
@ -118,19 +53,27 @@ def test_multi_shared_storage_connector_consistency():
kv_role="kv_both", kv_role="kv_both",
kv_connector_extra_config={ kv_connector_extra_config={
"connectors": [{ "connectors": [{
"kv_connector": "TestSharedStorageConnector", "kv_connector":
"kv_role": "kv_both", "TestSharedStorageConnector",
"kv_role":
"kv_both",
"kv_connector_extra_config": { "kv_connector_extra_config": {
"shared_storage_path": str(storage_1_path), "shared_storage_path": str(storage_1_path),
"name": "storage1", "name": "storage1",
} },
"kv_connector_module_path":
"tests.v1.kv_connector.unit.utils",
}, { }, {
"kv_connector": "TestSharedStorageConnector", "kv_connector":
"kv_role": "kv_both", "TestSharedStorageConnector",
"kv_role":
"kv_both",
"kv_connector_extra_config": { "kv_connector_extra_config": {
"shared_storage_path": str(storage_2_path), "shared_storage_path": str(storage_2_path),
"name": "storage2", "name": "storage2",
} },
"kv_connector_module_path":
"tests.v1.kv_connector.unit.utils",
}] }]
}, },
) )

View File

@ -1,5 +1,7 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import tempfile
from collections import defaultdict
from typing import Any, Optional from typing import Any, Optional
import torch import torch
@ -7,6 +9,11 @@ import torch
from vllm import SamplingParams from vllm import SamplingParams
from vllm.config import (CacheConfig, DeviceConfig, KVTransferConfig, from vllm.config import (CacheConfig, DeviceConfig, KVTransferConfig,
ModelConfig, SchedulerConfig, VllmConfig) ModelConfig, SchedulerConfig, VllmConfig)
from vllm.distributed.kv_transfer.kv_connector.factory import (
KVConnectorFactory)
from vllm.distributed.kv_transfer.kv_connector.v1.shared_storage_connector import ( # noqa
SharedStorageConnector)
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
from vllm.v1.core.sched.scheduler import Scheduler from vllm.v1.core.sched.scheduler import Scheduler
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
KVCacheGroupSpec) KVCacheGroupSpec)
@ -187,3 +194,58 @@ def create_model_runner_output(
finished_sending=finished_sending, finished_sending=finished_sending,
finished_recving=finished_recving, finished_recving=finished_recving,
) )
class TestSharedStorageConnector(SharedStorageConnector):
def __init__(self, config: VllmConfig, role):
self.name = config.kv_transfer_config.kv_connector_extra_config["name"]
self._connector = SharedStorageConnector(config, role)
self.call_record: dict[str, int] = defaultdict(int)
# Use a unique temp file per connector
self._event_file = tempfile.gettempdir(
) + f"/connector_{self.name}-{self.role.name}_events.log"
# Start with an empty file
with open(self._event_file, "w") as _:
pass
def __getattribute__(self, name):
if name in ("_connector", "call_record", "name", "_event_file",
"__class__", "__dict__", "__getattribute__",
"__init__"): # avoid recursion
return object.__getattribute__(self, name)
if not hasattr(self._connector, name):
return object.__getattribute__(self, name)
attr = getattr(self._connector, name)
# Intercept calls to the connector interface and write an event
# for each one to a file, which can be read back in the main test proc.
if callable(attr):
def wrapper(*args, **kwargs):
self.call_record[name] += 1
# Include args that we're interested in
to_log = [name]
for arg in args:
if isinstance(arg, int):
to_log.append(str(arg))
elif isinstance(arg, KVCacheBlocks):
to_log.append(
f"num_blocks={[len(b) for b in arg.blocks]}")
# Log the event as a line to the file
try:
with open(self._event_file, "a") as f:
f.write(' '.join(to_log) + "\n")
except Exception as e:
print(f"[ERROR] Could not log event {name} "
f"for {self.name}: {e}")
return attr(*args, **kwargs)
return wrapper
return attr
KVConnectorFactory.register_connector("TestSharedStorageConnector", __name__,
TestSharedStorageConnector.__name__)

View File

@ -10,6 +10,7 @@ import torch.nn.functional as F
import vllm.envs as envs import vllm.envs as envs
from vllm.attention import AttentionType from vllm.attention import AttentionType
from vllm.attention.selector import backend_name_to_enum, get_attn_backend from vllm.attention.selector import backend_name_to_enum, get_attn_backend
from vllm.attention.utils.kv_sharing_utils import validate_kv_sharing_target
from vllm.config import CacheConfig, get_current_vllm_config from vllm.config import CacheConfig, get_current_vllm_config
from vllm.distributed.kv_transfer import (get_kv_transfer_group, from vllm.distributed.kv_transfer import (get_kv_transfer_group,
has_kv_transfer_group, has_kv_transfer_group,
@ -21,7 +22,6 @@ from vllm.model_executor.layers.quantization.base_config import (
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
from vllm.platforms import _Backend, current_platform from vllm.platforms import _Backend, current_platform
from vllm.utils import direct_register_custom_op from vllm.utils import direct_register_custom_op
from vllm.v1.attention.backends.utils import validate_kv_sharing_target
class Attention(nn.Module): class Attention(nn.Module):

View File

@ -0,0 +1,33 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
def validate_kv_sharing_target(current_layer_name, target_layer_name,
static_forward_context):
error_msg = (f"Specified KV sharing target layer for {current_layer_name} "
f"is not valid: target layer {target_layer_name} ")
if current_layer_name == target_layer_name:
raise ValueError(error_msg +
"cannot be the same as the current layer.")
if target_layer_name not in static_forward_context:
from vllm.model_executor.models.utils import extract_layer_index
# If target layer name is not in the static fwd context, it means either
# a) the target layer does not come BEFORE the current layer, or
# b) the target layer is not an Attention layer that exists in the model
current_layer_idx = extract_layer_index(current_layer_name)
target_layer_idx = extract_layer_index(target_layer_name)
if current_layer_idx <= target_layer_idx:
raise ValueError(error_msg + "must come before the current layer.")
else:
raise ValueError(error_msg +
"is not a valid Attention layer in the model.")
# Currently KV sharing is only supported between layers of the same type
target_layer_attn_type = static_forward_context[
target_layer_name].attn_type
expected = static_forward_context[current_layer_name].attn_type
if target_layer_attn_type != expected:
raise ValueError(
error_msg +
f"must be the same type as the current layer ({expected}).")

View File

@ -53,6 +53,12 @@ DEFAULT_LOGGING_CONFIG = {
} }
@lru_cache
def _print_debug_once(logger: Logger, msg: str, *args: Hashable) -> None:
# Set the stacklevel to 2 to print the original caller's line info
logger.debug(msg, *args, stacklevel=2)
@lru_cache @lru_cache
def _print_info_once(logger: Logger, msg: str, *args: Hashable) -> None: def _print_info_once(logger: Logger, msg: str, *args: Hashable) -> None:
# Set the stacklevel to 2 to print the original caller's line info # Set the stacklevel to 2 to print the original caller's line info
@ -74,6 +80,13 @@ class _VllmLogger(Logger):
`intel_extension_for_pytorch.utils._logger`. `intel_extension_for_pytorch.utils._logger`.
""" """
def debug_once(self, msg: str, *args: Hashable) -> None:
"""
As [`debug`][logging.Logger.debug], but subsequent calls with
the same message are silently dropped.
"""
_print_debug_once(self, msg, *args)
def info_once(self, msg: str, *args: Hashable) -> None: def info_once(self, msg: str, *args: Hashable) -> None:
""" """
As [`info`][logging.Logger.info], but subsequent calls with As [`info`][logging.Logger.info], but subsequent calls with
@ -132,6 +145,7 @@ def init_logger(name: str) -> _VllmLogger:
logger = logging.getLogger(name) logger = logging.getLogger(name)
methods_to_patch = { methods_to_patch = {
"debug_once": _print_debug_once,
"info_once": _print_info_once, "info_once": _print_info_once,
"warning_once": _print_warning_once, "warning_once": _print_warning_once,
} }

View File

@ -14,13 +14,14 @@ from flashinfer import (BatchDecodeWithPagedKVCacheWrapper,
import vllm.envs as envs import vllm.envs as envs
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionType) AttentionType)
from vllm.attention.layer import Attention
from vllm.config import VllmConfig, get_layers_from_vllm_config
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.v1.attention.backends.flash_attn import use_cascade_attention from vllm.v1.attention.backends.flash_attn import use_cascade_attention
from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder, from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
CommonAttentionMetadata, CommonAttentionMetadata,
get_kv_cache_layout) PerLayerParameters,
get_kv_cache_layout,
get_per_layer_parameters,
infer_global_hyperparameters)
from vllm.v1.kv_cache_interface import AttentionSpec from vllm.v1.kv_cache_interface import AttentionSpec
from vllm.v1.worker.block_table import BlockTable from vllm.v1.worker.block_table import BlockTable
@ -93,70 +94,6 @@ class FlashInferBackend(AttentionBackend):
return stride_order return stride_order
@dataclass
class PerLayerParameters:
"""
Currently, FlashInfer backend only support models in which all layers share
the same values for the following hyperparameters.
"""
window_left: int
logits_soft_cap: Optional[float]
sm_scale: float
def get_per_layer_parameters(
vllm_config: VllmConfig) -> dict[str, PerLayerParameters]:
"""
Scan all attention layers and determine some hyperparameters
to use during `plan`.
"""
layers = get_layers_from_vllm_config(vllm_config, Attention)
per_layer_params: dict[str, PerLayerParameters] = {}
for key, layer in layers.items():
impl = layer.impl
assert isinstance(impl, FlashInferImpl)
# Infer hyperparameters from the attention layer
window_size = impl.sliding_window
window_left = window_size[0] if window_size is not None else -1
logits_soft_cap = impl.logits_soft_cap
sm_scale = impl.scale
per_layer_params[key] = PerLayerParameters(window_left,
logits_soft_cap, sm_scale)
return per_layer_params
def infer_global_hyperparameters(
per_layer_params: dict[str, PerLayerParameters]) -> PerLayerParameters:
"""
Currently, FlashInfer backend only support models in which all layers share
the same values for the following hyperparameters:
- `window_left`
- `logits_soft_cap`
- `sm_scale`
So this function asserts that all layers share the same values for these
hyperparameters and returns the global values.
"""
assert len(per_layer_params) > 0, "No attention layers found in the model."
param_sets = list(per_layer_params.values())
global_params = param_sets[0]
for params in param_sets:
assert params == global_params, (
"FlashInfer backend currently only supports models in which all "
"layers share the same values for the following hyperparameters: "
"`window_left`, `logits_soft_cap`, `sm_scale`.")
return global_params
@dataclass @dataclass
class FlashInferMetadata: class FlashInferMetadata:
@ -336,7 +273,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
def _plan(self, attn_metadata: FlashInferMetadata): def _plan(self, attn_metadata: FlashInferMetadata):
if self.global_hyperparameters is None: if self.global_hyperparameters is None:
self.global_hyperparameters = infer_global_hyperparameters( self.global_hyperparameters = infer_global_hyperparameters(
get_per_layer_parameters(self.vllm_config)) get_per_layer_parameters(self.vllm_config, FlashInferImpl))
if attn_metadata.use_cascade: if attn_metadata.use_cascade:
attn_metadata.cascade_wrapper = self._get_cascade_wrapper() attn_metadata.cascade_wrapper = self._get_cascade_wrapper()
attn_metadata.cascade_wrapper.plan( attn_metadata.cascade_wrapper.plan(

View File

@ -189,8 +189,8 @@ return curr_o @ W_O
import functools import functools
from abc import abstractmethod from abc import abstractmethod
from dataclasses import dataclass from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any, Generic, Optional, TypeVar from typing import TYPE_CHECKING, Any, Generic, Optional, TypeVar, Union
import torch import torch
@ -208,7 +208,9 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils import cdiv, round_down from vllm.utils import cdiv, round_down
from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder, from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
CommonAttentionMetadata) CommonAttentionMetadata,
get_per_layer_parameters,
infer_global_hyperparameters)
from vllm.v1.kv_cache_interface import AttentionSpec from vllm.v1.kv_cache_interface import AttentionSpec
from vllm.v1.worker.block_table import BlockTable from vllm.v1.worker.block_table import BlockTable
@ -221,6 +223,12 @@ except ImportError:
from flash_attn import flash_attn_varlen_func from flash_attn import flash_attn_varlen_func
is_vllm_fa = False is_vllm_fa = False
try:
from flashinfer import BatchPrefillWithRaggedKVCacheWrapper
flashinfer_available = True
except ImportError:
flashinfer_available = False
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.worker.gpu_input_batch import InputBatch from vllm.v1.worker.gpu_input_batch import InputBatch
@ -290,6 +298,13 @@ class MLACommonPrefillMetadata:
chunked_context: Optional[ChunkedContextMetadata] = None chunked_context: Optional[ChunkedContextMetadata] = None
@dataclass
class FlashInferPrefillMetadata(MLACommonPrefillMetadata):
prefill_main: Optional['BatchPrefillWithRaggedKVCacheWrapper'] = None
prefill_chunks: list['BatchPrefillWithRaggedKVCacheWrapper'] = field(
default_factory=list)
@dataclass @dataclass
class MLACommonDecodeMetadata: class MLACommonDecodeMetadata:
block_table: torch.Tensor block_table: torch.Tensor
@ -328,7 +343,8 @@ class MLACommonMetadata(Generic[D]):
head_dim: Optional[int] = None head_dim: Optional[int] = None
decode: Optional[D] = None decode: Optional[D] = None
prefill: Optional[MLACommonPrefillMetadata] = None prefill: Optional[Union[MLACommonPrefillMetadata,
FlashInferPrefillMetadata]] = None
def __post_init__(self): def __post_init__(self):
if self.head_dim is not None: if self.head_dim is not None:
@ -338,6 +354,20 @@ class MLACommonMetadata(Generic[D]):
M = TypeVar("M", bound=MLACommonMetadata) M = TypeVar("M", bound=MLACommonMetadata)
def use_flashinfer_prefill() -> bool:
if flashinfer_available:
# For blackwell default to flashinfer prefill if its available since
# its faster than FA2.
return current_platform.has_device_capability(100)
return False
# Currently 394MB, this can be tuned based on GEMM sizes used.
# Choosen to be the same as sglang:
# https://github.com/sgl-project/sglang/blob/766392c6bda2558b61ce6d1c1bfd8081a549e1f1/python/sglang/global_config.py#L37
FLASHINFER_WORKSPACE_BUFFER_SIZE = 394 * 1024 * 1024
class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]): class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
""" """
NOTE: Please read the comment at the top of the file before trying to NOTE: Please read the comment at the top of the file before trying to
@ -392,6 +422,101 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
) )
self.block_table = block_table self.block_table = block_table
self._use_fi_prefill = use_flashinfer_prefill()
self.prefill_metadata_cls = FlashInferPrefillMetadata \
if self._use_fi_prefill else MLACommonPrefillMetadata
if self._use_fi_prefill:
self._workspace_buffer = torch.empty(
FLASHINFER_WORKSPACE_BUFFER_SIZE,
dtype=torch.uint8,
device=runner.device)
self._fi_prefill_main: Optional[
BatchPrefillWithRaggedKVCacheWrapper] = None
self._fi_prefill_chunks: list[
BatchPrefillWithRaggedKVCacheWrapper] = []
self._global_hyperparameters = infer_global_hyperparameters(
get_per_layer_parameters(runner.vllm_config, MLACommonImpl))
def _build_fi_prefill_wrappers(self, prefill: FlashInferPrefillMetadata):
qo_indptr = prefill.query_start_loc
has_context = False
if prefill.chunked_context is not None:
chunked_context = prefill.chunked_context
has_context = True
if self._fi_prefill_main is None:
self._fi_prefill_main = BatchPrefillWithRaggedKVCacheWrapper(
self._workspace_buffer, "NHD", backend="cutlass")
if has_context:
num_chunks = chunked_context.cu_seq_lens.shape[0]
# Allocate more prefill chunk wrappers if needed
if len(self._fi_prefill_chunks) < num_chunks:
for _ in range(len(self._fi_prefill_chunks), num_chunks):
self._fi_prefill_chunks.append(
BatchPrefillWithRaggedKVCacheWrapper(
self._workspace_buffer, "NHD", backend="cutlass"))
assert num_chunks <= len(self._fi_prefill_chunks)
# In MLA, the non-latent num_qo_heads == num_kv_heads
num_qo_heads = self.runner.num_query_heads
num_kv_heads = num_qo_heads
# Sanity: Verify that num_kv_heads == 1 since it is latent space
assert self.kv_cache_spec.num_kv_heads == 1
# Get non-latent head_dim_qk and head_dim_vo
head_dim_qk = (self.mla_dims.qk_nope_head_dim +
self.mla_dims.qk_rope_head_dim)
head_dim_vo = self.mla_dims.v_head_dim
# For main run, qo_indptr == kv_indptr
kv_indptr = qo_indptr.clone()
# Prepare main prefill
self._fi_prefill_main.plan(
qo_indptr=qo_indptr,
kv_indptr=kv_indptr,
num_qo_heads=num_qo_heads,
num_kv_heads=num_kv_heads,
head_dim_qk=head_dim_qk,
head_dim_vo=head_dim_vo,
causal=True, # This is main run
sm_scale=self._global_hyperparameters.sm_scale,
window_left=self._global_hyperparameters.window_left,
logits_soft_cap=self._global_hyperparameters.logits_soft_cap,
q_data_type=self.runner.dtype,
kv_data_type=self.kv_cache_spec.dtype,
)
# Prepare context prefills
if has_context:
for i in range(num_chunks):
kv_indptr_chunk = chunked_context.cu_seq_lens[i]
self._fi_prefill_chunks[i].plan(
qo_indptr=qo_indptr,
kv_indptr=kv_indptr_chunk,
num_qo_heads=num_qo_heads,
num_kv_heads=num_kv_heads,
head_dim_qk=head_dim_qk,
head_dim_vo=head_dim_vo,
causal=False, # This is context run
sm_scale=self._global_hyperparameters.sm_scale,
window_left=self._global_hyperparameters.window_left,
logits_soft_cap=self._global_hyperparameters.
logits_soft_cap,
q_data_type=self.runner.dtype,
kv_data_type=self.kv_cache_spec.dtype,
)
prefill.prefill_main = self._fi_prefill_main
prefill.prefill_chunks = self._fi_prefill_chunks
def reorder_batch(self, input_batch: "InputBatch", def reorder_batch(self, input_batch: "InputBatch",
scheduler_output: "SchedulerOutput") -> bool: scheduler_output: "SchedulerOutput") -> bool:
# We now want to reorder the batch so that the "decode" requests are and # We now want to reorder the batch so that the "decode" requests are and
@ -572,7 +697,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
assert max(chunked_context_metadata.max_seq_lens) <= \ assert max(chunked_context_metadata.max_seq_lens) <= \
self.chunked_prefill_workspace_size self.chunked_prefill_workspace_size
prefill_metadata = MLACommonPrefillMetadata( prefill_metadata = self.prefill_metadata_cls(
block_table=block_table_tensor[reqs_start:, ...], block_table=block_table_tensor[reqs_start:, ...],
query_start_loc=prefill_query_start_loc, query_start_loc=prefill_query_start_loc,
max_query_len=max_query_len, max_query_len=max_query_len,
@ -586,7 +711,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
seq_lens=seq_lens[:self._num_decodes], seq_lens=seq_lens[:self._num_decodes],
) )
return self.metadata_cls( attn_metadata = self.metadata_cls(
num_actual_tokens=num_actual_tokens, num_actual_tokens=num_actual_tokens,
query_start_loc=query_start_loc, query_start_loc=query_start_loc,
slot_mapping=slot_mapping, slot_mapping=slot_mapping,
@ -599,6 +724,12 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
decode=decode_metadata, decode=decode_metadata,
) )
if self._use_fi_prefill and self._num_prefills > 0:
assert isinstance(attn_metadata.prefill, FlashInferPrefillMetadata)
self._build_fi_prefill_wrappers(attn_metadata.prefill)
return attn_metadata
def can_run_in_cudagraph( def can_run_in_cudagraph(
self, common_attn_metadata: CommonAttentionMetadata) -> bool: self, common_attn_metadata: CommonAttentionMetadata) -> bool:
return common_attn_metadata.max_query_len == 1 return common_attn_metadata.max_query_len == 1
@ -649,23 +780,34 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
self.v_head_dim = v_head_dim self.v_head_dim = v_head_dim
self.kv_b_proj = kv_b_proj self.kv_b_proj = kv_b_proj
# Handle the differences between the flash_attn_varlen from flash_attn if use_flashinfer_prefill():
# and the one from vllm_flash_attn. The former is used on RoCM and the logger.debug_once("Using FlashInfer prefill for MLA")
# latter has an additional parameter to control FA2 vs FA3 self._run_prefill_context_chunk = self._run_prefill_context_chunk_fi
self.flash_attn_varlen_func = flash_attn_varlen_func self._run_prefill_new_tokens = self._run_prefill_new_tokens_fi
self.vllm_flash_attn_version = get_flash_attn_version() self._pad_v = False
if self.vllm_flash_attn_version is not None: else: # Use FlashAttention
self.flash_attn_varlen_func = \ logger.debug_once("Using FlashAttention prefill for MLA")
functools.partial(flash_attn_varlen_func, self._run_prefill_context_chunk = self._run_prefill_context_chunk_fa
fa_version=self.vllm_flash_attn_version) self._run_prefill_new_tokens = self._run_prefill_new_tokens_fa
# For MLA the v head dim is smaller than qk head dim so we pad out # Handle the differences between the flash_attn_varlen from
# v with 0s to match the qk head dim for attention backends that do # flash_attn and the one from vllm_flash_attn. The former is used on
# not support different headdims # RoCM and the latter has an additional parameter to control
# We don't need to pad V if we are on a hopper system with FA3 # FA2 vs FA3
self._pad_v = self.vllm_flash_attn_version is None or not ( self.flash_attn_varlen_func = flash_attn_varlen_func
self.vllm_flash_attn_version == 3 self.vllm_flash_attn_version = get_flash_attn_version()
and current_platform.get_device_capability()[0] == 9) if self.vllm_flash_attn_version is not None:
self.flash_attn_varlen_func = \
functools.partial(flash_attn_varlen_func,
fa_version=self.vllm_flash_attn_version)
# For MLA the v head dim is smaller than qk head dim so we pad out
# v with 0s to match the qk head dim for attention backends that do
# not support different headdims
# We don't need to pad V if we are on a hopper system with FA3
self._pad_v = self.vllm_flash_attn_version is None or not (
self.vllm_flash_attn_version == 3
and current_platform.get_device_capability()[0] == 9)
def _flash_attn_varlen_diff_headdims(self, def _flash_attn_varlen_diff_headdims(self,
q, q,
@ -705,6 +847,58 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
return attn_out, lse return attn_out, lse
return attn_out return attn_out
def _run_prefill_new_tokens_fa(self, prefill: MLACommonPrefillMetadata, q,
k, v, return_softmax_lse):
return self._flash_attn_varlen_diff_headdims(
q=q,
k=k,
v=v,
cu_seqlens_q=prefill.query_start_loc,
cu_seqlens_k=prefill.query_start_loc,
max_seqlen_q=prefill.max_query_len,
max_seqlen_k=prefill.max_query_len,
softmax_scale=self.scale,
causal=True,
return_softmax_lse=return_softmax_lse,
)
def _run_prefill_new_tokens_fi(self, prefill: MLACommonPrefillMetadata, q,
k, v, return_softmax_lse):
assert isinstance(prefill, FlashInferPrefillMetadata)
assert prefill.prefill_main is not None
return prefill.prefill_main.run(
q=q,
k=k,
v=v,
return_lse=return_softmax_lse,
)
def _run_prefill_context_chunk_fa(self, prefill: MLACommonPrefillMetadata,
chunk_idx: int, q, k, v):
assert prefill.chunked_context is not None
return self._flash_attn_varlen_diff_headdims(
q=q,
k=k,
v=v,
cu_seqlens_q=prefill.query_start_loc,
cu_seqlens_k=prefill.chunked_context.cu_seq_lens[chunk_idx],
max_seqlen_q=prefill.max_query_len,
max_seqlen_k=prefill.chunked_context.max_seq_lens[chunk_idx],
softmax_scale=self.scale,
causal=False, # Context is unmasked
return_softmax_lse=True,
)
def _run_prefill_context_chunk_fi(self, prefill: MLACommonPrefillMetadata,
chunk_idx: int, q, k, v):
assert isinstance(prefill, FlashInferPrefillMetadata)
return prefill.prefill_chunks[chunk_idx].run(
q=q,
k=k,
v=v,
return_lse=True,
)
def _v_up_proj(self, x): def _v_up_proj(self, x):
# Convert from (B, N, L) to (N, B, L) # Convert from (B, N, L) to (N, B, L)
x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1) x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1)
@ -803,18 +997,12 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))),
dim=-1) dim=-1)
attn_output, attn_softmax_lse = \ attn_output, attn_softmax_lse = self._run_prefill_context_chunk(
self._flash_attn_varlen_diff_headdims( prefill=prefill_metadata,
chunk_idx=i,
q=q, q=q,
k=k, k=k,
v=v, v=v,
cu_seqlens_q=prefill_metadata.query_start_loc,
cu_seqlens_k=prefill_metadata.chunked_context.cu_seq_lens[i],
max_seqlen_q=prefill_metadata.max_query_len,
max_seqlen_k=prefill_metadata.chunked_context.max_seq_lens[i],
softmax_scale=self.scale,
causal=False, # Context is unmasked
return_softmax_lse=True,
) )
if output is None: if output is None:
@ -854,16 +1042,11 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1) k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1)
output = self._flash_attn_varlen_diff_headdims( output = self._run_prefill_new_tokens(
prefill=attn_metadata.prefill,
q=q, q=q,
k=k, k=k,
v=v, v=v,
cu_seqlens_q=attn_metadata.prefill.query_start_loc,
cu_seqlens_k=attn_metadata.prefill.query_start_loc,
max_seqlen_q=attn_metadata.prefill.max_query_len,
max_seqlen_k=attn_metadata.prefill.max_query_len,
softmax_scale=self.scale,
causal=True,
return_softmax_lse=has_context, return_softmax_lse=has_context,
) )
@ -908,7 +1091,6 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
output: Optional[torch.Tensor] = None, output: Optional[torch.Tensor] = None,
output_scale: Optional[torch.Tensor] = None, output_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
assert output is not None, "Output tensor must be provided." assert output is not None, "Output tensor must be provided."
if output_scale is not None: if output_scale is not None:

View File

@ -91,6 +91,7 @@ class CutlassMLAImpl(MLACommonImpl[MLACommonMetadata]):
# Clone q_nope and q_pe to make sure strides computation is correct. # Clone q_nope and q_pe to make sure strides computation is correct.
q_nope = q_nope.clone() q_nope = q_nope.clone()
q_pe = q_pe.clone() q_pe = q_pe.clone()
ops.cutlass_mla_decode(o, q_nope, q_pe, kv_c_and_k_pe_cache, ops.cutlass_mla_decode(o, q_nope, q_pe, kv_c_and_k_pe_cache,
attn_metadata.decode.seq_lens, attn_metadata.decode.seq_lens,
attn_metadata.decode.block_table, self.scale) attn_metadata.decode.block_table, self.scale)

View File

@ -4,14 +4,17 @@ import abc
import functools import functools
from abc import abstractmethod from abc import abstractmethod
from dataclasses import dataclass from dataclasses import dataclass
from typing import TYPE_CHECKING, ClassVar, Generic, TypeVar from typing import TYPE_CHECKING, ClassVar, Generic, Optional, TypeVar
import numpy as np import numpy as np
import torch import torch
from vllm.attention.layer import Attention
from vllm.config import VllmConfig, get_layers_from_vllm_config
from vllm.utils import cdiv from vllm.utils import cdiv
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionImpl
from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.worker.gpu_input_batch import InputBatch from vllm.v1.worker.gpu_input_batch import InputBatch
@ -98,39 +101,6 @@ class AttentionMetadataBuilder(abc.ABC, Generic[M]):
return False return False
def validate_kv_sharing_target(current_layer_name, target_layer_name,
static_forward_context):
error_msg = (f"Specified KV sharing target layer for {current_layer_name} "
f"is not valid: target layer {target_layer_name} ")
if current_layer_name == target_layer_name:
raise ValueError(error_msg +
"cannot be the same as the current layer.")
if target_layer_name not in static_forward_context:
from vllm.model_executor.models.utils import extract_layer_index
# If target layer name is not in the static fwd context, it means either
# a) the target layer does not come BEFORE the current layer, or
# b) the target layer is not an Attention layer that exists in the model
current_layer_idx = extract_layer_index(current_layer_name)
target_layer_idx = extract_layer_index(target_layer_name)
if current_layer_idx <= target_layer_idx:
raise ValueError(error_msg + "must come before the current layer.")
else:
raise ValueError(error_msg +
"is not a valid Attention layer in the model.")
# Currently KV sharing is only supported between layers of the same type
target_layer_attn_type = static_forward_context[
target_layer_name].attn_type
expected = static_forward_context[current_layer_name].attn_type
if target_layer_attn_type != expected:
raise ValueError(
error_msg +
f"must be the same type as the current layer ({expected}).")
@functools.lru_cache @functools.lru_cache
def get_kv_cache_layout(): def get_kv_cache_layout():
# Override with format specified by the user. # Override with format specified by the user.
@ -144,6 +114,71 @@ def get_kv_cache_layout():
return cache_layout return cache_layout
@dataclass
class PerLayerParameters:
"""
Currently, FlashInfer backend only support models in which all layers share
the same values for the following hyperparameters.
"""
window_left: int
logits_soft_cap: Optional[float]
sm_scale: float
def get_per_layer_parameters(
vllm_config: VllmConfig,
cls_: type['AttentionImpl']) -> dict[str, PerLayerParameters]:
"""
Scan all attention layers and determine some hyperparameters
to use during `plan`.
"""
layers = get_layers_from_vllm_config(vllm_config, Attention)
per_layer_params: dict[str, PerLayerParameters] = {}
for key, layer in layers.items():
impl = layer.impl
assert isinstance(impl, cls_)
# Infer hyperparameters from the attention layer
window_size = getattr(impl, "sliding_window", None)
window_left = window_size[0] if window_size is not None else -1
logits_soft_cap = getattr(impl, "logits_soft_cap", None)
sm_scale = impl.scale
per_layer_params[key] = PerLayerParameters(window_left,
logits_soft_cap, sm_scale)
return per_layer_params
def infer_global_hyperparameters(
per_layer_params: dict[str, PerLayerParameters]) -> PerLayerParameters:
"""
Currently, FlashInfer backend only support models in which all layers share
the same values for the following hyperparameters:
- `window_left`
- `logits_soft_cap`
- `sm_scale`
So this function asserts that all layers share the same values for these
hyperparameters and returns the global values.
"""
assert len(per_layer_params) > 0, "No attention layers found in the model."
param_sets = list(per_layer_params.values())
global_params = param_sets[0]
for params in param_sets:
assert params == global_params, (
"FlashInfer backend currently only supports models in which all "
"layers share the same values for the following hyperparameters: "
"`window_left`, `logits_soft_cap`, `sm_scale`.")
return global_params
# #
# Take in `query_start_loc_np` and `seq_lens_np` and break the sequences into # Take in `query_start_loc_np` and `seq_lens_np` and break the sequences into
# local attention blocks, where each block is passed to the attention kernel # local attention blocks, where each block is passed to the attention kernel