From 97d6c30cc965e70579bfdad27e7514592752096e Mon Sep 17 00:00:00 2001 From: WeiQing Chen <40507679+david6666666@users.noreply.github.com> Date: Sat, 26 Jul 2025 22:07:40 +0800 Subject: [PATCH] [BugFix] Fix shared storage connector load kv only load attention layer (#21428) Signed-off-by: David Chen <530634352@qq.com> --- .../kv_connector/v1/shared_storage_connector.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/shared_storage_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/shared_storage_connector.py index 3c574d0655717..048748e6b8ecb 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/shared_storage_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/shared_storage_connector.py @@ -156,8 +156,16 @@ class SharedStorageConnector(KVConnectorBase_V1): logger.info("Inject KV cache of %d tokens to the paged memory", len(request.slot_mapping)) for layer_name in forward_context.no_compile_layers: - attn_layer = forward_context.no_compile_layers[layer_name] - kv_cache_layer = attn_layer.kv_cache[\ + layer = forward_context.no_compile_layers[layer_name] + + # Only process layers that have kv_cache + # attribute (attention layers) Skip non-attention + # layers like FusedMoE/MLP etc. + kv_cache_attr = getattr(layer, 'kv_cache', None) + if kv_cache_attr is None: + continue + + kv_cache_layer = kv_cache_attr[ \ forward_context.virtual_engine] filename = self._generate_filename_debug(