From 980385f8c1839f98e04928cb07c643d20b140cf7 Mon Sep 17 00:00:00 2001 From: Mathis Felardos Date: Sat, 8 Mar 2025 07:39:31 +0100 Subject: [PATCH] [Bugfix][Disaggregated] Add a check in send_kv_caches_and_hidden_states and fix the reshape of the KVCache (#14369) Signed-off-by: Mathis Felardos --- .../kv_connector/simple_connector.py | 17 ++++++++++++++--- 1 file changed, 14 insertions(+), 3 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/simple_connector.py b/vllm/distributed/kv_transfer/kv_connector/simple_connector.py index 8e2fbf36b4de4..7315a6f45f7d2 100644 --- a/vllm/distributed/kv_transfer/kv_connector/simple_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/simple_connector.py @@ -2,7 +2,7 @@ """ Simple KV Cache Connector for Distributed Machine Learning Inference -The SimpleConnector transfers KV caches between prefill vLLM worker (KV cache +The SimpleConnector transfers KV caches between prefill vLLM worker (KV cache producer) and decode vLLM worker (KV cache consumer) using PyNcclPipe or MooncakePipe. @@ -159,6 +159,7 @@ class SimpleConnector(KVConnectorBase): input_tokens_tensor = model_input.input_tokens seq_lens = model_input.attn_metadata.seq_lens slot_mapping_flat = model_input.attn_metadata.slot_mapping.flatten() + num_prefill_tokens = model_input.attn_metadata.num_prefill_tokens start_layer = model_executable.model.start_layer end_layer = model_executable.model.end_layer @@ -166,7 +167,8 @@ class SimpleConnector(KVConnectorBase): num_heads = int(model_config.num_key_value_heads / self.tp_size) hidden_size = model_config.hidden_size num_attention_heads = model_config.num_attention_heads - head_size = int(hidden_size / num_attention_heads) + head_size = getattr(model_config, "head_dim", + int(hidden_size // num_attention_heads)) # query_lens contains new KV caches that are added to vLLM. # so we will send them to decode instance @@ -174,6 +176,15 @@ class SimpleConnector(KVConnectorBase): for idx, slen in enumerate(seq_lens): start_pos = sum(seq_lens[:idx]) end_pos = start_pos + slen + + if start_pos >= num_prefill_tokens: + # vllm/worker/model_runner.py::_prepare_model_input_tensors: + # - input_tokens[:num_prefill_tokens] contains prefill tokens. + # - input_tokens[num_prefill_tokens:] contains decode tokens. + logger.warning("You have some decode requests while using " + "SimpleConnector. Their KVCache won't be sent.") + break + current_tokens = input_tokens_tensor[start_pos:end_pos] keys, values = [], [] @@ -236,7 +247,7 @@ class SimpleConnector(KVConnectorBase): # - input_tokens[num_prefill_tokens:] contains decode tokens. logger.warning("You should set --enable_chunked_prefill=False " "and --max_num_batched_tokens " - "should be equal to max_seq_len_to_capture") + "should be equal to --max_seq_len_to_capture") bypass_model_exec = False assert start_pos == num_prefill_tokens break