diff --git a/tests/utils_/test_utils.py b/tests/utils_/test_utils.py index 8b9411975e153..820f5d42d167c 100644 --- a/tests/utils_/test_utils.py +++ b/tests/utils_/test_utils.py @@ -24,7 +24,6 @@ from vllm.transformers_utils.detokenizer_utils import convert_ids_list_to_tokens from vllm.utils import ( FlexibleArgumentParser, MemorySnapshot, - bind_kv_cache, common_broadcastable_dtype, current_stream, get_open_port, @@ -343,87 +342,6 @@ def test_memory_profiling(): lib.cudaFree(handle2) -def test_bind_kv_cache(): - from vllm.attention import Attention - - ctx = { - "layers.0.self_attn": Attention(32, 128, 0.1), - "layers.1.self_attn": Attention(32, 128, 0.1), - "layers.2.self_attn": Attention(32, 128, 0.1), - "layers.3.self_attn": Attention(32, 128, 0.1), - } - kv_cache = [ - torch.zeros((1,)), - torch.zeros((1,)), - torch.zeros((1,)), - torch.zeros((1,)), - ] - bind_kv_cache(ctx, [kv_cache]) - assert ctx["layers.0.self_attn"].kv_cache[0] is kv_cache[0] - assert ctx["layers.1.self_attn"].kv_cache[0] is kv_cache[1] - assert ctx["layers.2.self_attn"].kv_cache[0] is kv_cache[2] - assert ctx["layers.3.self_attn"].kv_cache[0] is kv_cache[3] - - -def test_bind_kv_cache_kv_sharing(): - from vllm.attention import Attention - - ctx = { - "layers.0.self_attn": Attention(32, 128, 0.1), - "layers.1.self_attn": Attention(32, 128, 0.1), - "layers.2.self_attn": Attention(32, 128, 0.1), - "layers.3.self_attn": Attention(32, 128, 0.1), - } - kv_cache = [ - torch.zeros((1,)), - torch.zeros((1,)), - torch.zeros((1,)), - torch.zeros((1,)), - ] - shared_kv_cache_layers = { - "layers.2.self_attn": "layers.1.self_attn", - "layers.3.self_attn": "layers.0.self_attn", - } - bind_kv_cache(ctx, [kv_cache], shared_kv_cache_layers) - assert ctx["layers.0.self_attn"].kv_cache[0] is kv_cache[0] - assert ctx["layers.1.self_attn"].kv_cache[0] is kv_cache[1] - assert ctx["layers.2.self_attn"].kv_cache[0] is kv_cache[1] - assert ctx["layers.3.self_attn"].kv_cache[0] is kv_cache[0] - - -def test_bind_kv_cache_non_attention(): - from vllm.attention import Attention - - # example from Jamba PP=2 - ctx = { - "model.layers.20.attn": Attention(32, 128, 0.1), - "model.layers.28.attn": Attention(32, 128, 0.1), - } - kv_cache = [ - torch.zeros((1,)), - torch.zeros((1,)), - ] - bind_kv_cache(ctx, [kv_cache]) - assert ctx["model.layers.20.attn"].kv_cache[0] is kv_cache[0] - assert ctx["model.layers.28.attn"].kv_cache[0] is kv_cache[1] - - -def test_bind_kv_cache_pp(): - with patch("vllm.utils.cuda_device_count_stateless", lambda: 2): - # this test runs with 1 GPU, but we simulate 2 GPUs - cfg = VllmConfig(parallel_config=ParallelConfig(pipeline_parallel_size=2)) - with set_current_vllm_config(cfg): - from vllm.attention import Attention - - ctx = { - "layers.0.self_attn": Attention(32, 128, 0.1), - } - kv_cache = [[torch.zeros((1,))], [torch.zeros((1,))]] - bind_kv_cache(ctx, kv_cache) - assert ctx["layers.0.self_attn"].kv_cache[0] is kv_cache[0][0] - assert ctx["layers.0.self_attn"].kv_cache[1] is kv_cache[1][0] - - @pytest.mark.parametrize( ("src_dtype", "tgt_dtype", "expected_result"), [ diff --git a/tests/v1/kv_connector/unit/test_nixl_connector.py b/tests/v1/kv_connector/unit/test_nixl_connector.py index 869e80a1af88c..293fa4d62e399 100644 --- a/tests/v1/kv_connector/unit/test_nixl_connector.py +++ b/tests/v1/kv_connector/unit/test_nixl_connector.py @@ -382,7 +382,6 @@ class TestNixlHandshake: dummy_ctx = ForwardContext( no_compile_layers={}, attn_metadata={}, - virtual_engine=0, ) _before_load = time.perf_counter() connector.start_load_kv(dummy_ctx) @@ -450,7 +449,6 @@ class TestNixlHandshake: dummy_ctx = ForwardContext( no_compile_layers={}, attn_metadata={}, - virtual_engine=0, ) _before_load = time.perf_counter() connector.start_load_kv(dummy_ctx) @@ -506,7 +504,6 @@ class TestNixlHandshake: dummy_ctx = ForwardContext( no_compile_layers={}, attn_metadata={}, - virtual_engine=0, ) _before_load = time.perf_counter() connector.start_load_kv(dummy_ctx) @@ -666,7 +663,6 @@ def test_kv_connector_stats(dist_init): dummy_ctx = ForwardContext( no_compile_layers={}, attn_metadata={}, - virtual_engine=0, ) connector.start_load_kv(dummy_ctx) @@ -1241,7 +1237,6 @@ def test_aborted_request_removed_from_worker_in_batch(dist_init): dummy_ctx = ForwardContext( no_compile_layers={}, attn_metadata={}, - virtual_engine=0, ) connector.start_load_kv(dummy_ctx) @@ -1344,7 +1339,6 @@ def test_handshake_failure_returns_finished(dist_init): dummy_ctx = ForwardContext( no_compile_layers={}, attn_metadata={}, - virtual_engine=0, ) connector.start_load_kv(dummy_ctx) @@ -1393,7 +1387,6 @@ def test_transfer_setup_failure_returns_finished(dist_init): dummy_ctx = ForwardContext( no_compile_layers={}, attn_metadata={}, - virtual_engine=0, ) connector.start_load_kv(dummy_ctx) diff --git a/tests/v1/kv_connector/unit/test_offloading_connector.py b/tests/v1/kv_connector/unit/test_offloading_connector.py index 46a5c097094eb..b87a91eb857ec 100644 --- a/tests/v1/kv_connector/unit/test_offloading_connector.py +++ b/tests/v1/kv_connector/unit/test_offloading_connector.py @@ -179,7 +179,7 @@ class RequestRunner: self._block_hasher = get_request_block_hasher(gpu_block_size, sha256) self._dummy_ctx: ForwardContext = ForwardContext( - no_compile_layers={}, attn_metadata={}, virtual_engine=0 + no_compile_layers={}, attn_metadata={} ) def new_request(self, token_ids: list[int]): diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index 4b591f07ca2d4..ee24f52595af9 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -272,14 +272,9 @@ class Attention(nn.Module, AttentionLayerBase): self.kv_sharing_target_layer_name = kv_sharing_target_layer_name # use a placeholder kv cache tensor during init, which will be replaced - # by bind_kv_cache - # this variable will not be accessed if use_direct_call is True - self.kv_cache = [ - torch.tensor([]) - for _ in range( - get_current_vllm_config().parallel_config.pipeline_parallel_size - ) - ] + # by bind_kv_cache this variable will not be accessed if use_direct_call + # is True + self.kv_cache = torch.tensor([]) try: self.q_range = torch.tensor(envs.Q_SCALE_CONSTANT, dtype=torch.float32) @@ -361,9 +356,9 @@ class Attention(nn.Module, AttentionLayerBase): attn_metadata = forward_context.attn_metadata if isinstance(attn_metadata, dict): attn_metadata = attn_metadata[self.layer_name] - self_kv_cache = self.kv_cache[forward_context.virtual_engine] + self.impl.forward( - self, query, key, value, self_kv_cache, attn_metadata, output=output + self, query, key, value, self.kv_cache, attn_metadata, output=output ) else: torch.ops.vllm.unified_attention_with_output( @@ -376,9 +371,9 @@ class Attention(nn.Module, AttentionLayerBase): attn_metadata = forward_context.attn_metadata if isinstance(attn_metadata, dict): attn_metadata = attn_metadata[self.layer_name] - self_kv_cache = self.kv_cache[forward_context.virtual_engine] + return self.impl.forward( - self, query, key, value, self_kv_cache, attn_metadata + self, query, key, value, self.kv_cache, attn_metadata ) else: return torch.ops.vllm.unified_attention( @@ -644,12 +639,7 @@ class MLAAttention(nn.Module, AttentionLayerBase): raise ValueError(f"Duplicate layer name: {prefix}") compilation_config.static_forward_context[prefix] = self - self.kv_cache = [ - torch.tensor([]) - for _ in range( - get_current_vllm_config().parallel_config.pipeline_parallel_size - ) - ] + self.kv_cache = torch.tensor([]) # Align with Attention's scale attributes for MLA backends. @@ -688,7 +678,6 @@ class MLAAttention(nn.Module, AttentionLayerBase): attn_metadata = forward_context.attn_metadata if isinstance(attn_metadata, dict): attn_metadata = attn_metadata[self.layer_name] - self_kv_cache = self.kv_cache[forward_context.virtual_engine] # Mirror Attention.forward scale calculation path if self.calculate_kv_scales and getattr( @@ -703,14 +692,14 @@ class MLAAttention(nn.Module, AttentionLayerBase): q, kv_c_normed, k_pe, - self_kv_cache, + self.kv_cache, attn_metadata, output=output, ) return output else: return self.impl.forward( - self, q, kv_c_normed, k_pe, self_kv_cache, attn_metadata + self, q, kv_c_normed, k_pe, self.kv_cache, attn_metadata ) else: if self.attn_backend.accept_output_buffer: @@ -785,7 +774,7 @@ def wait_for_kv_layer_from_connector(layer_name: str): def maybe_save_kv_layer_to_connector( layer_name: str, - kv_cache_layer: list[torch.Tensor], + kv_cache_layer: torch.Tensor, ): if not has_kv_transfer_group() or not is_v1_kv_transfer_group(): return @@ -851,10 +840,9 @@ def unified_attention( if isinstance(attn_metadata, dict): attn_metadata = attn_metadata[layer_name] self = forward_context.no_compile_layers[layer_name] - kv_cache = self.kv_cache[forward_context.virtual_engine] - output = self.impl.forward(self, query, key, value, kv_cache, attn_metadata) + output = self.impl.forward(self, query, key, value, self.kv_cache, attn_metadata) - maybe_save_kv_layer_to_connector(layer_name, kv_cache) + maybe_save_kv_layer_to_connector(layer_name, self.kv_cache) return output @@ -889,20 +877,19 @@ def unified_attention_with_output( if isinstance(attn_metadata, dict): attn_metadata = attn_metadata[layer_name] self = forward_context.no_compile_layers[layer_name] - kv_cache = self.kv_cache[forward_context.virtual_engine] self.impl.forward( self, query, key, value, - kv_cache, + self.kv_cache, attn_metadata, output=output, output_scale=output_scale, output_block_scale=output_block_scale, ) - maybe_save_kv_layer_to_connector(layer_name, kv_cache) + maybe_save_kv_layer_to_connector(layer_name, self.kv_cache) def unified_attention_with_output_fake( @@ -938,10 +925,9 @@ def unified_mla_attention( if isinstance(attn_metadata, dict): attn_metadata = attn_metadata[layer_name] self: MLAAttention = forward_context.no_compile_layers[layer_name] - kv_cache = self.kv_cache[forward_context.virtual_engine] - output = self.impl.forward(self, q, kv_c_normed, k_pe, kv_cache, attn_metadata) + output = self.impl.forward(self, q, kv_c_normed, k_pe, self.kv_cache, attn_metadata) - maybe_save_kv_layer_to_connector(layer_name, kv_cache) + maybe_save_kv_layer_to_connector(layer_name, self.kv_cache) return output @@ -978,20 +964,19 @@ def unified_mla_attention_with_output( if isinstance(attn_metadata, dict): attn_metadata = attn_metadata[layer_name] self: MLAAttention = forward_context.no_compile_layers[layer_name] - kv_cache = self.kv_cache[forward_context.virtual_engine] self.impl.forward( self, q, kv_c_normed, k_pe, - kv_cache, + self.kv_cache, attn_metadata, output=output, output_scale=output_scale, output_block_scale=output_block_scale, ) - maybe_save_kv_layer_to_connector(layer_name, kv_cache) + maybe_save_kv_layer_to_connector(layer_name, self.kv_cache) def unified_mla_attention_with_output_fake( diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_connector.py index e47cde2614fc2..76be0298452dd 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_connector.py @@ -200,12 +200,10 @@ class P2pNcclConnector(KVConnectorBase_V1): # Only process layers that have kv_cache # attribute (attention layers) Skip non-attention # layers like FusedMoE - kv_cache = getattr(layer, "kv_cache", None) - if kv_cache is None: + layer = getattr(layer, "kv_cache", None) + if layer is None: continue - layer = kv_cache[forward_context.virtual_engine] - kv_cache = self.p2p_nccl_engine.recv_tensor( request.request_id + "#" + layer_name, remote_address ) diff --git a/vllm/forward_context.py b/vllm/forward_context.py index 484de15040c21..899139e3b57c0 100644 --- a/vllm/forward_context.py +++ b/vllm/forward_context.py @@ -37,7 +37,7 @@ class BatchDescriptor(NamedTuple): num_tokens: int uniform_decode: bool = False """ - False can also be used for an uniform decode batch to dispatch to the + False can also be used for an uniform decode batch to dispatch to the cudagraph supporting non-uniform batches. """ @@ -179,8 +179,8 @@ class ForwardContext: # copy from vllm_config.compilation_config.static_forward_context no_compile_layers: dict[str, Any] """ - Type AttentionMetadata for v0, - Type Dict[str, AttentionMetadata] for v1, map from layer_name of each + Type AttentionMetadata for v0, + Type Dict[str, AttentionMetadata] for v1, map from layer_name of each attention layer to its attention metadata Type List[Dict[str, AttentionMetadata]] for DBO. List of size two, one for each microbatch. @@ -191,8 +191,6 @@ class ForwardContext: dict[str, "AttentionMetadata"], list[dict[str, "AttentionMetadata"]], ] - # TODO: remove after making all virtual_engines share the same kv cache - virtual_engine: int # set dynamically for each forward pass # set dynamically for each forward pass dp_metadata: DPMetadata | None = None # determine the cudagraph style at runtime to be FULL, PIECEWISE, or NONE. @@ -223,7 +221,6 @@ def get_forward_context() -> ForwardContext: def create_forward_context( attn_metadata: Any, vllm_config: VllmConfig, - virtual_engine: int = 0, dp_metadata: DPMetadata | None = None, cudagraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE, batch_descriptor: BatchDescriptor | None = None, @@ -231,7 +228,6 @@ def create_forward_context( ): return ForwardContext( no_compile_layers=vllm_config.compilation_config.static_forward_context, - virtual_engine=virtual_engine, attn_metadata=attn_metadata, dp_metadata=dp_metadata, cudagraph_runtime_mode=cudagraph_runtime_mode, @@ -259,7 +255,6 @@ def override_forward_context(forward_context: ForwardContext | None): def set_forward_context( attn_metadata: Any, vllm_config: VllmConfig, - virtual_engine: int = 0, num_tokens: int | None = None, num_tokens_across_dp: torch.Tensor | None = None, cudagraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE, @@ -305,7 +300,6 @@ def set_forward_context( forward_context = create_forward_context( attn_metadata, vllm_config, - virtual_engine, dp_metadata, cudagraph_runtime_mode, batch_descriptor, diff --git a/vllm/model_executor/layers/mamba/linear_attn.py b/vllm/model_executor/layers/mamba/linear_attn.py index ce8f50bb27b82..bb1b83c4e9516 100644 --- a/vllm/model_executor/layers/mamba/linear_attn.py +++ b/vllm/model_executor/layers/mamba/linear_attn.py @@ -328,7 +328,7 @@ class MiniMaxText01LinearAttention(nn.Module, MambaBase): qkvact = qkvact.view((qkv.shape[0], self.tp_heads, -1)) q, k, v = torch.split(qkvact, [self.head_dim] * 3, dim=-1) if attn_metadata is not None: - kv_cache = self.kv_cache[forward_context.virtual_engine][0] + kv_cache = self.kv_cache[0] state_indices_tensor = attn_metadata.state_indices_tensor num_prefills = getattr(attn_metadata, "num_prefills", 0) diff --git a/vllm/model_executor/layers/mamba/mamba_mixer.py b/vllm/model_executor/layers/mamba/mamba_mixer.py index 8f7317556f776..d5c7660855665 100644 --- a/vllm/model_executor/layers/mamba/mamba_mixer.py +++ b/vllm/model_executor/layers/mamba/mamba_mixer.py @@ -248,9 +248,8 @@ class MambaMixer(MambaBase, CustomOp): assert isinstance(mamba1_metadata, Mamba1AttentionMetadata) query_start_loc = mamba1_metadata.query_start_loc state_indices_tensor = mamba1_metadata.state_indices_tensor - self_kv_cache = self.kv_cache[forward_context.virtual_engine] - conv_state = self_kv_cache[0].transpose(-1, -2) - ssm_state = self_kv_cache[1] + conv_state = self.kv_cache[0].transpose(-1, -2) + ssm_state = self.kv_cache[1] has_initial_states = mamba1_metadata.has_initial_states num_padded_decodes = mamba1_metadata.num_padded_decodes diff --git a/vllm/model_executor/layers/mamba/mamba_mixer2.py b/vllm/model_executor/layers/mamba/mamba_mixer2.py index b0ee327a82347..caa9fb9864d42 100644 --- a/vllm/model_executor/layers/mamba/mamba_mixer2.py +++ b/vllm/model_executor/layers/mamba/mamba_mixer2.py @@ -511,10 +511,9 @@ class MambaMixer2(MambaBase, CustomOp): assert isinstance(attn_metadata, dict) attn_metadata = attn_metadata[self.prefix] assert isinstance(attn_metadata, Mamba2AttentionMetadata) - self_kv_cache = self.kv_cache[forward_context.virtual_engine] # conv_state = (..., dim, width-1) yet contiguous along 'dim' - conv_state = self_kv_cache[0].transpose(-1, -2) - ssm_state = self_kv_cache[1] + conv_state = self.kv_cache[0].transpose(-1, -2) + ssm_state = self.kv_cache[1] state_indices_tensor = attn_metadata.state_indices_tensor has_initial_states_p = attn_metadata.has_initial_states_p prep_initial_states = attn_metadata.prep_initial_states diff --git a/vllm/model_executor/layers/mamba/short_conv.py b/vllm/model_executor/layers/mamba/short_conv.py index afaa706929a2c..799e086dba044 100644 --- a/vllm/model_executor/layers/mamba/short_conv.py +++ b/vllm/model_executor/layers/mamba/short_conv.py @@ -118,8 +118,7 @@ class ShortConv(MambaBase, CustomOp): assert isinstance(attn_metadata, dict) attn_metadata = attn_metadata[self.prefix] assert isinstance(attn_metadata, ShortConvAttentionMetadata) - self_kv_cache = self.kv_cache[forward_context.virtual_engine] - conv_state = self_kv_cache[0].transpose(-1, -2) + conv_state = self.kv_cache[0].transpose(-1, -2) state_indices_tensor = attn_metadata.state_indices_tensor has_initial_states_p = attn_metadata.has_initial_states_p diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index 5b55b685dacfc..ebd6e6b5df497 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -471,7 +471,7 @@ class DeepseekV32IndexerCache(torch.nn.Module, AttentionLayerBase): self, head_dim: int, dtype: torch.dtype, prefix: str, cache_config: CacheConfig ): super().__init__() - self.kv_cache = [torch.tensor([])] + self.kv_cache = torch.tensor([]) self.head_dim = head_dim self.prefix = prefix self.cache_config = cache_config diff --git a/vllm/model_executor/models/plamo2.py b/vllm/model_executor/models/plamo2.py index b35a8c6b66f26..5c86fdbc9ddc6 100644 --- a/vllm/model_executor/models/plamo2.py +++ b/vllm/model_executor/models/plamo2.py @@ -258,10 +258,9 @@ class Plamo2MambaMixer(MambaBase, CustomOp): assert isinstance(attn_metadata, dict) attn_metadata = attn_metadata[self.prefix] assert isinstance(attn_metadata, Mamba2AttentionMetadata) - self_kv_cache = self.kv_cache[forward_context.virtual_engine] # conv_state = (..., dim, width-1) yet contiguous along 'dim' - conv_state = self_kv_cache[0].transpose(-1, -2) - ssm_state = self_kv_cache[1] + conv_state = self.kv_cache[0].transpose(-1, -2) + ssm_state = self.kv_cache[1] state_indices_tensor = attn_metadata.state_indices_tensor has_initial_states_p = attn_metadata.has_initial_states_p prep_initial_states = attn_metadata.prep_initial_states diff --git a/vllm/model_executor/models/qwen3_next.py b/vllm/model_executor/models/qwen3_next.py index f891a4961dd70..eabee4ab3d9c3 100644 --- a/vllm/model_executor/models/qwen3_next.py +++ b/vllm/model_executor/models/qwen3_next.py @@ -458,9 +458,8 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase): non_spec_token_indx = attn_metadata.non_spec_token_indx spec_state_indices_tensor = attn_metadata.spec_state_indices_tensor # noqa: E501 non_spec_state_indices_tensor = attn_metadata.non_spec_state_indices_tensor # noqa: E501 - self_kv_cache = self.kv_cache[forward_context.virtual_engine] - conv_state = self_kv_cache[0].transpose(-1, -2) - ssm_state = self_kv_cache[1] + conv_state = self.kv_cache[0].transpose(-1, -2) + ssm_state = self.kv_cache[1] num_actual_tokens = attn_metadata.num_actual_tokens num_accepted_tokens = attn_metadata.num_accepted_tokens diff --git a/vllm/utils/__init__.py b/vllm/utils/__init__.py index a35dda2d77345..529702e06f145 100644 --- a/vllm/utils/__init__.py +++ b/vllm/utils/__init__.py @@ -2013,55 +2013,6 @@ def get_mp_context(): return multiprocessing.get_context(mp_method) -def bind_kv_cache( - ctx: dict[str, Any], - kv_cache: list[list[torch.Tensor]], # [virtual_engine][layer_index] - shared_kv_cache_layers: dict[str, str] | None = None, -) -> None: - # Bind the kv_cache tensor to Attention modules, similar to - # ctx[layer_name].kv_cache[ve]=kv_cache[ve][extract_layer_index(layer_name)] - # Special things handled here: - # 1. Some models have non-attention layers, e.g., Jamba - # 2. Pipeline parallelism, each rank only has a subset of layers - # 3. Encoder attention has no kv cache - # 4. Encoder-decoder models, encoder-decoder attention and decoder-only - # attention of the same layer (e.g., bart's decoder.layers.1.self_attn - # and decoder.layers.1.encoder_attn) is mapped to the same kv cache - # tensor - # 5. Some models have attention layers that share kv cache with previous - # layers, this is specified through shared_kv_cache_layers - if shared_kv_cache_layers is None: - shared_kv_cache_layers = {} - from vllm.attention import AttentionType - from vllm.model_executor.models.utils import extract_layer_index - - layer_need_kv_cache = [ - layer_name - for layer_name in ctx - if ( - hasattr(ctx[layer_name], "attn_type") - and ctx[layer_name].attn_type - in (AttentionType.DECODER, AttentionType.ENCODER_DECODER) - ) - and ctx[layer_name].kv_sharing_target_layer_name is None - ] - layer_index_sorted = sorted( - set(extract_layer_index(layer_name) for layer_name in layer_need_kv_cache) - ) - for layer_name in layer_need_kv_cache: - kv_cache_idx = layer_index_sorted.index(extract_layer_index(layer_name)) - forward_ctx = ctx[layer_name] - assert len(forward_ctx.kv_cache) == len(kv_cache) - for ve, ve_kv_cache in enumerate(kv_cache): - forward_ctx.kv_cache[ve] = ve_kv_cache[kv_cache_idx] - if shared_kv_cache_layers is not None: - for layer_name, target_layer_name in shared_kv_cache_layers.items(): - assert extract_layer_index(target_layer_name) < extract_layer_index( - layer_name - ), "v0 doesn't support interleaving kv sharing" - ctx[layer_name].kv_cache = ctx[target_layer_name].kv_cache - - def run_method( obj: Any, method: str | bytes | Callable, diff --git a/vllm/v1/worker/utils.py b/vllm/v1/worker/utils.py index f384ede066210..b7dce67687464 100644 --- a/vllm/v1/worker/utils.py +++ b/vllm/v1/worker/utils.py @@ -318,8 +318,7 @@ def bind_kv_cache( # Bind kv_caches to forward context for layer_name, kv_cache in kv_caches.items(): - # NOTE: Use list because of v0 PP virtual engine. - forward_context[layer_name].kv_cache = [kv_cache] + forward_context[layer_name].kv_cache = kv_cache def is_residual_scattered_for_sp(