diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index ec705126c710d..487bba76babf1 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -15,14 +15,10 @@ from vllm.attention.backends.abstract import AttentionBackend, MLAAttentionImpl from vllm.attention.backends.registry import AttentionBackendEnum from vllm.attention.selector import get_attn_backend from vllm.attention.utils.kv_sharing_utils import validate_kv_sharing_target +from vllm.attention.utils.kv_transfer_utils import maybe_transfer_kv_layer from vllm.config import CacheConfig, get_current_vllm_config from vllm.config.multimodal import MultiModalConfig from vllm.config.vllm import VllmConfig -from vllm.distributed.kv_transfer import ( - get_kv_transfer_group, - has_kv_transfer_group, - is_v1_kv_transfer_group, -) from vllm.forward_context import ForwardContext, get_forward_context from vllm.logger import init_logger from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase @@ -842,41 +838,6 @@ class MLAAttention(nn.Module, AttentionLayerBase): ) -def wait_for_kv_layer_from_connector(layer_name: str): - if not has_kv_transfer_group() or not is_v1_kv_transfer_group(): - return - - connector = get_kv_transfer_group() - if not connector.has_connector_metadata(): - return - - forward_context: ForwardContext = get_forward_context() - attn_metadata = forward_context.attn_metadata - if attn_metadata is None: - return - assert isinstance(attn_metadata, dict) - connector.wait_for_layer_load(layer_name) - - -def maybe_save_kv_layer_to_connector( - layer_name: str, - kv_cache_layer: list[torch.Tensor], -): - if not has_kv_transfer_group() or not is_v1_kv_transfer_group(): - return - - connector = get_kv_transfer_group() - if not connector.has_connector_metadata(): - return - - forward_context: ForwardContext = get_forward_context() - attn_metadata = forward_context.attn_metadata - if attn_metadata is None: - return - assert isinstance(attn_metadata, dict) - connector.save_kv_layer(layer_name, kv_cache_layer, attn_metadata[layer_name]) - - def maybe_calc_kv_scales( query: torch.Tensor, key: torch.Tensor, @@ -911,23 +872,46 @@ direct_register_custom_op( ) +def get_attention_context( + layer_name: str, +) -> tuple[dict | object | None, Attention | MLAAttention, torch.Tensor]: + """Extract attention context for a given layer. + + This helper function extracts the attention metadata, attention layer + instance, and KV cache tensor for a specific layer. + + Args: + layer_name: The name/identifier of the attention layer. + + Returns: + A tuple containing: + - attn_metadata: Attention metadata for this specific layer, or None if + no metadata available + - attn_layer: The attention layer instance (Attention or MLAAttention) + - kv_cache: The KV cache tensor for current virtual engine + + Note: attn_metadata may be None, but attn_layer and kv_cache are always + extracted from the forward context. + """ + forward_context: ForwardContext = get_forward_context() + attn_metadata = forward_context.attn_metadata + if isinstance(attn_metadata, dict): + attn_metadata = attn_metadata[layer_name] + attn_layer: Attention | MLAAttention = forward_context.no_compile_layers[layer_name] + kv_cache = attn_layer.kv_cache[forward_context.virtual_engine] + return attn_metadata, attn_layer, kv_cache + + +@maybe_transfer_kv_layer def unified_attention( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, layer_name: str, ) -> torch.Tensor: - wait_for_kv_layer_from_connector(layer_name) - - forward_context: ForwardContext = get_forward_context() - attn_metadata = forward_context.attn_metadata - if isinstance(attn_metadata, dict): - attn_metadata = attn_metadata[layer_name] - self = forward_context.no_compile_layers[layer_name] - kv_cache = self.kv_cache[forward_context.virtual_engine] + attn_metadata, self, kv_cache = get_attention_context(layer_name) output = self.impl.forward(self, query, key, value, kv_cache, attn_metadata) - maybe_save_kv_layer_to_connector(layer_name, kv_cache) return output @@ -947,6 +931,7 @@ direct_register_custom_op( ) +@maybe_transfer_kv_layer def unified_attention_with_output( query: torch.Tensor, key: torch.Tensor, @@ -956,13 +941,7 @@ def unified_attention_with_output( output_scale: torch.Tensor | None = None, output_block_scale: torch.Tensor | None = None, ) -> None: - wait_for_kv_layer_from_connector(layer_name) - forward_context: ForwardContext = get_forward_context() - attn_metadata = forward_context.attn_metadata - if isinstance(attn_metadata, dict): - attn_metadata = attn_metadata[layer_name] - self = forward_context.no_compile_layers[layer_name] - kv_cache = self.kv_cache[forward_context.virtual_engine] + attn_metadata, self, kv_cache = get_attention_context(layer_name) self.impl.forward( self, query, @@ -975,8 +954,6 @@ def unified_attention_with_output( output_block_scale=output_block_scale, ) - maybe_save_kv_layer_to_connector(layer_name, kv_cache) - def unified_attention_with_output_fake( query: torch.Tensor, @@ -998,23 +975,16 @@ direct_register_custom_op( ) +@maybe_transfer_kv_layer def unified_mla_attention( q: torch.Tensor, kv_c_normed: torch.Tensor, k_pe: torch.Tensor, layer_name: str, ) -> torch.Tensor: - wait_for_kv_layer_from_connector(layer_name) - - forward_context: ForwardContext = get_forward_context() - attn_metadata = forward_context.attn_metadata - if isinstance(attn_metadata, dict): - attn_metadata = attn_metadata[layer_name] - self: MLAAttention = forward_context.no_compile_layers[layer_name] - kv_cache = self.kv_cache[forward_context.virtual_engine] + attn_metadata, self, kv_cache = get_attention_context(layer_name) output = self.impl.forward(self, q, kv_c_normed, k_pe, kv_cache, attn_metadata) - maybe_save_kv_layer_to_connector(layer_name, kv_cache) return output @@ -1036,6 +1006,7 @@ direct_register_custom_op( ) +@maybe_transfer_kv_layer def unified_mla_attention_with_output( q: torch.Tensor, kv_c_normed: torch.Tensor, @@ -1045,13 +1016,7 @@ def unified_mla_attention_with_output( output_scale: torch.Tensor | None = None, output_block_scale: torch.Tensor | None = None, ) -> None: - wait_for_kv_layer_from_connector(layer_name) - forward_context: ForwardContext = get_forward_context() - attn_metadata = forward_context.attn_metadata - if isinstance(attn_metadata, dict): - attn_metadata = attn_metadata[layer_name] - self: MLAAttention = forward_context.no_compile_layers[layer_name] - kv_cache = self.kv_cache[forward_context.virtual_engine] + attn_metadata, self, kv_cache = get_attention_context(layer_name) self.impl.forward( self, q, @@ -1064,8 +1029,6 @@ def unified_mla_attention_with_output( output_block_scale=output_block_scale, ) - maybe_save_kv_layer_to_connector(layer_name, kv_cache) - def unified_mla_attention_with_output_fake( q: torch.Tensor, diff --git a/vllm/attention/utils/kv_transfer_utils.py b/vllm/attention/utils/kv_transfer_utils.py new file mode 100644 index 0000000000000..210be55feb2fa --- /dev/null +++ b/vllm/attention/utils/kv_transfer_utils.py @@ -0,0 +1,60 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import inspect +from collections.abc import Callable +from functools import wraps + +from vllm.distributed.kv_transfer import ( + get_kv_transfer_group, + has_kv_transfer_group, + is_v1_kv_transfer_group, +) + + +def maybe_transfer_kv_layer(func: Callable) -> Callable: + """Decorator that handles KV layer transfer prior and after execution of + an attention layer, if enabled. Otherwise, the wrapper is a no-op. + + On entry: waits for the KV layer from the connector. + On exit: saves the KV layer to the connector. + """ + # Import at runtime to avoid circular dependency + from vllm.attention.layer import get_attention_context + + # Inspect the signature ONCE when the decorator is applied. + sig = inspect.signature(func) + param_names = list(sig.parameters.keys()) + + # Find the index of 'layer_name' parameter. + try: + layer_name_index = param_names.index("layer_name") + except ValueError as e: + raise TypeError( + f"Function {func.__name__} must have a 'layer_name' parameter" + ) from e + + @wraps(func) + def wrapper(*args, **kwargs): + if not has_kv_transfer_group() or not is_v1_kv_transfer_group(): + return func(*args, **kwargs) + + layer_name: str = args[layer_name_index] + + # Extract attention context (layer-specific metadata, layer, and kv_cache) + attn_metadata, attn_layer, kv_cache = get_attention_context(layer_name) + connector = get_kv_transfer_group() + if attn_metadata is None or not connector.has_connector_metadata(): + return func(*args, **kwargs) + + # Wait for KV layer on entry + connector.wait_for_layer_load(layer_name) + + # Execute the function + result = func(*args, **kwargs) + + # Save KV cache layer on exit + connector.save_kv_layer(layer_name, kv_cache, attn_metadata) + + return result + + return wrapper