mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 09:16:06 +08:00
[Misc] Refactor Attention kv transfer methods into decorator (#27816)
Signed-off-by: NickLucche <nlucches@redhat.com> Signed-off-by: Mark McLoughlin <markmc@redhat.com> Co-authored-by: Mark McLoughlin <markmc@redhat.com>
This commit is contained in:
parent
bc5bd45c7d
commit
728a9eb70e
@ -15,14 +15,10 @@ from vllm.attention.backends.abstract import AttentionBackend, MLAAttentionImpl
|
||||
from vllm.attention.backends.registry import AttentionBackendEnum
|
||||
from vllm.attention.selector import get_attn_backend
|
||||
from vllm.attention.utils.kv_sharing_utils import validate_kv_sharing_target
|
||||
from vllm.attention.utils.kv_transfer_utils import maybe_transfer_kv_layer
|
||||
from vllm.config import CacheConfig, get_current_vllm_config
|
||||
from vllm.config.multimodal import MultiModalConfig
|
||||
from vllm.config.vllm import VllmConfig
|
||||
from vllm.distributed.kv_transfer import (
|
||||
get_kv_transfer_group,
|
||||
has_kv_transfer_group,
|
||||
is_v1_kv_transfer_group,
|
||||
)
|
||||
from vllm.forward_context import ForwardContext, get_forward_context
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
|
||||
@ -842,41 +838,6 @@ class MLAAttention(nn.Module, AttentionLayerBase):
|
||||
)
|
||||
|
||||
|
||||
def wait_for_kv_layer_from_connector(layer_name: str):
|
||||
if not has_kv_transfer_group() or not is_v1_kv_transfer_group():
|
||||
return
|
||||
|
||||
connector = get_kv_transfer_group()
|
||||
if not connector.has_connector_metadata():
|
||||
return
|
||||
|
||||
forward_context: ForwardContext = get_forward_context()
|
||||
attn_metadata = forward_context.attn_metadata
|
||||
if attn_metadata is None:
|
||||
return
|
||||
assert isinstance(attn_metadata, dict)
|
||||
connector.wait_for_layer_load(layer_name)
|
||||
|
||||
|
||||
def maybe_save_kv_layer_to_connector(
|
||||
layer_name: str,
|
||||
kv_cache_layer: list[torch.Tensor],
|
||||
):
|
||||
if not has_kv_transfer_group() or not is_v1_kv_transfer_group():
|
||||
return
|
||||
|
||||
connector = get_kv_transfer_group()
|
||||
if not connector.has_connector_metadata():
|
||||
return
|
||||
|
||||
forward_context: ForwardContext = get_forward_context()
|
||||
attn_metadata = forward_context.attn_metadata
|
||||
if attn_metadata is None:
|
||||
return
|
||||
assert isinstance(attn_metadata, dict)
|
||||
connector.save_kv_layer(layer_name, kv_cache_layer, attn_metadata[layer_name])
|
||||
|
||||
|
||||
def maybe_calc_kv_scales(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
@ -911,23 +872,46 @@ direct_register_custom_op(
|
||||
)
|
||||
|
||||
|
||||
def get_attention_context(
|
||||
layer_name: str,
|
||||
) -> tuple[dict | object | None, Attention | MLAAttention, torch.Tensor]:
|
||||
"""Extract attention context for a given layer.
|
||||
|
||||
This helper function extracts the attention metadata, attention layer
|
||||
instance, and KV cache tensor for a specific layer.
|
||||
|
||||
Args:
|
||||
layer_name: The name/identifier of the attention layer.
|
||||
|
||||
Returns:
|
||||
A tuple containing:
|
||||
- attn_metadata: Attention metadata for this specific layer, or None if
|
||||
no metadata available
|
||||
- attn_layer: The attention layer instance (Attention or MLAAttention)
|
||||
- kv_cache: The KV cache tensor for current virtual engine
|
||||
|
||||
Note: attn_metadata may be None, but attn_layer and kv_cache are always
|
||||
extracted from the forward context.
|
||||
"""
|
||||
forward_context: ForwardContext = get_forward_context()
|
||||
attn_metadata = forward_context.attn_metadata
|
||||
if isinstance(attn_metadata, dict):
|
||||
attn_metadata = attn_metadata[layer_name]
|
||||
attn_layer: Attention | MLAAttention = forward_context.no_compile_layers[layer_name]
|
||||
kv_cache = attn_layer.kv_cache[forward_context.virtual_engine]
|
||||
return attn_metadata, attn_layer, kv_cache
|
||||
|
||||
|
||||
@maybe_transfer_kv_layer
|
||||
def unified_attention(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
layer_name: str,
|
||||
) -> torch.Tensor:
|
||||
wait_for_kv_layer_from_connector(layer_name)
|
||||
|
||||
forward_context: ForwardContext = get_forward_context()
|
||||
attn_metadata = forward_context.attn_metadata
|
||||
if isinstance(attn_metadata, dict):
|
||||
attn_metadata = attn_metadata[layer_name]
|
||||
self = forward_context.no_compile_layers[layer_name]
|
||||
kv_cache = self.kv_cache[forward_context.virtual_engine]
|
||||
attn_metadata, self, kv_cache = get_attention_context(layer_name)
|
||||
output = self.impl.forward(self, query, key, value, kv_cache, attn_metadata)
|
||||
|
||||
maybe_save_kv_layer_to_connector(layer_name, kv_cache)
|
||||
return output
|
||||
|
||||
|
||||
@ -947,6 +931,7 @@ direct_register_custom_op(
|
||||
)
|
||||
|
||||
|
||||
@maybe_transfer_kv_layer
|
||||
def unified_attention_with_output(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
@ -956,13 +941,7 @@ def unified_attention_with_output(
|
||||
output_scale: torch.Tensor | None = None,
|
||||
output_block_scale: torch.Tensor | None = None,
|
||||
) -> None:
|
||||
wait_for_kv_layer_from_connector(layer_name)
|
||||
forward_context: ForwardContext = get_forward_context()
|
||||
attn_metadata = forward_context.attn_metadata
|
||||
if isinstance(attn_metadata, dict):
|
||||
attn_metadata = attn_metadata[layer_name]
|
||||
self = forward_context.no_compile_layers[layer_name]
|
||||
kv_cache = self.kv_cache[forward_context.virtual_engine]
|
||||
attn_metadata, self, kv_cache = get_attention_context(layer_name)
|
||||
self.impl.forward(
|
||||
self,
|
||||
query,
|
||||
@ -975,8 +954,6 @@ def unified_attention_with_output(
|
||||
output_block_scale=output_block_scale,
|
||||
)
|
||||
|
||||
maybe_save_kv_layer_to_connector(layer_name, kv_cache)
|
||||
|
||||
|
||||
def unified_attention_with_output_fake(
|
||||
query: torch.Tensor,
|
||||
@ -998,23 +975,16 @@ direct_register_custom_op(
|
||||
)
|
||||
|
||||
|
||||
@maybe_transfer_kv_layer
|
||||
def unified_mla_attention(
|
||||
q: torch.Tensor,
|
||||
kv_c_normed: torch.Tensor,
|
||||
k_pe: torch.Tensor,
|
||||
layer_name: str,
|
||||
) -> torch.Tensor:
|
||||
wait_for_kv_layer_from_connector(layer_name)
|
||||
|
||||
forward_context: ForwardContext = get_forward_context()
|
||||
attn_metadata = forward_context.attn_metadata
|
||||
if isinstance(attn_metadata, dict):
|
||||
attn_metadata = attn_metadata[layer_name]
|
||||
self: MLAAttention = forward_context.no_compile_layers[layer_name]
|
||||
kv_cache = self.kv_cache[forward_context.virtual_engine]
|
||||
attn_metadata, self, kv_cache = get_attention_context(layer_name)
|
||||
output = self.impl.forward(self, q, kv_c_normed, k_pe, kv_cache, attn_metadata)
|
||||
|
||||
maybe_save_kv_layer_to_connector(layer_name, kv_cache)
|
||||
return output
|
||||
|
||||
|
||||
@ -1036,6 +1006,7 @@ direct_register_custom_op(
|
||||
)
|
||||
|
||||
|
||||
@maybe_transfer_kv_layer
|
||||
def unified_mla_attention_with_output(
|
||||
q: torch.Tensor,
|
||||
kv_c_normed: torch.Tensor,
|
||||
@ -1045,13 +1016,7 @@ def unified_mla_attention_with_output(
|
||||
output_scale: torch.Tensor | None = None,
|
||||
output_block_scale: torch.Tensor | None = None,
|
||||
) -> None:
|
||||
wait_for_kv_layer_from_connector(layer_name)
|
||||
forward_context: ForwardContext = get_forward_context()
|
||||
attn_metadata = forward_context.attn_metadata
|
||||
if isinstance(attn_metadata, dict):
|
||||
attn_metadata = attn_metadata[layer_name]
|
||||
self: MLAAttention = forward_context.no_compile_layers[layer_name]
|
||||
kv_cache = self.kv_cache[forward_context.virtual_engine]
|
||||
attn_metadata, self, kv_cache = get_attention_context(layer_name)
|
||||
self.impl.forward(
|
||||
self,
|
||||
q,
|
||||
@ -1064,8 +1029,6 @@ def unified_mla_attention_with_output(
|
||||
output_block_scale=output_block_scale,
|
||||
)
|
||||
|
||||
maybe_save_kv_layer_to_connector(layer_name, kv_cache)
|
||||
|
||||
|
||||
def unified_mla_attention_with_output_fake(
|
||||
q: torch.Tensor,
|
||||
|
||||
60
vllm/attention/utils/kv_transfer_utils.py
Normal file
60
vllm/attention/utils/kv_transfer_utils.py
Normal file
@ -0,0 +1,60 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import inspect
|
||||
from collections.abc import Callable
|
||||
from functools import wraps
|
||||
|
||||
from vllm.distributed.kv_transfer import (
|
||||
get_kv_transfer_group,
|
||||
has_kv_transfer_group,
|
||||
is_v1_kv_transfer_group,
|
||||
)
|
||||
|
||||
|
||||
def maybe_transfer_kv_layer(func: Callable) -> Callable:
|
||||
"""Decorator that handles KV layer transfer prior and after execution of
|
||||
an attention layer, if enabled. Otherwise, the wrapper is a no-op.
|
||||
|
||||
On entry: waits for the KV layer from the connector.
|
||||
On exit: saves the KV layer to the connector.
|
||||
"""
|
||||
# Import at runtime to avoid circular dependency
|
||||
from vllm.attention.layer import get_attention_context
|
||||
|
||||
# Inspect the signature ONCE when the decorator is applied.
|
||||
sig = inspect.signature(func)
|
||||
param_names = list(sig.parameters.keys())
|
||||
|
||||
# Find the index of 'layer_name' parameter.
|
||||
try:
|
||||
layer_name_index = param_names.index("layer_name")
|
||||
except ValueError as e:
|
||||
raise TypeError(
|
||||
f"Function {func.__name__} must have a 'layer_name' parameter"
|
||||
) from e
|
||||
|
||||
@wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
if not has_kv_transfer_group() or not is_v1_kv_transfer_group():
|
||||
return func(*args, **kwargs)
|
||||
|
||||
layer_name: str = args[layer_name_index]
|
||||
|
||||
# Extract attention context (layer-specific metadata, layer, and kv_cache)
|
||||
attn_metadata, attn_layer, kv_cache = get_attention_context(layer_name)
|
||||
connector = get_kv_transfer_group()
|
||||
if attn_metadata is None or not connector.has_connector_metadata():
|
||||
return func(*args, **kwargs)
|
||||
|
||||
# Wait for KV layer on entry
|
||||
connector.wait_for_layer_load(layer_name)
|
||||
|
||||
# Execute the function
|
||||
result = func(*args, **kwargs)
|
||||
|
||||
# Save KV cache layer on exit
|
||||
connector.save_kv_layer(layer_name, kv_cache, attn_metadata)
|
||||
|
||||
return result
|
||||
|
||||
return wrapper
|
||||
Loading…
x
Reference in New Issue
Block a user