diff --git a/tests/v1/kv_connector/__init__.py b/tests/v1/kv_connector/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/tests/v1/kv_connector/unit/test_multi_connector.py b/tests/v1/kv_connector/unit/test_multi_connector.py index e82691cd05e25..b1780d8a9af80 100644 --- a/tests/v1/kv_connector/unit/test_multi_connector.py +++ b/tests/v1/kv_connector/unit/test_multi_connector.py @@ -3,16 +3,10 @@ import filecmp import shutil import tempfile -from collections import defaultdict from pathlib import Path from vllm import LLM, SamplingParams -from vllm.config import KVTransferConfig, VllmConfig -from vllm.distributed.kv_transfer.kv_connector.factory import ( - KVConnectorFactory) -from vllm.distributed.kv_transfer.kv_connector.v1.shared_storage_connector import ( # noqa - SharedStorageConnector) -from vllm.v1.core.kv_cache_manager import KVCacheBlocks +from vllm.config import KVTransferConfig MODEL_NAME = "meta-llama/Llama-3.2-1B-Instruct" @@ -25,65 +19,6 @@ PROMPTS = [ SAMPLING_PARAMS = SamplingParams(temperature=0, max_tokens=20) -class TestSharedStorageConnector(SharedStorageConnector): - - def __init__(self, config: VllmConfig, role): - self.name = config.kv_transfer_config.kv_connector_extra_config["name"] - self._connector = SharedStorageConnector(config, role) - self.call_record: dict[str, int] = defaultdict(int) - # Use a unique temp file per connector - self._event_file = tempfile.gettempdir( - ) + f"/connector_{self.name}-{self.role.name}_events.log" - # Start with an empty file - with open(self._event_file, "w") as _: - pass - - def __getattribute__(self, name): - if name in ("_connector", "call_record", "name", "_event_file", - "__class__", "__dict__", "__getattribute__", - "__init__"): # avoid recursion - return object.__getattribute__(self, name) - if not hasattr(self._connector, name): - return object.__getattribute__(self, name) - attr = getattr(self._connector, name) - - # Intercept calls to the connector interface and write an event - # for each one to a file, which can be read back in the main test proc. - if callable(attr): - - def wrapper(*args, **kwargs): - self.call_record[name] += 1 - - # Include args that we're interested in - to_log = [name] - for arg in args: - if isinstance(arg, int): - to_log.append(str(arg)) - elif isinstance(arg, KVCacheBlocks): - to_log.append( - f"num_blocks={[len(b) for b in arg.blocks]}") - - # Log the event as a line to the file - try: - with open(self._event_file, "a") as f: - f.write(' '.join(to_log) + "\n") - except Exception as e: - print(f"[ERROR] Could not log event {name} " - f"for {self.name}: {e}") - return attr(*args, **kwargs) - - return wrapper - return attr - - -# This relies on "fork" multiprocessing method being used. -# It's the default but vLLM may fall back to spawn if for example CUDA -# is already initialized. -KVConnectorFactory.register_connector("TestSharedStorageConnector", - TestSharedStorageConnector.__module__, - TestSharedStorageConnector.__name__) - - # Helper function to compare directories recursively def _compare_directories(dir1: Path, dir2: Path) -> bool: """Compares two directories recursively for identical content.""" @@ -118,19 +53,27 @@ def test_multi_shared_storage_connector_consistency(): kv_role="kv_both", kv_connector_extra_config={ "connectors": [{ - "kv_connector": "TestSharedStorageConnector", - "kv_role": "kv_both", + "kv_connector": + "TestSharedStorageConnector", + "kv_role": + "kv_both", "kv_connector_extra_config": { "shared_storage_path": str(storage_1_path), "name": "storage1", - } + }, + "kv_connector_module_path": + "tests.v1.kv_connector.unit.utils", }, { - "kv_connector": "TestSharedStorageConnector", - "kv_role": "kv_both", + "kv_connector": + "TestSharedStorageConnector", + "kv_role": + "kv_both", "kv_connector_extra_config": { "shared_storage_path": str(storage_2_path), "name": "storage2", - } + }, + "kv_connector_module_path": + "tests.v1.kv_connector.unit.utils", }] }, ) diff --git a/tests/v1/kv_connector/unit/utils.py b/tests/v1/kv_connector/unit/utils.py index 983d900606fc9..cf20d44fbaaed 100644 --- a/tests/v1/kv_connector/unit/utils.py +++ b/tests/v1/kv_connector/unit/utils.py @@ -1,5 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import tempfile +from collections import defaultdict from typing import Any, Optional import torch @@ -7,6 +9,11 @@ import torch from vllm import SamplingParams from vllm.config import (CacheConfig, DeviceConfig, KVTransferConfig, ModelConfig, SchedulerConfig, VllmConfig) +from vllm.distributed.kv_transfer.kv_connector.factory import ( + KVConnectorFactory) +from vllm.distributed.kv_transfer.kv_connector.v1.shared_storage_connector import ( # noqa + SharedStorageConnector) +from vllm.v1.core.kv_cache_manager import KVCacheBlocks from vllm.v1.core.sched.scheduler import Scheduler from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, KVCacheGroupSpec) @@ -187,3 +194,58 @@ def create_model_runner_output( finished_sending=finished_sending, finished_recving=finished_recving, ) + + +class TestSharedStorageConnector(SharedStorageConnector): + + def __init__(self, config: VllmConfig, role): + self.name = config.kv_transfer_config.kv_connector_extra_config["name"] + self._connector = SharedStorageConnector(config, role) + self.call_record: dict[str, int] = defaultdict(int) + # Use a unique temp file per connector + self._event_file = tempfile.gettempdir( + ) + f"/connector_{self.name}-{self.role.name}_events.log" + # Start with an empty file + with open(self._event_file, "w") as _: + pass + + def __getattribute__(self, name): + if name in ("_connector", "call_record", "name", "_event_file", + "__class__", "__dict__", "__getattribute__", + "__init__"): # avoid recursion + return object.__getattribute__(self, name) + if not hasattr(self._connector, name): + return object.__getattribute__(self, name) + attr = getattr(self._connector, name) + + # Intercept calls to the connector interface and write an event + # for each one to a file, which can be read back in the main test proc. + if callable(attr): + + def wrapper(*args, **kwargs): + self.call_record[name] += 1 + + # Include args that we're interested in + to_log = [name] + for arg in args: + if isinstance(arg, int): + to_log.append(str(arg)) + elif isinstance(arg, KVCacheBlocks): + to_log.append( + f"num_blocks={[len(b) for b in arg.blocks]}") + + # Log the event as a line to the file + try: + with open(self._event_file, "a") as f: + f.write(' '.join(to_log) + "\n") + except Exception as e: + print(f"[ERROR] Could not log event {name} " + f"for {self.name}: {e}") + return attr(*args, **kwargs) + + return wrapper + return attr + + +KVConnectorFactory.register_connector("TestSharedStorageConnector", __name__, + TestSharedStorageConnector.__name__) diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index f0ad68b16405e..3d5746837beae 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -10,6 +10,7 @@ import torch.nn.functional as F import vllm.envs as envs from vllm.attention import AttentionType from vllm.attention.selector import backend_name_to_enum, get_attn_backend +from vllm.attention.utils.kv_sharing_utils import validate_kv_sharing_target from vllm.config import CacheConfig, get_current_vllm_config from vllm.distributed.kv_transfer import (get_kv_transfer_group, has_kv_transfer_group, @@ -21,7 +22,6 @@ from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod from vllm.platforms import _Backend, current_platform from vllm.utils import direct_register_custom_op -from vllm.v1.attention.backends.utils import validate_kv_sharing_target class Attention(nn.Module): diff --git a/vllm/attention/utils/kv_sharing_utils.py b/vllm/attention/utils/kv_sharing_utils.py new file mode 100644 index 0000000000000..b4ae8bdf4d762 --- /dev/null +++ b/vllm/attention/utils/kv_sharing_utils.py @@ -0,0 +1,33 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +def validate_kv_sharing_target(current_layer_name, target_layer_name, + static_forward_context): + error_msg = (f"Specified KV sharing target layer for {current_layer_name} " + f"is not valid: target layer {target_layer_name} ") + + if current_layer_name == target_layer_name: + raise ValueError(error_msg + + "cannot be the same as the current layer.") + + if target_layer_name not in static_forward_context: + from vllm.model_executor.models.utils import extract_layer_index + + # If target layer name is not in the static fwd context, it means either + # a) the target layer does not come BEFORE the current layer, or + # b) the target layer is not an Attention layer that exists in the model + current_layer_idx = extract_layer_index(current_layer_name) + target_layer_idx = extract_layer_index(target_layer_name) + if current_layer_idx <= target_layer_idx: + raise ValueError(error_msg + "must come before the current layer.") + else: + raise ValueError(error_msg + + "is not a valid Attention layer in the model.") + + # Currently KV sharing is only supported between layers of the same type + target_layer_attn_type = static_forward_context[ + target_layer_name].attn_type + expected = static_forward_context[current_layer_name].attn_type + if target_layer_attn_type != expected: + raise ValueError( + error_msg + + f"must be the same type as the current layer ({expected}).") diff --git a/vllm/logger.py b/vllm/logger.py index 0ddb83cb8ba7a..69aaf4390a7db 100644 --- a/vllm/logger.py +++ b/vllm/logger.py @@ -53,6 +53,12 @@ DEFAULT_LOGGING_CONFIG = { } +@lru_cache +def _print_debug_once(logger: Logger, msg: str, *args: Hashable) -> None: + # Set the stacklevel to 2 to print the original caller's line info + logger.debug(msg, *args, stacklevel=2) + + @lru_cache def _print_info_once(logger: Logger, msg: str, *args: Hashable) -> None: # Set the stacklevel to 2 to print the original caller's line info @@ -74,6 +80,13 @@ class _VllmLogger(Logger): `intel_extension_for_pytorch.utils._logger`. """ + def debug_once(self, msg: str, *args: Hashable) -> None: + """ + As [`debug`][logging.Logger.debug], but subsequent calls with + the same message are silently dropped. + """ + _print_debug_once(self, msg, *args) + def info_once(self, msg: str, *args: Hashable) -> None: """ As [`info`][logging.Logger.info], but subsequent calls with @@ -132,6 +145,7 @@ def init_logger(name: str) -> _VllmLogger: logger = logging.getLogger(name) methods_to_patch = { + "debug_once": _print_debug_once, "info_once": _print_info_once, "warning_once": _print_warning_once, } diff --git a/vllm/v1/attention/backends/flashinfer.py b/vllm/v1/attention/backends/flashinfer.py index 860309faa9053..4cca618f6b3c9 100755 --- a/vllm/v1/attention/backends/flashinfer.py +++ b/vllm/v1/attention/backends/flashinfer.py @@ -14,13 +14,14 @@ from flashinfer import (BatchDecodeWithPagedKVCacheWrapper, import vllm.envs as envs from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, AttentionType) -from vllm.attention.layer import Attention -from vllm.config import VllmConfig, get_layers_from_vllm_config from vllm.logger import init_logger from vllm.v1.attention.backends.flash_attn import use_cascade_attention from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder, CommonAttentionMetadata, - get_kv_cache_layout) + PerLayerParameters, + get_kv_cache_layout, + get_per_layer_parameters, + infer_global_hyperparameters) from vllm.v1.kv_cache_interface import AttentionSpec from vllm.v1.worker.block_table import BlockTable @@ -93,70 +94,6 @@ class FlashInferBackend(AttentionBackend): return stride_order -@dataclass -class PerLayerParameters: - """ - Currently, FlashInfer backend only support models in which all layers share - the same values for the following hyperparameters. - """ - - window_left: int - logits_soft_cap: Optional[float] - sm_scale: float - - -def get_per_layer_parameters( - vllm_config: VllmConfig) -> dict[str, PerLayerParameters]: - """ - Scan all attention layers and determine some hyperparameters - to use during `plan`. - """ - - layers = get_layers_from_vllm_config(vllm_config, Attention) - per_layer_params: dict[str, PerLayerParameters] = {} - - for key, layer in layers.items(): - impl = layer.impl - assert isinstance(impl, FlashInferImpl) - - # Infer hyperparameters from the attention layer - window_size = impl.sliding_window - window_left = window_size[0] if window_size is not None else -1 - logits_soft_cap = impl.logits_soft_cap - sm_scale = impl.scale - - per_layer_params[key] = PerLayerParameters(window_left, - logits_soft_cap, sm_scale) - - return per_layer_params - - -def infer_global_hyperparameters( - per_layer_params: dict[str, PerLayerParameters]) -> PerLayerParameters: - """ - Currently, FlashInfer backend only support models in which all layers share - the same values for the following hyperparameters: - - `window_left` - - `logits_soft_cap` - - `sm_scale` - - So this function asserts that all layers share the same values for these - hyperparameters and returns the global values. - """ - - assert len(per_layer_params) > 0, "No attention layers found in the model." - - param_sets = list(per_layer_params.values()) - global_params = param_sets[0] - for params in param_sets: - assert params == global_params, ( - "FlashInfer backend currently only supports models in which all " - "layers share the same values for the following hyperparameters: " - "`window_left`, `logits_soft_cap`, `sm_scale`.") - - return global_params - - @dataclass class FlashInferMetadata: @@ -336,7 +273,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): def _plan(self, attn_metadata: FlashInferMetadata): if self.global_hyperparameters is None: self.global_hyperparameters = infer_global_hyperparameters( - get_per_layer_parameters(self.vllm_config)) + get_per_layer_parameters(self.vllm_config, FlashInferImpl)) if attn_metadata.use_cascade: attn_metadata.cascade_wrapper = self._get_cascade_wrapper() attn_metadata.cascade_wrapper.plan( diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py index f2aaf59a40f88..970de229e139e 100644 --- a/vllm/v1/attention/backends/mla/common.py +++ b/vllm/v1/attention/backends/mla/common.py @@ -189,8 +189,8 @@ return curr_o @ W_O import functools from abc import abstractmethod -from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, Generic, Optional, TypeVar +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, Any, Generic, Optional, TypeVar, Union import torch @@ -208,7 +208,9 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.platforms import current_platform from vllm.utils import cdiv, round_down from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder, - CommonAttentionMetadata) + CommonAttentionMetadata, + get_per_layer_parameters, + infer_global_hyperparameters) from vllm.v1.kv_cache_interface import AttentionSpec from vllm.v1.worker.block_table import BlockTable @@ -221,6 +223,12 @@ except ImportError: from flash_attn import flash_attn_varlen_func is_vllm_fa = False +try: + from flashinfer import BatchPrefillWithRaggedKVCacheWrapper + flashinfer_available = True +except ImportError: + flashinfer_available = False + if TYPE_CHECKING: from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.worker.gpu_input_batch import InputBatch @@ -290,6 +298,13 @@ class MLACommonPrefillMetadata: chunked_context: Optional[ChunkedContextMetadata] = None +@dataclass +class FlashInferPrefillMetadata(MLACommonPrefillMetadata): + prefill_main: Optional['BatchPrefillWithRaggedKVCacheWrapper'] = None + prefill_chunks: list['BatchPrefillWithRaggedKVCacheWrapper'] = field( + default_factory=list) + + @dataclass class MLACommonDecodeMetadata: block_table: torch.Tensor @@ -328,7 +343,8 @@ class MLACommonMetadata(Generic[D]): head_dim: Optional[int] = None decode: Optional[D] = None - prefill: Optional[MLACommonPrefillMetadata] = None + prefill: Optional[Union[MLACommonPrefillMetadata, + FlashInferPrefillMetadata]] = None def __post_init__(self): if self.head_dim is not None: @@ -338,6 +354,20 @@ class MLACommonMetadata(Generic[D]): M = TypeVar("M", bound=MLACommonMetadata) +def use_flashinfer_prefill() -> bool: + if flashinfer_available: + # For blackwell default to flashinfer prefill if its available since + # its faster than FA2. + return current_platform.has_device_capability(100) + return False + + +# Currently 394MB, this can be tuned based on GEMM sizes used. +# Choosen to be the same as sglang: +# https://github.com/sgl-project/sglang/blob/766392c6bda2558b61ce6d1c1bfd8081a549e1f1/python/sglang/global_config.py#L37 +FLASHINFER_WORKSPACE_BUFFER_SIZE = 394 * 1024 * 1024 + + class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]): """ NOTE: Please read the comment at the top of the file before trying to @@ -392,6 +422,101 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]): ) self.block_table = block_table + self._use_fi_prefill = use_flashinfer_prefill() + self.prefill_metadata_cls = FlashInferPrefillMetadata \ + if self._use_fi_prefill else MLACommonPrefillMetadata + + if self._use_fi_prefill: + self._workspace_buffer = torch.empty( + FLASHINFER_WORKSPACE_BUFFER_SIZE, + dtype=torch.uint8, + device=runner.device) + + self._fi_prefill_main: Optional[ + BatchPrefillWithRaggedKVCacheWrapper] = None + self._fi_prefill_chunks: list[ + BatchPrefillWithRaggedKVCacheWrapper] = [] + + self._global_hyperparameters = infer_global_hyperparameters( + get_per_layer_parameters(runner.vllm_config, MLACommonImpl)) + + def _build_fi_prefill_wrappers(self, prefill: FlashInferPrefillMetadata): + qo_indptr = prefill.query_start_loc + + has_context = False + if prefill.chunked_context is not None: + chunked_context = prefill.chunked_context + has_context = True + + if self._fi_prefill_main is None: + self._fi_prefill_main = BatchPrefillWithRaggedKVCacheWrapper( + self._workspace_buffer, "NHD", backend="cutlass") + + if has_context: + num_chunks = chunked_context.cu_seq_lens.shape[0] + # Allocate more prefill chunk wrappers if needed + if len(self._fi_prefill_chunks) < num_chunks: + for _ in range(len(self._fi_prefill_chunks), num_chunks): + self._fi_prefill_chunks.append( + BatchPrefillWithRaggedKVCacheWrapper( + self._workspace_buffer, "NHD", backend="cutlass")) + assert num_chunks <= len(self._fi_prefill_chunks) + + # In MLA, the non-latent num_qo_heads == num_kv_heads + num_qo_heads = self.runner.num_query_heads + num_kv_heads = num_qo_heads + + # Sanity: Verify that num_kv_heads == 1 since it is latent space + assert self.kv_cache_spec.num_kv_heads == 1 + + # Get non-latent head_dim_qk and head_dim_vo + head_dim_qk = (self.mla_dims.qk_nope_head_dim + + self.mla_dims.qk_rope_head_dim) + head_dim_vo = self.mla_dims.v_head_dim + + # For main run, qo_indptr == kv_indptr + kv_indptr = qo_indptr.clone() + + # Prepare main prefill + self._fi_prefill_main.plan( + qo_indptr=qo_indptr, + kv_indptr=kv_indptr, + num_qo_heads=num_qo_heads, + num_kv_heads=num_kv_heads, + head_dim_qk=head_dim_qk, + head_dim_vo=head_dim_vo, + causal=True, # This is main run + sm_scale=self._global_hyperparameters.sm_scale, + window_left=self._global_hyperparameters.window_left, + logits_soft_cap=self._global_hyperparameters.logits_soft_cap, + q_data_type=self.runner.dtype, + kv_data_type=self.kv_cache_spec.dtype, + ) + + # Prepare context prefills + if has_context: + for i in range(num_chunks): + kv_indptr_chunk = chunked_context.cu_seq_lens[i] + + self._fi_prefill_chunks[i].plan( + qo_indptr=qo_indptr, + kv_indptr=kv_indptr_chunk, + num_qo_heads=num_qo_heads, + num_kv_heads=num_kv_heads, + head_dim_qk=head_dim_qk, + head_dim_vo=head_dim_vo, + causal=False, # This is context run + sm_scale=self._global_hyperparameters.sm_scale, + window_left=self._global_hyperparameters.window_left, + logits_soft_cap=self._global_hyperparameters. + logits_soft_cap, + q_data_type=self.runner.dtype, + kv_data_type=self.kv_cache_spec.dtype, + ) + + prefill.prefill_main = self._fi_prefill_main + prefill.prefill_chunks = self._fi_prefill_chunks + def reorder_batch(self, input_batch: "InputBatch", scheduler_output: "SchedulerOutput") -> bool: # We now want to reorder the batch so that the "decode" requests are and @@ -572,7 +697,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]): assert max(chunked_context_metadata.max_seq_lens) <= \ self.chunked_prefill_workspace_size - prefill_metadata = MLACommonPrefillMetadata( + prefill_metadata = self.prefill_metadata_cls( block_table=block_table_tensor[reqs_start:, ...], query_start_loc=prefill_query_start_loc, max_query_len=max_query_len, @@ -586,7 +711,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]): seq_lens=seq_lens[:self._num_decodes], ) - return self.metadata_cls( + attn_metadata = self.metadata_cls( num_actual_tokens=num_actual_tokens, query_start_loc=query_start_loc, slot_mapping=slot_mapping, @@ -599,6 +724,12 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]): decode=decode_metadata, ) + if self._use_fi_prefill and self._num_prefills > 0: + assert isinstance(attn_metadata.prefill, FlashInferPrefillMetadata) + self._build_fi_prefill_wrappers(attn_metadata.prefill) + + return attn_metadata + def can_run_in_cudagraph( self, common_attn_metadata: CommonAttentionMetadata) -> bool: return common_attn_metadata.max_query_len == 1 @@ -649,23 +780,34 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): self.v_head_dim = v_head_dim self.kv_b_proj = kv_b_proj - # Handle the differences between the flash_attn_varlen from flash_attn - # and the one from vllm_flash_attn. The former is used on RoCM and the - # latter has an additional parameter to control FA2 vs FA3 - self.flash_attn_varlen_func = flash_attn_varlen_func - self.vllm_flash_attn_version = get_flash_attn_version() - if self.vllm_flash_attn_version is not None: - self.flash_attn_varlen_func = \ - functools.partial(flash_attn_varlen_func, - fa_version=self.vllm_flash_attn_version) + if use_flashinfer_prefill(): + logger.debug_once("Using FlashInfer prefill for MLA") + self._run_prefill_context_chunk = self._run_prefill_context_chunk_fi + self._run_prefill_new_tokens = self._run_prefill_new_tokens_fi + self._pad_v = False + else: # Use FlashAttention + logger.debug_once("Using FlashAttention prefill for MLA") + self._run_prefill_context_chunk = self._run_prefill_context_chunk_fa + self._run_prefill_new_tokens = self._run_prefill_new_tokens_fa - # For MLA the v head dim is smaller than qk head dim so we pad out - # v with 0s to match the qk head dim for attention backends that do - # not support different headdims - # We don't need to pad V if we are on a hopper system with FA3 - self._pad_v = self.vllm_flash_attn_version is None or not ( - self.vllm_flash_attn_version == 3 - and current_platform.get_device_capability()[0] == 9) + # Handle the differences between the flash_attn_varlen from + # flash_attn and the one from vllm_flash_attn. The former is used on + # RoCM and the latter has an additional parameter to control + # FA2 vs FA3 + self.flash_attn_varlen_func = flash_attn_varlen_func + self.vllm_flash_attn_version = get_flash_attn_version() + if self.vllm_flash_attn_version is not None: + self.flash_attn_varlen_func = \ + functools.partial(flash_attn_varlen_func, + fa_version=self.vllm_flash_attn_version) + + # For MLA the v head dim is smaller than qk head dim so we pad out + # v with 0s to match the qk head dim for attention backends that do + # not support different headdims + # We don't need to pad V if we are on a hopper system with FA3 + self._pad_v = self.vllm_flash_attn_version is None or not ( + self.vllm_flash_attn_version == 3 + and current_platform.get_device_capability()[0] == 9) def _flash_attn_varlen_diff_headdims(self, q, @@ -705,6 +847,58 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): return attn_out, lse return attn_out + def _run_prefill_new_tokens_fa(self, prefill: MLACommonPrefillMetadata, q, + k, v, return_softmax_lse): + return self._flash_attn_varlen_diff_headdims( + q=q, + k=k, + v=v, + cu_seqlens_q=prefill.query_start_loc, + cu_seqlens_k=prefill.query_start_loc, + max_seqlen_q=prefill.max_query_len, + max_seqlen_k=prefill.max_query_len, + softmax_scale=self.scale, + causal=True, + return_softmax_lse=return_softmax_lse, + ) + + def _run_prefill_new_tokens_fi(self, prefill: MLACommonPrefillMetadata, q, + k, v, return_softmax_lse): + assert isinstance(prefill, FlashInferPrefillMetadata) + assert prefill.prefill_main is not None + return prefill.prefill_main.run( + q=q, + k=k, + v=v, + return_lse=return_softmax_lse, + ) + + def _run_prefill_context_chunk_fa(self, prefill: MLACommonPrefillMetadata, + chunk_idx: int, q, k, v): + assert prefill.chunked_context is not None + return self._flash_attn_varlen_diff_headdims( + q=q, + k=k, + v=v, + cu_seqlens_q=prefill.query_start_loc, + cu_seqlens_k=prefill.chunked_context.cu_seq_lens[chunk_idx], + max_seqlen_q=prefill.max_query_len, + max_seqlen_k=prefill.chunked_context.max_seq_lens[chunk_idx], + softmax_scale=self.scale, + causal=False, # Context is unmasked + return_softmax_lse=True, + ) + + def _run_prefill_context_chunk_fi(self, prefill: MLACommonPrefillMetadata, + chunk_idx: int, q, k, v): + assert isinstance(prefill, FlashInferPrefillMetadata) + return prefill.prefill_chunks[chunk_idx].run( + q=q, + k=k, + v=v, + return_lse=True, + ) + def _v_up_proj(self, x): # Convert from (B, N, L) to (N, B, L) x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1) @@ -803,18 +997,12 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1) - attn_output, attn_softmax_lse = \ - self._flash_attn_varlen_diff_headdims( + attn_output, attn_softmax_lse = self._run_prefill_context_chunk( + prefill=prefill_metadata, + chunk_idx=i, q=q, k=k, v=v, - cu_seqlens_q=prefill_metadata.query_start_loc, - cu_seqlens_k=prefill_metadata.chunked_context.cu_seq_lens[i], - max_seqlen_q=prefill_metadata.max_query_len, - max_seqlen_k=prefill_metadata.chunked_context.max_seq_lens[i], - softmax_scale=self.scale, - causal=False, # Context is unmasked - return_softmax_lse=True, ) if output is None: @@ -854,16 +1042,11 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1) - output = self._flash_attn_varlen_diff_headdims( + output = self._run_prefill_new_tokens( + prefill=attn_metadata.prefill, q=q, k=k, v=v, - cu_seqlens_q=attn_metadata.prefill.query_start_loc, - cu_seqlens_k=attn_metadata.prefill.query_start_loc, - max_seqlen_q=attn_metadata.prefill.max_query_len, - max_seqlen_k=attn_metadata.prefill.max_query_len, - softmax_scale=self.scale, - causal=True, return_softmax_lse=has_context, ) @@ -908,7 +1091,6 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): output: Optional[torch.Tensor] = None, output_scale: Optional[torch.Tensor] = None, ) -> torch.Tensor: - assert output is not None, "Output tensor must be provided." if output_scale is not None: diff --git a/vllm/v1/attention/backends/mla/cutlass_mla.py b/vllm/v1/attention/backends/mla/cutlass_mla.py index db4b9c9537e5f..b2116bf114318 100644 --- a/vllm/v1/attention/backends/mla/cutlass_mla.py +++ b/vllm/v1/attention/backends/mla/cutlass_mla.py @@ -91,6 +91,7 @@ class CutlassMLAImpl(MLACommonImpl[MLACommonMetadata]): # Clone q_nope and q_pe to make sure strides computation is correct. q_nope = q_nope.clone() q_pe = q_pe.clone() + ops.cutlass_mla_decode(o, q_nope, q_pe, kv_c_and_k_pe_cache, attn_metadata.decode.seq_lens, attn_metadata.decode.block_table, self.scale) diff --git a/vllm/v1/attention/backends/utils.py b/vllm/v1/attention/backends/utils.py index b0ebb00d9e6b9..3787b39a81be5 100644 --- a/vllm/v1/attention/backends/utils.py +++ b/vllm/v1/attention/backends/utils.py @@ -4,14 +4,17 @@ import abc import functools from abc import abstractmethod from dataclasses import dataclass -from typing import TYPE_CHECKING, ClassVar, Generic, TypeVar +from typing import TYPE_CHECKING, ClassVar, Generic, Optional, TypeVar import numpy as np import torch +from vllm.attention.layer import Attention +from vllm.config import VllmConfig, get_layers_from_vllm_config from vllm.utils import cdiv if TYPE_CHECKING: + from vllm.attention.backends.abstract import AttentionImpl from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.worker.gpu_input_batch import InputBatch @@ -98,39 +101,6 @@ class AttentionMetadataBuilder(abc.ABC, Generic[M]): return False -def validate_kv_sharing_target(current_layer_name, target_layer_name, - static_forward_context): - error_msg = (f"Specified KV sharing target layer for {current_layer_name} " - f"is not valid: target layer {target_layer_name} ") - - if current_layer_name == target_layer_name: - raise ValueError(error_msg + - "cannot be the same as the current layer.") - - if target_layer_name not in static_forward_context: - from vllm.model_executor.models.utils import extract_layer_index - - # If target layer name is not in the static fwd context, it means either - # a) the target layer does not come BEFORE the current layer, or - # b) the target layer is not an Attention layer that exists in the model - current_layer_idx = extract_layer_index(current_layer_name) - target_layer_idx = extract_layer_index(target_layer_name) - if current_layer_idx <= target_layer_idx: - raise ValueError(error_msg + "must come before the current layer.") - else: - raise ValueError(error_msg + - "is not a valid Attention layer in the model.") - - # Currently KV sharing is only supported between layers of the same type - target_layer_attn_type = static_forward_context[ - target_layer_name].attn_type - expected = static_forward_context[current_layer_name].attn_type - if target_layer_attn_type != expected: - raise ValueError( - error_msg + - f"must be the same type as the current layer ({expected}).") - - @functools.lru_cache def get_kv_cache_layout(): # Override with format specified by the user. @@ -144,6 +114,71 @@ def get_kv_cache_layout(): return cache_layout +@dataclass +class PerLayerParameters: + """ + Currently, FlashInfer backend only support models in which all layers share + the same values for the following hyperparameters. + """ + + window_left: int + logits_soft_cap: Optional[float] + sm_scale: float + + +def get_per_layer_parameters( + vllm_config: VllmConfig, + cls_: type['AttentionImpl']) -> dict[str, PerLayerParameters]: + """ + Scan all attention layers and determine some hyperparameters + to use during `plan`. + """ + + layers = get_layers_from_vllm_config(vllm_config, Attention) + per_layer_params: dict[str, PerLayerParameters] = {} + + for key, layer in layers.items(): + impl = layer.impl + assert isinstance(impl, cls_) + + # Infer hyperparameters from the attention layer + window_size = getattr(impl, "sliding_window", None) + window_left = window_size[0] if window_size is not None else -1 + logits_soft_cap = getattr(impl, "logits_soft_cap", None) + sm_scale = impl.scale + + per_layer_params[key] = PerLayerParameters(window_left, + logits_soft_cap, sm_scale) + + return per_layer_params + + +def infer_global_hyperparameters( + per_layer_params: dict[str, PerLayerParameters]) -> PerLayerParameters: + """ + Currently, FlashInfer backend only support models in which all layers share + the same values for the following hyperparameters: + - `window_left` + - `logits_soft_cap` + - `sm_scale` + + So this function asserts that all layers share the same values for these + hyperparameters and returns the global values. + """ + + assert len(per_layer_params) > 0, "No attention layers found in the model." + + param_sets = list(per_layer_params.values()) + global_params = param_sets[0] + for params in param_sets: + assert params == global_params, ( + "FlashInfer backend currently only supports models in which all " + "layers share the same values for the following hyperparameters: " + "`window_left`, `logits_soft_cap`, `sm_scale`.") + + return global_params + + # # Take in `query_start_loc_np` and `seq_lens_np` and break the sequences into # local attention blocks, where each block is passed to the attention kernel