[Attention] MLA - Flashinfer Ragged Prefill (#20034)

This commit is contained in:
Alexander Matveev 2025-07-10 23:17:47 -04:00 committed by GitHub
parent 922f316441
commit 5b032352cc
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 421 additions and 214 deletions

View File

View File

@ -3,16 +3,10 @@
import filecmp
import shutil
import tempfile
from collections import defaultdict
from pathlib import Path
from vllm import LLM, SamplingParams
from vllm.config import KVTransferConfig, VllmConfig
from vllm.distributed.kv_transfer.kv_connector.factory import (
KVConnectorFactory)
from vllm.distributed.kv_transfer.kv_connector.v1.shared_storage_connector import ( # noqa
SharedStorageConnector)
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
from vllm.config import KVTransferConfig
MODEL_NAME = "meta-llama/Llama-3.2-1B-Instruct"
@ -25,65 +19,6 @@ PROMPTS = [
SAMPLING_PARAMS = SamplingParams(temperature=0, max_tokens=20)
class TestSharedStorageConnector(SharedStorageConnector):
def __init__(self, config: VllmConfig, role):
self.name = config.kv_transfer_config.kv_connector_extra_config["name"]
self._connector = SharedStorageConnector(config, role)
self.call_record: dict[str, int] = defaultdict(int)
# Use a unique temp file per connector
self._event_file = tempfile.gettempdir(
) + f"/connector_{self.name}-{self.role.name}_events.log"
# Start with an empty file
with open(self._event_file, "w") as _:
pass
def __getattribute__(self, name):
if name in ("_connector", "call_record", "name", "_event_file",
"__class__", "__dict__", "__getattribute__",
"__init__"): # avoid recursion
return object.__getattribute__(self, name)
if not hasattr(self._connector, name):
return object.__getattribute__(self, name)
attr = getattr(self._connector, name)
# Intercept calls to the connector interface and write an event
# for each one to a file, which can be read back in the main test proc.
if callable(attr):
def wrapper(*args, **kwargs):
self.call_record[name] += 1
# Include args that we're interested in
to_log = [name]
for arg in args:
if isinstance(arg, int):
to_log.append(str(arg))
elif isinstance(arg, KVCacheBlocks):
to_log.append(
f"num_blocks={[len(b) for b in arg.blocks]}")
# Log the event as a line to the file
try:
with open(self._event_file, "a") as f:
f.write(' '.join(to_log) + "\n")
except Exception as e:
print(f"[ERROR] Could not log event {name} "
f"for {self.name}: {e}")
return attr(*args, **kwargs)
return wrapper
return attr
# This relies on "fork" multiprocessing method being used.
# It's the default but vLLM may fall back to spawn if for example CUDA
# is already initialized.
KVConnectorFactory.register_connector("TestSharedStorageConnector",
TestSharedStorageConnector.__module__,
TestSharedStorageConnector.__name__)
# Helper function to compare directories recursively
def _compare_directories(dir1: Path, dir2: Path) -> bool:
"""Compares two directories recursively for identical content."""
@ -118,19 +53,27 @@ def test_multi_shared_storage_connector_consistency():
kv_role="kv_both",
kv_connector_extra_config={
"connectors": [{
"kv_connector": "TestSharedStorageConnector",
"kv_role": "kv_both",
"kv_connector":
"TestSharedStorageConnector",
"kv_role":
"kv_both",
"kv_connector_extra_config": {
"shared_storage_path": str(storage_1_path),
"name": "storage1",
}
},
"kv_connector_module_path":
"tests.v1.kv_connector.unit.utils",
}, {
"kv_connector": "TestSharedStorageConnector",
"kv_role": "kv_both",
"kv_connector":
"TestSharedStorageConnector",
"kv_role":
"kv_both",
"kv_connector_extra_config": {
"shared_storage_path": str(storage_2_path),
"name": "storage2",
}
},
"kv_connector_module_path":
"tests.v1.kv_connector.unit.utils",
}]
},
)

