[V1][P/D]P2pNcclConnector supports flashinfer (#23536)

Signed-off-by: Abatom <abzhonghua@gmail.com>
Co-authored-by: Simon Mo <simon.mo@hey.com>
This commit is contained in:
Zhonghua Deng 2025-08-27 06:56:16 +08:00 committed by GitHub
parent 6421b66bf4
commit c3b0fd1ee6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -30,27 +30,19 @@ logger = init_logger(__name__)
class ReqMeta: class ReqMeta:
# Request Id # Request Id
request_id: str request_id: str
# Request tokens # Request block ids
token_ids: torch.Tensor block_ids: torch.Tensor
# Slot mappings, should have the same length as token_ids # Request num tokens
slot_mapping: torch.Tensor num_tokens: int
@staticmethod @staticmethod
def make_meta(request_id: str, token_ids: list[int], block_ids: list[int], def make_meta(request_id: str, token_ids: list[int], block_ids: list[int],
block_size: int) -> "ReqMeta": block_size: int) -> "ReqMeta":
valid_num_tokens = len(token_ids)
token_ids_tensor = torch.tensor(token_ids)
block_ids_tensor = torch.tensor(block_ids) block_ids_tensor = torch.tensor(block_ids)
num_blocks = block_ids_tensor.shape[0]
block_offsets = torch.arange(0, block_size)
slot_mapping = block_offsets.reshape((1, block_size)) + \
block_ids_tensor.reshape((num_blocks, 1)) * block_size
slot_mapping = slot_mapping.flatten()[:valid_num_tokens]
return ReqMeta( return ReqMeta(
request_id=request_id, request_id=request_id,
token_ids=token_ids_tensor, block_ids=block_ids_tensor,
slot_mapping=slot_mapping, num_tokens=len(token_ids),
) )
@ -123,63 +115,58 @@ class P2pNcclConnector(KVConnectorBase_V1):
return return
def inject_kv_into_layer( def inject_kv_into_layer(
dst_kv_cache_layer: torch.Tensor, layer: torch.Tensor,
src_kv_cache: torch.Tensor, kv_cache: torch.Tensor,
slot_mapping: torch.Tensor, block_ids: torch.Tensor,
request_id: str, request_id: str,
) -> None: ) -> None:
"""Inject the KV cache into the layer. """
Inject KV cache data into a given attention layer tensor.
This function updates `layer` in-place with values from `kv_cache`,
handling different backend layouts:
- MLA (Multi-Linear Attention) or FlashInfer: KV tensors are
indexed along the first dimension.
- FlashAttention: KV tensors are indexed along the second
dimension.
If the number of provided block IDs does not match the number of KV
blocks, only the overlapping portion is updated, and a warning is
logged.
Args: Args:
dst_kv_cache_layer (torch.Tensor): the destination KV cache layer (torch.Tensor): The attention layer KV tensor to update.
layer. In shape [2, num_pages, page_size, xxx] if not kv_cache (torch.Tensor): The KV cache tensor to inject.
using MLA, [num_pages, page_size, xxx] otherwise. block_ids (torch.Tensor): Indices of the blocks to update.
src_kv_cache (torch.Tensor): the source KV cache. In shape request_id (str): Request identifier used for logging.
[2, num_tokens, xxx] if not using MLA, [num_tokens, xxx]
otherwise. Returns:
slot_mapping (torch.Tensor): the slot mapping. In shape None. The function modifies `layer` in-place.
[num_tokens].
request_id (str): request id for log
""" """
dst_kv_cache_layer_shape = dst_kv_cache_layer.shape if (isinstance(attn_metadata, MLACommonMetadata)
if isinstance(attn_metadata, MLACommonMetadata): or layer.shape[1] == 2): # MLA or FlashInfer
num_pages = dst_kv_cache_layer_shape[0] num_block = kv_cache.shape[0]
page_size = dst_kv_cache_layer_shape[1] self.check_tensors_except_dim(layer, kv_cache, 0)
dst_kv_cache_layer = dst_kv_cache_layer.reshape( if len(block_ids) == num_block:
num_pages * page_size, -1) layer[block_ids, ...] = kv_cache
self.check_tensors_except_dim(dst_kv_cache_layer, src_kv_cache,
0)
num_token = src_kv_cache.shape[0]
if len(slot_mapping) == num_token:
dst_kv_cache_layer[slot_mapping, ...] = src_kv_cache
else: else:
dst_kv_cache_layer[slot_mapping[:num_token], layer[block_ids[:num_block], ...] = kv_cache
...] = src_kv_cache
logger.warning( logger.warning(
"🚧src_kv_cache does not match, num_slot:%d, " "🚧kv_cache does not match, block_ids:%d, "
"num_token:%d, request_id:%s", len(slot_mapping), "num_block:%d, request_id:%s", len(block_ids),
num_token, request_id) num_block, request_id)
dst_kv_cache_layer.reshape(dst_kv_cache_layer_shape) elif layer.shape[0] == 2: # FlashAttention
else: num_block = kv_cache.shape[1]
num_pages = dst_kv_cache_layer_shape[1] self.check_tensors_except_dim(layer, kv_cache, 1)
page_size = dst_kv_cache_layer_shape[2] if len(block_ids) == num_block:
dst_kv_cache_layer = dst_kv_cache_layer.reshape( layer[:, block_ids, ...] = kv_cache
2, num_pages * page_size, -1)
self.check_tensors_except_dim(dst_kv_cache_layer, src_kv_cache,
1)
num_token = src_kv_cache.shape[1]
if len(slot_mapping) == num_token:
dst_kv_cache_layer[:, slot_mapping, ...] = src_kv_cache
else: else:
dst_kv_cache_layer[:, slot_mapping[:num_token], layer[:, block_ids[:num_block], ...] = kv_cache
...] = src_kv_cache
logger.warning( logger.warning(
"🚧src_kv_cache does not match, num_slot:%d, " "🚧kv_cache does not match, block_ids:%d, "
"num_token:%d, request_id:%s", len(slot_mapping), "num_block:%d, request_id:%s", len(block_ids),
num_token, request_id) num_block, request_id)
dst_kv_cache_layer.reshape(dst_kv_cache_layer_shape)
# Get the metadata # Get the metadata
metadata: KVConnectorMetadata = \ metadata: KVConnectorMetadata = \
@ -201,19 +188,17 @@ class P2pNcclConnector(KVConnectorBase_V1):
if kv_cache is None: if kv_cache is None:
continue continue
kv_cache_layer = kv_cache[ \ layer = kv_cache[forward_context.virtual_engine]
forward_context.virtual_engine]
kv_cache = self.p2p_nccl_engine.recv_tensor( kv_cache = self.p2p_nccl_engine.recv_tensor(
request.request_id + "#" + layer_name) request.request_id + "#" + layer_name)
if kv_cache is None: if kv_cache is None:
logger.warning("🚧src_kv_cache is None, %s", logger.warning("🚧kv_cache is None, %s", request.request_id)
request.request_id)
continue continue
inject_kv_into_layer(kv_cache_layer, kv_cache, inject_kv_into_layer(layer, kv_cache, request.block_ids,
request.slot_mapping, request.request_id) request.request_id)
def wait_for_layer_load(self, layer_name: str) -> None: def wait_for_layer_load(self, layer_name: str) -> None:
"""Blocking until the KV for a specific layer is loaded into vLLM's """Blocking until the KV for a specific layer is loaded into vLLM's
@ -247,20 +232,33 @@ class P2pNcclConnector(KVConnectorBase_V1):
def extract_kv_from_layer( def extract_kv_from_layer(
layer: torch.Tensor, layer: torch.Tensor,
slot_mapping: torch.Tensor, block_ids: torch.Tensor,
) -> torch.Tensor: ) -> torch.Tensor:
"""Extract the KV cache from the layer.
Assume the shape of the layer is (2, num_pages, page_size, xxx)
if MLA is not used, and (num_pages, page_size, xxx) otherwise.
""" """
if isinstance(attn_metadata, MLACommonMetadata): Extract KV cache slices from a given attention layer tensor.
num_pages, page_size = layer.shape[0], layer.shape[1]
return layer.reshape(num_pages * page_size, -1)[slot_mapping, This function handles multiple backend layouts:
...] - MLA (Multi-Linear Attention) or FlashInfer: KV tensors are
num_pages, page_size = layer.shape[1], layer.shape[2] indexed along the first dimension.
return layer.reshape(2, num_pages * page_size, -1)[:, slot_mapping, - FlashAttention: KV tensors are indexed along the second
...] dimension.
Args:
layer (torch.Tensor): The KV cache from the attention layer.
block_ids (torch.Tensor): Indices of blocks to extract.
Returns:
torch.Tensor: A tensor containing the extracted KV slices.
Returns None if the layout is unsupported.
"""
if (isinstance(attn_metadata, MLACommonMetadata)
or layer.shape[1] == 2): # MLA or FlashInfer
return layer[block_ids, ...]
if layer.shape[0] == 2: # FlashAttention
return layer[:, block_ids, ...]
return None
connector_metadata = self._get_connector_metadata() connector_metadata = self._get_connector_metadata()
assert isinstance(connector_metadata, P2pNcclConnectorMetadata) assert isinstance(connector_metadata, P2pNcclConnectorMetadata)
@ -269,7 +267,7 @@ class P2pNcclConnector(KVConnectorBase_V1):
ip, port = self.parse_request_id(request_id, True) ip, port = self.parse_request_id(request_id, True)
remote_address = ip + ":" + str(port + self._rank) remote_address = ip + ":" + str(port + self._rank)
kv_cache = extract_kv_from_layer(kv_layer, request.slot_mapping) kv_cache = extract_kv_from_layer(kv_layer, request.block_ids)
self.p2p_nccl_engine.send_tensor(request_id + "#" + layer_name, self.p2p_nccl_engine.send_tensor(request_id + "#" + layer_name,
kv_cache, remote_address) kv_cache, remote_address)