From 26869256308a3bfbfeedc86bddb5bc92ab836a91 Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Mon, 8 Sep 2025 21:06:29 +0000 Subject: [PATCH] updated Signed-off-by: Robert Shaw --- vllm/attention/layer.py | 46 ++++++++++++++++++++++------------------- 1 file changed, 25 insertions(+), 21 deletions(-) diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index 237802afccde9..fc1d61962ae17 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -418,35 +418,39 @@ class MultiHeadAttention(nn.Module): def wait_for_kv_layer_from_connector(layer_name: str): - if not has_kv_transfer_group() or not is_v1_kv_transfer_group(): - return + print("hi --- wait_for_kv_layer_from_connector") + pass + # if not has_kv_transfer_group() or not is_v1_kv_transfer_group(): + # return - connector = get_kv_transfer_group() + # connector = get_kv_transfer_group() - forward_context: ForwardContext = get_forward_context() - attn_metadata = forward_context.attn_metadata - if attn_metadata is None: - return - assert isinstance(attn_metadata, dict) - connector.wait_for_layer_load(layer_name) + # forward_context: ForwardContext = get_forward_context() + # attn_metadata = forward_context.attn_metadata + # if attn_metadata is None: + # return + # assert isinstance(attn_metadata, dict) + # connector.wait_for_layer_load(layer_name) def maybe_save_kv_layer_to_connector( layer_name: str, kv_cache_layer: List[torch.Tensor], ): - if not has_kv_transfer_group() or not is_v1_kv_transfer_group(): - return + print("hi --- maybe_save_kv_layer_to_connector") + pass + # if not has_kv_transfer_group() or not is_v1_kv_transfer_group(): + # return - connector = get_kv_transfer_group() + # connector = get_kv_transfer_group() - forward_context: ForwardContext = get_forward_context() - attn_metadata = forward_context.attn_metadata - if attn_metadata is None: - return - assert isinstance(attn_metadata, dict) - connector.save_kv_layer(layer_name, kv_cache_layer, - attn_metadata[layer_name]) + # forward_context: ForwardContext = get_forward_context() + # attn_metadata = forward_context.attn_metadata + # if attn_metadata is None: + # return + # assert isinstance(attn_metadata, dict) + # connector.save_kv_layer(layer_name, kv_cache_layer, + # attn_metadata[layer_name]) def unified_attention( @@ -497,7 +501,7 @@ def unified_attention_with_output( output_scale: Optional[torch.Tensor] = None, output_block_scale: Optional[torch.Tensor] = None, ) -> None: - wait_for_kv_layer_from_connector(layer_name) + # wait_for_kv_layer_from_connector(layer_name) forward_context: ForwardContext = get_forward_context() attn_metadata = forward_context.attn_metadata if isinstance(attn_metadata, dict): @@ -514,7 +518,7 @@ def unified_attention_with_output( output_scale=output_scale, output_block_scale=output_block_scale) - maybe_save_kv_layer_to_connector(layer_name, kv_cache) + # maybe_save_kv_layer_to_connector(layer_name, kv_cache) def unified_attention_with_output_fake(