diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_connector.py index 25675d70fe22..2485c57d86ec 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_connector.py @@ -30,27 +30,19 @@ logger = init_logger(__name__) class ReqMeta: # Request Id request_id: str - # Request tokens - token_ids: torch.Tensor - # Slot mappings, should have the same length as token_ids - slot_mapping: torch.Tensor + # Request block ids + block_ids: torch.Tensor + # Request num tokens + num_tokens: int @staticmethod def make_meta(request_id: str, token_ids: list[int], block_ids: list[int], block_size: int) -> "ReqMeta": - valid_num_tokens = len(token_ids) - token_ids_tensor = torch.tensor(token_ids) block_ids_tensor = torch.tensor(block_ids) - num_blocks = block_ids_tensor.shape[0] - block_offsets = torch.arange(0, block_size) - slot_mapping = block_offsets.reshape((1, block_size)) + \ - block_ids_tensor.reshape((num_blocks, 1)) * block_size - slot_mapping = slot_mapping.flatten()[:valid_num_tokens] - return ReqMeta( request_id=request_id, - token_ids=token_ids_tensor, - slot_mapping=slot_mapping, + block_ids=block_ids_tensor, + num_tokens=len(token_ids), ) @@ -123,63 +115,58 @@ class P2pNcclConnector(KVConnectorBase_V1): return def inject_kv_into_layer( - dst_kv_cache_layer: torch.Tensor, - src_kv_cache: torch.Tensor, - slot_mapping: torch.Tensor, + layer: torch.Tensor, + kv_cache: torch.Tensor, + block_ids: torch.Tensor, request_id: str, ) -> None: - """Inject the KV cache into the layer. + """ + Inject KV cache data into a given attention layer tensor. + + This function updates `layer` in-place with values from `kv_cache`, + handling different backend layouts: + - MLA (Multi-Linear Attention) or FlashInfer: KV tensors are + indexed along the first dimension. + - FlashAttention: KV tensors are indexed along the second + dimension. + + If the number of provided block IDs does not match the number of KV + blocks, only the overlapping portion is updated, and a warning is + logged. Args: - dst_kv_cache_layer (torch.Tensor): the destination KV cache - layer. In shape [2, num_pages, page_size, xxx] if not - using MLA, [num_pages, page_size, xxx] otherwise. - src_kv_cache (torch.Tensor): the source KV cache. In shape - [2, num_tokens, xxx] if not using MLA, [num_tokens, xxx] - otherwise. - slot_mapping (torch.Tensor): the slot mapping. In shape - [num_tokens]. - request_id (str): request id for log + layer (torch.Tensor): The attention layer KV tensor to update. + kv_cache (torch.Tensor): The KV cache tensor to inject. + block_ids (torch.Tensor): Indices of the blocks to update. + request_id (str): Request identifier used for logging. + + Returns: + None. The function modifies `layer` in-place. """ - dst_kv_cache_layer_shape = dst_kv_cache_layer.shape - if isinstance(attn_metadata, MLACommonMetadata): - num_pages = dst_kv_cache_layer_shape[0] - page_size = dst_kv_cache_layer_shape[1] - dst_kv_cache_layer = dst_kv_cache_layer.reshape( - num_pages * page_size, -1) - self.check_tensors_except_dim(dst_kv_cache_layer, src_kv_cache, - 0) - num_token = src_kv_cache.shape[0] - if len(slot_mapping) == num_token: - dst_kv_cache_layer[slot_mapping, ...] = src_kv_cache + if (isinstance(attn_metadata, MLACommonMetadata) + or layer.shape[1] == 2): # MLA or FlashInfer + num_block = kv_cache.shape[0] + self.check_tensors_except_dim(layer, kv_cache, 0) + if len(block_ids) == num_block: + layer[block_ids, ...] = kv_cache else: - dst_kv_cache_layer[slot_mapping[:num_token], - ...] = src_kv_cache + layer[block_ids[:num_block], ...] = kv_cache logger.warning( - "🚧src_kv_cache does not match, num_slot:%d, " - "num_token:%d, request_id:%s", len(slot_mapping), - num_token, request_id) + "🚧kv_cache does not match, block_ids:%d, " + "num_block:%d, request_id:%s", len(block_ids), + num_block, request_id) - dst_kv_cache_layer.reshape(dst_kv_cache_layer_shape) - else: - num_pages = dst_kv_cache_layer_shape[1] - page_size = dst_kv_cache_layer_shape[2] - dst_kv_cache_layer = dst_kv_cache_layer.reshape( - 2, num_pages * page_size, -1) - self.check_tensors_except_dim(dst_kv_cache_layer, src_kv_cache, - 1) - num_token = src_kv_cache.shape[1] - if len(slot_mapping) == num_token: - dst_kv_cache_layer[:, slot_mapping, ...] = src_kv_cache + elif layer.shape[0] == 2: # FlashAttention + num_block = kv_cache.shape[1] + self.check_tensors_except_dim(layer, kv_cache, 1) + if len(block_ids) == num_block: + layer[:, block_ids, ...] = kv_cache else: - dst_kv_cache_layer[:, slot_mapping[:num_token], - ...] = src_kv_cache + layer[:, block_ids[:num_block], ...] = kv_cache logger.warning( - "🚧src_kv_cache does not match, num_slot:%d, " - "num_token:%d, request_id:%s", len(slot_mapping), - num_token, request_id) - - dst_kv_cache_layer.reshape(dst_kv_cache_layer_shape) + "🚧kv_cache does not match, block_ids:%d, " + "num_block:%d, request_id:%s", len(block_ids), + num_block, request_id) # Get the metadata metadata: KVConnectorMetadata = \ @@ -201,19 +188,17 @@ class P2pNcclConnector(KVConnectorBase_V1): if kv_cache is None: continue - kv_cache_layer = kv_cache[ \ - forward_context.virtual_engine] + layer = kv_cache[forward_context.virtual_engine] kv_cache = self.p2p_nccl_engine.recv_tensor( request.request_id + "#" + layer_name) if kv_cache is None: - logger.warning("🚧src_kv_cache is None, %s", - request.request_id) + logger.warning("🚧kv_cache is None, %s", request.request_id) continue - inject_kv_into_layer(kv_cache_layer, kv_cache, - request.slot_mapping, request.request_id) + inject_kv_into_layer(layer, kv_cache, request.block_ids, + request.request_id) def wait_for_layer_load(self, layer_name: str) -> None: """Blocking until the KV for a specific layer is loaded into vLLM's @@ -247,20 +232,33 @@ class P2pNcclConnector(KVConnectorBase_V1): def extract_kv_from_layer( layer: torch.Tensor, - slot_mapping: torch.Tensor, + block_ids: torch.Tensor, ) -> torch.Tensor: - """Extract the KV cache from the layer. - - Assume the shape of the layer is (2, num_pages, page_size, xxx) - if MLA is not used, and (num_pages, page_size, xxx) otherwise. """ - if isinstance(attn_metadata, MLACommonMetadata): - num_pages, page_size = layer.shape[0], layer.shape[1] - return layer.reshape(num_pages * page_size, -1)[slot_mapping, - ...] - num_pages, page_size = layer.shape[1], layer.shape[2] - return layer.reshape(2, num_pages * page_size, -1)[:, slot_mapping, - ...] + Extract KV cache slices from a given attention layer tensor. + + This function handles multiple backend layouts: + - MLA (Multi-Linear Attention) or FlashInfer: KV tensors are + indexed along the first dimension. + - FlashAttention: KV tensors are indexed along the second + dimension. + + Args: + layer (torch.Tensor): The KV cache from the attention layer. + block_ids (torch.Tensor): Indices of blocks to extract. + + Returns: + torch.Tensor: A tensor containing the extracted KV slices. + Returns None if the layout is unsupported. + """ + if (isinstance(attn_metadata, MLACommonMetadata) + or layer.shape[1] == 2): # MLA or FlashInfer + return layer[block_ids, ...] + + if layer.shape[0] == 2: # FlashAttention + return layer[:, block_ids, ...] + + return None connector_metadata = self._get_connector_metadata() assert isinstance(connector_metadata, P2pNcclConnectorMetadata) @@ -269,7 +267,7 @@ class P2pNcclConnector(KVConnectorBase_V1): ip, port = self.parse_request_id(request_id, True) remote_address = ip + ":" + str(port + self._rank) - kv_cache = extract_kv_from_layer(kv_layer, request.slot_mapping) + kv_cache = extract_kv_from_layer(kv_layer, request.block_ids) self.p2p_nccl_engine.send_tensor(request_id + "#" + layer_name, kv_cache, remote_address)