diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index 237802afccde9..fc1d61962ae17 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -418,35 +418,39 @@ class MultiHeadAttention(nn.Module): def wait_for_kv_layer_from_connector(layer_name: str): - if not has_kv_transfer_group() or not is_v1_kv_transfer_group(): - return + print("hi --- wait_for_kv_layer_from_connector") + pass + # if not has_kv_transfer_group() or not is_v1_kv_transfer_group(): + # return - connector = get_kv_transfer_group() + # connector = get_kv_transfer_group() - forward_context: ForwardContext = get_forward_context() - attn_metadata = forward_context.attn_metadata - if attn_metadata is None: - return - assert isinstance(attn_metadata, dict) - connector.wait_for_layer_load(layer_name) + # forward_context: ForwardContext = get_forward_context() + # attn_metadata = forward_context.attn_metadata + # if attn_metadata is None: + # return + # assert isinstance(attn_metadata, dict) + # connector.wait_for_layer_load(layer_name) def maybe_save_kv_layer_to_connector( layer_name: str, kv_cache_layer: List[torch.Tensor], ): - if not has_kv_transfer_group() or not is_v1_kv_transfer_group(): - return + print("hi --- maybe_save_kv_layer_to_connector") + pass + # if not has_kv_transfer_group() or not is_v1_kv_transfer_group(): + # return - connector = get_kv_transfer_group() + # connector = get_kv_transfer_group() - forward_context: ForwardContext = get_forward_context() - attn_metadata = forward_context.attn_metadata - if attn_metadata is None: - return - assert isinstance(attn_metadata, dict) - connector.save_kv_layer(layer_name, kv_cache_layer, - attn_metadata[layer_name]) + # forward_context: ForwardContext = get_forward_context() + # attn_metadata = forward_context.attn_metadata + # if attn_metadata is None: + # return + # assert isinstance(attn_metadata, dict) + # connector.save_kv_layer(layer_name, kv_cache_layer, + # attn_metadata[layer_name]) def unified_attention( @@ -497,7 +501,7 @@ def unified_attention_with_output( output_scale: Optional[torch.Tensor] = None, output_block_scale: Optional[torch.Tensor] = None, ) -> None: - wait_for_kv_layer_from_connector(layer_name) + # wait_for_kv_layer_from_connector(layer_name) forward_context: ForwardContext = get_forward_context() attn_metadata = forward_context.attn_metadata if isinstance(attn_metadata, dict): @@ -514,7 +518,7 @@ def unified_attention_with_output( output_scale=output_scale, output_block_scale=output_block_scale) - maybe_save_kv_layer_to_connector(layer_name, kv_cache) + # maybe_save_kv_layer_to_connector(layer_name, kv_cache) def unified_attention_with_output_fake(