[P/D][NixlConnector] Enable FlashInfer backend (#19090)

This commit is contained in:
Nicolò Lucchesi 2025-06-05 19:10:15 +02:00 committed by GitHub
parent 85e2b7bb13
commit 9ef9173cfa
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 51 additions and 15 deletions

View File

@ -15,6 +15,7 @@ import torch
import zmq
from vllm import envs
from vllm.attention.selector import backend_name_to_enum, get_attn_backend
from vllm.config import VllmConfig
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole)
@ -22,6 +23,7 @@ from vllm.distributed.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size,
get_tp_group)
from vllm.logger import init_logger
from vllm.platforms import _Backend
from vllm.utils import make_zmq_path, make_zmq_socket, round_down
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.request import RequestStatus
@ -57,6 +59,7 @@ class NixlAgentMetadata(
num_blocks: int
tp_size: int
block_len: int
attn_backend_name: str
@dataclass
@ -384,11 +387,25 @@ class NixlConnectorWorker:
self.vllm_config = vllm_config
self.block_size = vllm_config.cache_config.block_size
self.model_config = vllm_config.model_config
self.cache_config = vllm_config.cache_config
# TODO(mgoin): remove this once we have hybrid memory allocator
# Optimization for models with local attention (Llama 4)
# List of block window sizes for each layer for local attention
self.block_window_per_layer: list[Optional[int]] = []
self.use_mla = self.model_config.use_mla
backend = get_attn_backend(self.model_config.get_head_size(),
self.model_config.dtype,
self.cache_config.cache_dtype,
self.block_size,
self.model_config.is_attention_free,
use_mla=self.use_mla)
self.backend_name = backend.get_name()
attn_backend = backend_name_to_enum(self.backend_name)
self._use_flashinfer = attn_backend == _Backend.FLASHINFER_VLLM_V1
logger.debug("Detected attention backend %s", self.backend_name)
self._tp_size: dict[str, int] = {self.engine_id: self.world_size}
# With heterogeneous TP, P must wait for all assigned D TP workers to
@ -472,12 +489,16 @@ class NixlConnectorWorker:
kv_elem_size = first_kv_cache.element_size()
# TODO(tms): Find a more robust way to detect and handle MLA
self.use_mla = len(first_kv_cache.shape) == 3
# NOTE (NickLucche) To move blocks efficiently with NIXL, the expected
# KV memory layout is HND, as opposed to the default NHD. Note that it
# will only affects the strides. For MLA instead, we make require no
# such thing and resort to the standard layout.
if self.use_mla:
use_mla = len(first_kv_cache.shape) == 3
assert use_mla == self.use_mla
# TODO (NickLucche) not compatible with hybrid allocator. Enforce check
# once it goes live, as a single kv layout is expected for xfers.
if use_mla:
# MLA case.
self.num_blocks = first_kv_cache.shape[0]
block_rank = 2 # [block_size, latent_dim]
@ -485,11 +506,16 @@ class NixlConnectorWorker:
block_size, kv_latent_dim = block_shape
self.slot_size_bytes = kv_elem_size * kv_latent_dim
else:
# [2 (k and v), num_blocks, block_size, kv_heads, head_dim]
# [2 (k and v), num_blocks, ...]
if self._use_flashinfer:
# FlashInfer swaps 2<->num_blocks dimensions.
self.num_blocks = first_kv_cache.shape[0]
block_rank = 4 # [2, block_size, kv_heads, head_dim]
else:
self.num_blocks = first_kv_cache.shape[1]
block_rank = 3 # [block_size, kv_heads, head_dim]
block_shape = first_kv_cache.shape[-block_rank:]
block_size, n_kv_heads, head_dim = block_shape
block_size, n_kv_heads, head_dim = block_shape[-3:]
# head size in bytes.
self.slot_size_bytes = kv_elem_size * n_kv_heads * head_dim
assert block_size == self.block_size
@ -497,12 +523,10 @@ class NixlConnectorWorker:
# hybrid attn, etc
# block size in bytes
self.block_len = kv_elem_size * math.prod(block_shape)
logger.debug("Registering KV_Caches. use_mla: %s, shape %s",
self.use_mla, first_kv_cache.shape)
logger.debug("num_blocks: %s, block_shape: %s", self.num_blocks,
block_shape)
logger.debug("Per layer kv cache size: %s", first_kv_cache.shape)
logger.info(
"Registering KV_Caches: use_mla: %s, num_blocks: %s, "
"block_shape: %s, per_layer_kv_cache_shape: %s", use_mla,
self.num_blocks, block_shape, first_kv_cache.shape)
self.dst_num_blocks[self.engine_id] = self.num_blocks
self.kv_caches = kv_caches
kv_caches_base_addr = []
@ -514,9 +538,12 @@ class NixlConnectorWorker:
# are non-contiguous (it's not locally guaranteed that they will be)
# Disadvantage is that the encoded NixlAgentMetadata is now larger
# (roughly 8KB vs 5KB).
# Conversely for FlashInfer, K and V are transferred in the same tensor
# to better exploit the memory layout (ie num_blocks is the first dim).
for cache_or_caches in kv_caches.values():
# Normalize to always be a list of caches
cache_list = [cache_or_caches] if self.use_mla else cache_or_caches
cache_list = [cache_or_caches] if use_mla or self._use_flashinfer \
else cache_or_caches
for cache in cache_list:
base_addr = cache.data_ptr()
region_len = self.num_blocks * self.block_len
@ -581,7 +608,8 @@ class NixlConnectorWorker:
kv_caches_base_addr=self.kv_caches_base_addr[self.engine_id],
num_blocks=self.num_blocks,
tp_size=self.world_size,
block_len=self.block_len)
block_len=self.block_len,
attn_backend_name=self.backend_name)
ready_event = threading.Event()
self._nixl_handshake_listener_t = threading.Thread(
target=self._nixl_handshake_listener,
@ -641,6 +669,10 @@ class NixlConnectorWorker:
assert self._tp_size[engine_id] == nixl_agent_meta.tp_size
else:
self._tp_size[engine_id] = nixl_agent_meta.tp_size
# We may eventually enable this after asserting equality in cache
# layout and close outputs.
assert nixl_agent_meta.attn_backend_name == self.backend_name
self._remote_agents[engine_id][
remote_tp_rank] = self.nixl_wrapper.add_remote_agent(
nixl_agent_meta.agent_metadata)
@ -659,13 +691,16 @@ class NixlConnectorWorker:
else:
remote_block_size = nixl_agent_meta.block_len // (
self.slot_size_bytes * tp_ratio)
if self._use_flashinfer:
# Account for joint KV in FlashInfer.
remote_block_size //= 2
assert nixl_agent_meta.block_len == self.block_len * tp_ratio, (
"Remote P worker KV layer cache must be of shape [2, N, "
"local_kv_heads*tp_ratio, block_size, head_dim] and same dtype."
)
assert self.block_size == remote_block_size, "Remote P worker with "
assert self.block_size == remote_block_size, "Remote P worker with " \
"different block size is not supported"
assert self.num_blocks >= nixl_agent_meta.num_blocks

View File

@ -47,6 +47,7 @@ class _Backend(enum.Enum):
ROCM_AITER_MLA_VLLM_V1 = enum.auto()
TORCH_SDPA = enum.auto()
FLASHINFER = enum.auto()
FLASHINFER_VLLM_V1 = enum.auto()
TRITON_MLA = enum.auto() # Supported by V1
TRITON_MLA_VLLM_V1 = enum.auto()
FLASHMLA_VLLM_V1 = enum.auto()