View File

@ -1,5 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import tempfile
from collections import defaultdict
from typing import Any, Optional
import torch
@ -7,6 +9,11 @@ import torch
from vllm import SamplingParams
from vllm.config import (CacheConfig, DeviceConfig, KVTransferConfig,
ModelConfig, SchedulerConfig, VllmConfig)
from vllm.distributed.kv_transfer.kv_connector.factory import (
KVConnectorFactory)
from vllm.distributed.kv_transfer.kv_connector.v1.shared_storage_connector import ( # noqa
SharedStorageConnector)
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
from vllm.v1.core.sched.scheduler import Scheduler
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
KVCacheGroupSpec)
@ -187,3 +194,58 @@ def create_model_runner_output(
finished_sending=finished_sending,
finished_recving=finished_recving,
)
class TestSharedStorageConnector(SharedStorageConnector):
def __init__(self, config: VllmConfig, role):
self.name = config.kv_transfer_config.kv_connector_extra_config["name"]
self._connector = SharedStorageConnector(config, role)
self.call_record: dict[str, int] = defaultdict(int)
# Use a unique temp file per connector
self._event_file = tempfile.gettempdir(
) + f"/connector_{self.name}-{self.role.name}_events.log"
# Start with an empty file
with open(self._event_file, "w") as _:
pass
def __getattribute__(self, name):
if name in ("_connector", "call_record", "name", "_event_file",
"__class__", "__dict__", "__getattribute__",
"__init__"): # avoid recursion
return object.__getattribute__(self, name)
if not hasattr(self._connector, name):
return object.__getattribute__(self, name)
attr = getattr(self._connector, name)
# Intercept calls to the connector interface and write an event
# for each one to a file, which can be read back in the main test proc.
if callable(attr):
def wrapper(*args, **kwargs):
self.call_record[name] += 1
# Include args that we're interested in
to_log = [name]
for arg in args:
if isinstance(arg, int):
to_log.append(str(arg))
elif isinstance(arg, KVCacheBlocks):
to_log.append(
f"num_blocks={[len(b) for b in arg.blocks]}")
# Log the event as a line to the file
try:
with open(self._event_file, "a") as f:
f.write(' '.join(to_log) + "\n")
except Exception as e:
print(f"[ERROR] Could not log event {name} "
f"for {self.name}: {e}")
return attr(*args, **kwargs)
return wrapper
return attr
KVConnectorFactory.register_connector("TestSharedStorageConnector", __name__,
TestSharedStorageConnector.__name__)

View File

@ -10,6 +10,7 @@ import torch.nn.functional as F
import vllm.envs as envs
from vllm.attention import AttentionType
from vllm.attention.selector import backend_name_to_enum, get_attn_backend
from vllm.attention.utils.kv_sharing_utils import validate_kv_sharing_target
from vllm.config import CacheConfig, get_current_vllm_config
from vllm.distributed.kv_transfer import (get_kv_transfer_group,
has_kv_transfer_group,
@ -21,7 +22,6 @@ from vllm.model_executor.layers.quantization.base_config import (
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
from vllm.platforms import _Backend, current_platform
from vllm.utils import direct_register_custom_op
from vllm.v1.attention.backends.utils import validate_kv_sharing_target
class Attention(nn.Module):

View File

@ -0,0 +1,33 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
def validate_kv_sharing_target(current_layer_name, target_layer_name,
static_forward_context):
error_msg = (f"Specified KV sharing target layer for {current_layer_name} "
f"is not valid: target layer {target_layer_name} ")
if current_layer_name == target_layer_name:
raise ValueError(error_msg +
"cannot be the same as the current layer.")
if target_layer_name not in static_forward_context:
from vllm.model_executor.models.utils import extract_layer_index
# If target layer name is not in the static fwd context, it means either
# a) the target layer does not come BEFORE the current layer, or
# b) the target layer is not an Attention layer that exists in the model
current_layer_idx = extract_layer_index(current_layer_name)
target_layer_idx = extract_layer_index(target_layer_name)
if current_layer_idx <= target_layer_idx:
raise ValueError(error_msg + "must come before the current layer.")
else:
raise ValueError(error_msg +
"is not a valid Attention layer in the model.")
# Currently KV sharing is only supported between layers of the same type
target_layer_attn_type = static_forward_context[
target_layer_name].attn_type
expected = static_forward_context[current_layer_name].attn_type
if target_layer_attn_type != expected:
raise ValueError(
error_msg +
f"must be the same type as the current layer ({expected}).")

View File

@ -53,6 +53,12 @@ DEFAULT_LOGGING_CONFIG = {
}
@lru_cache
def _print_debug_once(logger: Logger, msg: str, *args: Hashable) -> None:
# Set the stacklevel to 2 to print the original caller's line info
logger.debug(msg, *args, stacklevel=2)
@lru_cache
def _print_info_once(logger: Logger, msg: str, *args: Hashable) -> None:
# Set the stacklevel to 2 to print the original caller's line info
@ -74,6 +80,13 @@ class _VllmLogger(Logger):
`intel_extension_for_pytorch.utils._logger`.
"""
def debug_once(self, msg: str, *args: Hashable) -> None:
"""
As [`debug`][logging.Logger.debug], but subsequent calls with
the same message are silently dropped.
"""
_print_debug_once(self, msg, *args)
def info_once(self, msg: str, *args: Hashable) -> None:
"""
As [`info`][logging.Logger.info], but subsequent calls with
@ -132,6 +145,7 @@ def init_logger(name: str) -> _VllmLogger:
logger = logging.getLogger(name)
methods_to_patch = {
"debug_once": _print_debug_once,
"info_once": _print_info_once,
"warning_once": _print_warning_once,
}

View File

@ -14,13 +14,14 @@ from flashinfer import (BatchDecodeWithPagedKVCacheWrapper,
import vllm.envs as envs
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionType)
from vllm.attention.layer import Attention
from vllm.config import VllmConfig, get_layers_from_vllm_config
from vllm.logger import init_logger
from vllm.v1.attention.backends.flash_attn import use_cascade_attention
from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
CommonAttentionMetadata,
get_kv_cache_layout)
PerLayerParameters,
get_kv_cache_layout,
get_per_layer_parameters,
infer_global_hyperparameters)
from vllm.v1.kv_cache_interface import AttentionSpec
from vllm.v1.worker.block_table import BlockTable
@ -93,70 +94,6 @@ class FlashInferBackend(AttentionBackend):
return stride_order
@dataclass
class PerLayerParameters:
"""
Currently, FlashInfer backend only support models in which all layers share
the same values for the following hyperparameters.
"""
window_left: int
logits_soft_cap: Optional[float]
sm_scale: float
def get_per_layer_parameters(
vllm_config: VllmConfig) -> dict[str, PerLayerParameters]:
"""
Scan all attention layers and determine some hyperparameters
to use during `plan`.
"""
layers = get_layers_from_vllm_config(vllm_config, Attention)
per_layer_params: dict[str, PerLayerParameters] = {}
for key, layer in layers.items():
impl = layer.impl
assert isinstance(impl, FlashInferImpl)
# Infer hyperparameters from the attention layer
window_size = impl.sliding_window
window_left = window_size[0] if window_size is not None else -1
logits_soft_cap = impl.logits_soft_cap
sm_scale = impl.scale
per_layer_params[key] = PerLayerParameters(window_left,
logits_soft_cap, sm_scale)
return per_layer_params
def infer_global_hyperparameters(
per_layer_params: dict[str, PerLayerParameters]) -> PerLayerParameters:
"""
Currently, FlashInfer backend only support models in which all layers share
the same values for the following hyperparameters:
- `window_left`
- `logits_soft_cap`
- `sm_scale`
So this function asserts that all layers share the same values for these
hyperparameters and returns the global values.
"""
assert len(per_layer_params) > 0, "No attention layers found in the model."
param_sets = list(per_layer_params.values())
global_params = param_sets[0]
for params in param_sets:
assert params == global_params, (
"FlashInfer backend currently only supports models in which all "
"layers share the same values for the following hyperparameters: "
"`window_left`, `logits_soft_cap`, `sm_scale`.")
return global_params
@dataclass
class FlashInferMetadata:
@ -336,7 +273,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
def _plan(self, attn_metadata: FlashInferMetadata):
if self.global_hyperparameters is None:
self.global_hyperparameters = infer_global_hyperparameters(
get_per_layer_parameters(self.vllm_config))
get_per_layer_parameters(self.vllm_config, FlashInferImpl))
if attn_metadata.use_cascade:
attn_metadata.cascade_wrapper = self._get_cascade_wrapper()
attn_metadata.cascade_wrapper.plan(

View File

@ -189,8 +189,8 @@ return curr_o @ W_O
import functools
from abc import abstractmethod
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Generic, Optional, TypeVar
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any, Generic, Optional, TypeVar, Union
import torch
@ -208,7 +208,9 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
from vllm.platforms import current_platform
from vllm.utils import cdiv, round_down
from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
CommonAttentionMetadata)
CommonAttentionMetadata,
get_per_layer_parameters,
infer_global_hyperparameters)
from vllm.v1.kv_cache_interface import AttentionSpec
from vllm.v1.worker.block_table import BlockTable
@ -221,6 +223,12 @@ except ImportError:
from flash_attn import flash_attn_varlen_func
is_vllm_fa = False
try:
from flashinfer import BatchPrefillWithRaggedKVCacheWrapper
flashinfer_available = True
except ImportError:
flashinfer_available = False
if TYPE_CHECKING:
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.worker.gpu_input_batch import InputBatch
@ -290,6 +298,13 @@ class MLACommonPrefillMetadata:
chunked_context: Optional[ChunkedContextMetadata] = None
@dataclass
class FlashInferPrefillMetadata(MLACommonPrefillMetadata):
prefill_main: Optional['BatchPrefillWithRaggedKVCacheWrapper'] = None
prefill_chunks: list['BatchPrefillWithRaggedKVCacheWrapper'] = field(
default_factory=list)
@dataclass
class MLACommonDecodeMetadata:
block_table: torch.Tensor
@ -328,7 +343,8 @@ class MLACommonMetadata(Generic[D]):
head_dim: Optional[int] = None
decode: Optional[D] = None
prefill: Optional[MLACommonPrefillMetadata] = None
prefill: Optional[Union[MLACommonPrefillMetadata,
FlashInferPrefillMetadata]] = None
def __post_init__(self):
if self.head_dim is not None:
@ -338,6 +354,20 @@ class MLACommonMetadata(Generic[D]):
M = TypeVar("M", bound=MLACommonMetadata)
def use_flashinfer_prefill() -> bool:
if flashinfer_available:
# For blackwell default to flashinfer prefill if its available since
# its faster than FA2.
return current_platform.has_device_capability(100)
return False
# Currently 394MB, this can be tuned based on GEMM sizes used.
# Choosen to be the same as sglang:
# https://github.com/sgl-project/sglang/blob/766392c6bda2558b61ce6d1c1bfd8081a549e1f1/python/sglang/global_config.py#L37
FLASHINFER_WORKSPACE_BUFFER_SIZE = 394 * 1024 * 1024
class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
"""
NOTE: Please read the comment at the top of the file before trying to
@ -392,6 +422,101 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
)
self.block_table = block_table
self._use_fi_prefill = use_flashinfer_prefill()
self.prefill_metadata_cls = FlashInferPrefillMetadata \
if self._use_fi_prefill else MLACommonPrefillMetadata
if self._use_fi_prefill:
self._workspace_buffer = torch.empty(
FLASHINFER_WORKSPACE_BUFFER_SIZE,
dtype=torch.uint8,
device=runner.device)
self._fi_prefill_main: Optional[
BatchPrefillWithRaggedKVCacheWrapper] = None
self._fi_prefill_chunks: list[
BatchPrefillWithRaggedKVCacheWrapper] = []
self._global_hyperparameters = infer_global_hyperparameters(
get_per_layer_parameters(runner.vllm_config, MLACommonImpl))
def _build_fi_prefill_wrappers(self, prefill: FlashInferPrefillMetadata):
qo_indptr = prefill.query_start_loc
has_context = False
if prefill.chunked_context is not None:
chunked_context = prefill.chunked_context
has_context = True
if self._fi_prefill_main is None:
self._fi_prefill_main = BatchPrefillWithRaggedKVCacheWrapper(
self._workspace_buffer, "NHD", backend="cutlass")
if has_context:
num_chunks = chunked_context.cu_seq_lens.shape[0]
# Allocate more prefill chunk wrappers if needed
if len(self._fi_prefill_chunks) < num_chunks:
for _ in range(len(self._fi_prefill_chunks), num_chunks):
self._fi_prefill_chunks.append(
BatchPrefillWithRaggedKVCacheWrapper(
self._workspace_buffer, "NHD", backend="cutlass"))
assert num_chunks <= len(self._fi_prefill_chunks)
# In MLA, the non-latent num_qo_heads == num_kv_heads
num_qo_heads = self.runner.num_query_heads
num_kv_heads = num_qo_heads
# Sanity: Verify that num_kv_heads == 1 since it is latent space
assert self.kv_cache_spec.num_kv_heads == 1
# Get non-latent head_dim_qk and head_dim_vo
head_dim_qk = (self.mla_dims.qk_nope_head_dim +
self.mla_dims.qk_rope_head_dim)
head_dim_vo = self.mla_dims.v_head_dim
# For main run, qo_indptr == kv_indptr
kv_indptr = qo_indptr.clone()
# Prepare main prefill
self._fi_prefill_main.plan(
qo_indptr=qo_indptr,
kv_indptr=kv_indptr,
num_qo_heads=num_qo_heads,
num_kv_heads=num_kv_heads,
head_dim_qk=head_dim_qk,
head_dim_vo=head_dim_vo,
causal=True, # This is main run
sm_scale=self._global_hyperparameters.sm_scale,
window_left=self._global_hyperparameters.window_left,
logits_soft_cap=self._global_hyperparameters.logits_soft_cap,
q_data_type=self.runner.dtype,
kv_data_type=self.kv_cache_spec.dtype,
)
# Prepare context prefills
if has_context:
for i in range(num_chunks):
kv_indptr_chunk = chunked_context.cu_seq_lens[i]
self._fi_prefill_chunks[i].plan(
qo_indptr=qo_indptr,
kv_indptr=kv_indptr_chunk,
num_qo_heads=num_qo_heads,
num_kv_heads=num_kv_heads,
head_dim_qk=head_dim_qk,
head_dim_vo=head_dim_vo,
causal=False, # This is context run
sm_scale=self._global_hyperparameters.sm_scale,
window_left=self._global_hyperparameters.window_left,
logits_soft_cap=self._global_hyperparameters.
logits_soft_cap,
q_data_type=self.runner.dtype,
kv_data_type=self.kv_cache_spec.dtype,
)
prefill.prefill_main = self._fi_prefill_main
prefill.prefill_chunks = self._fi_prefill_chunks
def reorder_batch(self, input_batch: "InputBatch",
scheduler_output: "SchedulerOutput") -> bool:
# We now want to reorder the batch so that the "decode" requests are and
@ -572,7 +697,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
assert max(chunked_context_metadata.max_seq_lens) <= \
self.chunked_prefill_workspace_size
prefill_metadata = MLACommonPrefillMetadata(
prefill_metadata = self.prefill_metadata_cls(
block_table=block_table_tensor[reqs_start:, ...],
query_start_loc=prefill_query_start_loc,
max_query_len=max_query_len,
@ -586,7 +711,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
seq_lens=seq_lens[:self._num_decodes],
)
return self.metadata_cls(
attn_metadata = self.metadata_cls(
num_actual_tokens=num_actual_tokens,
query_start_loc=query_start_loc,
slot_mapping=slot_mapping,
@ -599,6 +724,12 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
decode=decode_metadata,
)
if self._use_fi_prefill and self._num_prefills > 0:
assert isinstance(attn_metadata.prefill, FlashInferPrefillMetadata)
self._build_fi_prefill_wrappers(attn_metadata.prefill)
return attn_metadata
def can_run_in_cudagraph(
self, common_attn_metadata: CommonAttentionMetadata) -> bool:
return common_attn_metadata.max_query_len == 1
@ -649,23 +780,34 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
self.v_head_dim = v_head_dim
self.kv_b_proj = kv_b_proj
# Handle the differences between the flash_attn_varlen from flash_attn
# and the one from vllm_flash_attn. The former is used on RoCM and the
# latter has an additional parameter to control FA2 vs FA3
self.flash_attn_varlen_func = flash_attn_varlen_func
self.vllm_flash_attn_version = get_flash_attn_version()
if self.vllm_flash_attn_version is not None:
self.flash_attn_varlen_func = \
functools.partial(flash_attn_varlen_func,
fa_version=self.vllm_flash_attn_version)
if use_flashinfer_prefill():
logger.debug_once("Using FlashInfer prefill for MLA")
self._run_prefill_context_chunk = self._run_prefill_context_chunk_fi
self._run_prefill_new_tokens = self._run_prefill_new_tokens_fi
self._pad_v = False
else: # Use FlashAttention
logger.debug_once("Using FlashAttention prefill for MLA")
self._run_prefill_context_chunk = self._run_prefill_context_chunk_fa
self._run_prefill_new_tokens = self._run_prefill_new_tokens_fa
# For MLA the v head dim is smaller than qk head dim so we pad out
# v with 0s to match the qk head dim for attention backends that do
# not support different headdims
# We don't need to pad V if we are on a hopper system with FA3
self._pad_v = self.vllm_flash_attn_version is None or not (
self.vllm_flash_attn_version == 3
and current_platform.get_device_capability()[0] == 9)
# Handle the differences between the flash_attn_varlen from
# flash_attn and the one from vllm_flash_attn. The former is used on
# RoCM and the latter has an additional parameter to control
# FA2 vs FA3
self.flash_attn_varlen_func = flash_attn_varlen_func
self.vllm_flash_attn_version = get_flash_attn_version()
if self.vllm_flash_attn_version is not None:
self.flash_attn_varlen_func = \
functools.partial(flash_attn_varlen_func,
fa_version=self.vllm_flash_attn_version)
# For MLA the v head dim is smaller than qk head dim so we pad out
# v with 0s to match the qk head dim for attention backends that do
# not support different headdims
# We don't need to pad V if we are on a hopper system with FA3
self._pad_v = self.vllm_flash_attn_version is None or not (
self.vllm_flash_attn_version == 3
and current_platform.get_device_capability()[0] == 9)
def _flash_attn_varlen_diff_headdims(self,
q,
@ -705,6 +847,58 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
return attn_out, lse
return attn_out
def _run_prefill_new_tokens_fa(self, prefill: MLACommonPrefillMetadata, q,
k, v, return_softmax_lse):
return self._flash_attn_varlen_diff_headdims(
q=q,
k=k,
v=v,
cu_seqlens_q=prefill.query_start_loc,
cu_seqlens_k=prefill.query_start_loc,
max_seqlen_q=prefill.max_query_len,
max_seqlen_k=prefill.max_query_len,
softmax_scale=self.scale,
causal=True,
return_softmax_lse=return_softmax_lse,
)
def _run_prefill_new_tokens_fi(self, prefill: MLACommonPrefillMetadata, q,
k, v, return_softmax_lse):
assert isinstance(prefill, FlashInferPrefillMetadata)
assert prefill.prefill_main is not None
return prefill.prefill_main.run(
q=q,
k=k,
v=v,
return_lse=return_softmax_lse,
)
def _run_prefill_context_chunk_fa(self, prefill: MLACommonPrefillMetadata,
chunk_idx: int, q, k, v):
assert prefill.chunked_context is not None
return self._flash_attn_varlen_diff_headdims(
q=q,
k=k,
v=v,
cu_seqlens_q=prefill.query_start_loc,
cu_seqlens_k=prefill.chunked_context.cu_seq_lens[chunk_idx],
max_seqlen_q=prefill.max_query_len,
max_seqlen_k=prefill.chunked_context.max_seq_lens[chunk_idx],
softmax_scale=self.scale,
causal=False, # Context is unmasked
return_softmax_lse=True,
)
def _run_prefill_context_chunk_fi(self, prefill: MLACommonPrefillMetadata,
chunk_idx: int, q, k, v):
assert isinstance(prefill, FlashInferPrefillMetadata)
return prefill.prefill_chunks[chunk_idx].run(
q=q,
k=k,
v=v,
return_lse=True,
)
def _v_up_proj(self, x):
# Convert from (B, N, L) to (N, B, L)
x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1)
@ -803,18 +997,12 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))),
dim=-1)
attn_output, attn_softmax_lse = \
self._flash_attn_varlen_diff_headdims(
attn_output, attn_softmax_lse = self._run_prefill_context_chunk(
prefill=prefill_metadata,
chunk_idx=i,
q=q,
k=k,
v=v,
cu_seqlens_q=prefill_metadata.query_start_loc,
cu_seqlens_k=prefill_metadata.chunked_context.cu_seq_lens[i],
max_seqlen_q=prefill_metadata.max_query_len,
max_seqlen_k=prefill_metadata.chunked_context.max_seq_lens[i],
softmax_scale=self.scale,
causal=False, # Context is unmasked
return_softmax_lse=True,
)
if output is None:
@ -854,16 +1042,11 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1)
output = self._flash_attn_varlen_diff_headdims(
output = self._run_prefill_new_tokens(
prefill=attn_metadata.prefill,
q=q,
k=k,
v=v,
cu_seqlens_q=attn_metadata.prefill.query_start_loc,
cu_seqlens_k=attn_metadata.prefill.query_start_loc,
max_seqlen_q=attn_metadata.prefill.max_query_len,
max_seqlen_k=attn_metadata.prefill.max_query_len,
softmax_scale=self.scale,
causal=True,
return_softmax_lse=has_context,
)
@ -908,7 +1091,6 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
output: Optional[torch.Tensor] = None,
output_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
assert output is not None, "Output tensor must be provided."
if output_scale is not None:

View File

@ -91,6 +91,7 @@ class CutlassMLAImpl(MLACommonImpl[MLACommonMetadata]):
# Clone q_nope and q_pe to make sure strides computation is correct.
q_nope = q_nope.clone()
q_pe = q_pe.clone()
ops.cutlass_mla_decode(o, q_nope, q_pe, kv_c_and_k_pe_cache,
attn_metadata.decode.seq_lens,
attn_metadata.decode.block_table, self.scale)

View File

@ -4,14 +4,17 @@ import abc
import functools
from abc import abstractmethod
from dataclasses import dataclass
from typing import TYPE_CHECKING, ClassVar, Generic, TypeVar
from typing import TYPE_CHECKING, ClassVar, Generic, Optional, TypeVar
import numpy as np
import torch
from vllm.attention.layer import Attention
from vllm.config import VllmConfig, get_layers_from_vllm_config
from vllm.utils import cdiv
if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionImpl
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.worker.gpu_input_batch import InputBatch
@ -98,39 +101,6 @@ class AttentionMetadataBuilder(abc.ABC, Generic[M]):
return False
def validate_kv_sharing_target(current_layer_name, target_layer_name,
static_forward_context):
error_msg = (f"Specified KV sharing target layer for {current_layer_name} "
f"is not valid: target layer {target_layer_name} ")
if current_layer_name == target_layer_name:
raise ValueError(error_msg +
"cannot be the same as the current layer.")
if target_layer_name not in static_forward_context:
from vllm.model_executor.models.utils import extract_layer_index
# If target layer name is not in the static fwd context, it means either
# a) the target layer does not come BEFORE the current layer, or
# b) the target layer is not an Attention layer that exists in the model
current_layer_idx = extract_layer_index(current_layer_name)
target_layer_idx = extract_layer_index(target_layer_name)
if current_layer_idx <= target_layer_idx:
raise ValueError(error_msg + "must come before the current layer.")
else:
raise ValueError(error_msg +
"is not a valid Attention layer in the model.")
# Currently KV sharing is only supported between layers of the same type
target_layer_attn_type = static_forward_context[
target_layer_name].attn_type
expected = static_forward_context[current_layer_name].attn_type
if target_layer_attn_type != expected:
raise ValueError(
error_msg +
f"must be the same type as the current layer ({expected}).")
@functools.lru_cache
def get_kv_cache_layout():
# Override with format specified by the user.
@ -144,6 +114,71 @@ def get_kv_cache_layout():
return cache_layout
@dataclass
class PerLayerParameters:
"""
Currently, FlashInfer backend only support models in which all layers share
the same values for the following hyperparameters.
"""
window_left: int
logits_soft_cap: Optional[float]
sm_scale: float
def get_per_layer_parameters(
vllm_config: VllmConfig,
cls_: type['AttentionImpl']) -> dict[str, PerLayerParameters]:
"""
Scan all attention layers and determine some hyperparameters
to use during `plan`.
"""
layers = get_layers_from_vllm_config(vllm_config, Attention)
per_layer_params: dict[str, PerLayerParameters] = {}
for key, layer in layers.items():
impl = layer.impl
assert isinstance(impl, cls_)
# Infer hyperparameters from the attention layer
window_size = getattr(impl, "sliding_window", None)
window_left = window_size[0] if window_size is not None else -1
logits_soft_cap = getattr(impl, "logits_soft_cap", None)
sm_scale = impl.scale
per_layer_params[key] = PerLayerParameters(window_left,
logits_soft_cap, sm_scale)
return per_layer_params
def infer_global_hyperparameters(
per_layer_params: dict[str, PerLayerParameters]) -> PerLayerParameters:
"""
Currently, FlashInfer backend only support models in which all layers share
the same values for the following hyperparameters:
- `window_left`
- `logits_soft_cap`
- `sm_scale`
So this function asserts that all layers share the same values for these
hyperparameters and returns the global values.
"""
assert len(per_layer_params) > 0, "No attention layers found in the model."
param_sets = list(per_layer_params.values())
global_params = param_sets[0]
for params in param_sets:
assert params == global_params, (
"FlashInfer backend currently only supports models in which all "
"layers share the same values for the following hyperparameters: "
"`window_left`, `logits_soft_cap`, `sm_scale`.")
return global_params
#
# Take in `query_start_loc_np` and `seq_lens_np` and break the sequences into
# local attention blocks, where each block is passed to the attention kernel