mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 07:45:29 +08:00
[P/D][NixlConnector] Enable FlashInfer backend (#19090)
This commit is contained in:
parent
85e2b7bb13
commit
9ef9173cfa
@ -15,6 +15,7 @@ import torch
|
||||
import zmq
|
||||
|
||||
from vllm import envs
|
||||
from vllm.attention.selector import backend_name_to_enum, get_attn_backend
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
|
||||
KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole)
|
||||
@ -22,6 +23,7 @@ from vllm.distributed.parallel_state import (
|
||||
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size,
|
||||
get_tp_group)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import _Backend
|
||||
from vllm.utils import make_zmq_path, make_zmq_socket, round_down
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
from vllm.v1.request import RequestStatus
|
||||
@ -57,6 +59,7 @@ class NixlAgentMetadata(
|
||||
num_blocks: int
|
||||
tp_size: int
|
||||
block_len: int
|
||||
attn_backend_name: str
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -384,11 +387,25 @@ class NixlConnectorWorker:
|
||||
|
||||
self.vllm_config = vllm_config
|
||||
self.block_size = vllm_config.cache_config.block_size
|
||||
self.model_config = vllm_config.model_config
|
||||
self.cache_config = vllm_config.cache_config
|
||||
|
||||
# TODO(mgoin): remove this once we have hybrid memory allocator
|
||||
# Optimization for models with local attention (Llama 4)
|
||||
# List of block window sizes for each layer for local attention
|
||||
self.block_window_per_layer: list[Optional[int]] = []
|
||||
self.use_mla = self.model_config.use_mla
|
||||
|
||||
backend = get_attn_backend(self.model_config.get_head_size(),
|
||||
self.model_config.dtype,
|
||||
self.cache_config.cache_dtype,
|
||||
self.block_size,
|
||||
self.model_config.is_attention_free,
|
||||
use_mla=self.use_mla)
|
||||
self.backend_name = backend.get_name()
|
||||
attn_backend = backend_name_to_enum(self.backend_name)
|
||||
self._use_flashinfer = attn_backend == _Backend.FLASHINFER_VLLM_V1
|
||||
logger.debug("Detected attention backend %s", self.backend_name)
|
||||
|
||||
self._tp_size: dict[str, int] = {self.engine_id: self.world_size}
|
||||
# With heterogeneous TP, P must wait for all assigned D TP workers to
|
||||
@ -472,12 +489,16 @@ class NixlConnectorWorker:
|
||||
kv_elem_size = first_kv_cache.element_size()
|
||||
|
||||
# TODO(tms): Find a more robust way to detect and handle MLA
|
||||
self.use_mla = len(first_kv_cache.shape) == 3
|
||||
# NOTE (NickLucche) To move blocks efficiently with NIXL, the expected
|
||||
# KV memory layout is HND, as opposed to the default NHD. Note that it
|
||||
# will only affects the strides. For MLA instead, we make require no
|
||||
# such thing and resort to the standard layout.
|
||||
if self.use_mla:
|
||||
use_mla = len(first_kv_cache.shape) == 3
|
||||
assert use_mla == self.use_mla
|
||||
|
||||
# TODO (NickLucche) not compatible with hybrid allocator. Enforce check
|
||||
# once it goes live, as a single kv layout is expected for xfers.
|
||||
if use_mla:
|
||||
# MLA case.
|
||||
self.num_blocks = first_kv_cache.shape[0]
|
||||
block_rank = 2 # [block_size, latent_dim]
|
||||
@ -485,11 +506,16 @@ class NixlConnectorWorker:
|
||||
block_size, kv_latent_dim = block_shape
|
||||
self.slot_size_bytes = kv_elem_size * kv_latent_dim
|
||||
else:
|
||||
# [2 (k and v), num_blocks, block_size, kv_heads, head_dim]
|
||||
self.num_blocks = first_kv_cache.shape[1]
|
||||
block_rank = 3 # [block_size, kv_heads, head_dim]
|
||||
# [2 (k and v), num_blocks, ...]
|
||||
if self._use_flashinfer:
|
||||
# FlashInfer swaps 2<->num_blocks dimensions.
|
||||
self.num_blocks = first_kv_cache.shape[0]
|
||||
block_rank = 4 # [2, block_size, kv_heads, head_dim]
|
||||
else:
|
||||
self.num_blocks = first_kv_cache.shape[1]
|
||||
block_rank = 3 # [block_size, kv_heads, head_dim]
|
||||
block_shape = first_kv_cache.shape[-block_rank:]
|
||||
block_size, n_kv_heads, head_dim = block_shape
|
||||
block_size, n_kv_heads, head_dim = block_shape[-3:]
|
||||
# head size in bytes.
|
||||
self.slot_size_bytes = kv_elem_size * n_kv_heads * head_dim
|
||||
assert block_size == self.block_size
|
||||
@ -497,12 +523,10 @@ class NixlConnectorWorker:
|
||||
# hybrid attn, etc
|
||||
# block size in bytes
|
||||
self.block_len = kv_elem_size * math.prod(block_shape)
|
||||
|
||||
logger.debug("Registering KV_Caches. use_mla: %s, shape %s",
|
||||
self.use_mla, first_kv_cache.shape)
|
||||
logger.debug("num_blocks: %s, block_shape: %s", self.num_blocks,
|
||||
block_shape)
|
||||
logger.debug("Per layer kv cache size: %s", first_kv_cache.shape)
|
||||
logger.info(
|
||||
"Registering KV_Caches: use_mla: %s, num_blocks: %s, "
|
||||
"block_shape: %s, per_layer_kv_cache_shape: %s", use_mla,
|
||||
self.num_blocks, block_shape, first_kv_cache.shape)
|
||||
self.dst_num_blocks[self.engine_id] = self.num_blocks
|
||||
self.kv_caches = kv_caches
|
||||
kv_caches_base_addr = []
|
||||
@ -514,9 +538,12 @@ class NixlConnectorWorker:
|
||||
# are non-contiguous (it's not locally guaranteed that they will be)
|
||||
# Disadvantage is that the encoded NixlAgentMetadata is now larger
|
||||
# (roughly 8KB vs 5KB).
|
||||
# Conversely for FlashInfer, K and V are transferred in the same tensor
|
||||
# to better exploit the memory layout (ie num_blocks is the first dim).
|
||||
for cache_or_caches in kv_caches.values():
|
||||
# Normalize to always be a list of caches
|
||||
cache_list = [cache_or_caches] if self.use_mla else cache_or_caches
|
||||
cache_list = [cache_or_caches] if use_mla or self._use_flashinfer \
|
||||
else cache_or_caches
|
||||
for cache in cache_list:
|
||||
base_addr = cache.data_ptr()
|
||||
region_len = self.num_blocks * self.block_len
|
||||
@ -581,7 +608,8 @@ class NixlConnectorWorker:
|
||||
kv_caches_base_addr=self.kv_caches_base_addr[self.engine_id],
|
||||
num_blocks=self.num_blocks,
|
||||
tp_size=self.world_size,
|
||||
block_len=self.block_len)
|
||||
block_len=self.block_len,
|
||||
attn_backend_name=self.backend_name)
|
||||
ready_event = threading.Event()
|
||||
self._nixl_handshake_listener_t = threading.Thread(
|
||||
target=self._nixl_handshake_listener,
|
||||
@ -641,6 +669,10 @@ class NixlConnectorWorker:
|
||||
assert self._tp_size[engine_id] == nixl_agent_meta.tp_size
|
||||
else:
|
||||
self._tp_size[engine_id] = nixl_agent_meta.tp_size
|
||||
# We may eventually enable this after asserting equality in cache
|
||||
# layout and close outputs.
|
||||
assert nixl_agent_meta.attn_backend_name == self.backend_name
|
||||
|
||||
self._remote_agents[engine_id][
|
||||
remote_tp_rank] = self.nixl_wrapper.add_remote_agent(
|
||||
nixl_agent_meta.agent_metadata)
|
||||
@ -659,13 +691,16 @@ class NixlConnectorWorker:
|
||||
else:
|
||||
remote_block_size = nixl_agent_meta.block_len // (
|
||||
self.slot_size_bytes * tp_ratio)
|
||||
if self._use_flashinfer:
|
||||
# Account for joint KV in FlashInfer.
|
||||
remote_block_size //= 2
|
||||
|
||||
assert nixl_agent_meta.block_len == self.block_len * tp_ratio, (
|
||||
"Remote P worker KV layer cache must be of shape [2, N, "
|
||||
"local_kv_heads*tp_ratio, block_size, head_dim] and same dtype."
|
||||
)
|
||||
|
||||
assert self.block_size == remote_block_size, "Remote P worker with "
|
||||
assert self.block_size == remote_block_size, "Remote P worker with " \
|
||||
"different block size is not supported"
|
||||
|
||||
assert self.num_blocks >= nixl_agent_meta.num_blocks
|
||||
|
||||
@ -47,6 +47,7 @@ class _Backend(enum.Enum):
|
||||
ROCM_AITER_MLA_VLLM_V1 = enum.auto()
|
||||
TORCH_SDPA = enum.auto()
|
||||
FLASHINFER = enum.auto()
|
||||
FLASHINFER_VLLM_V1 = enum.auto()
|
||||
TRITON_MLA = enum.auto() # Supported by V1
|
||||
TRITON_MLA_VLLM_V1 = enum.auto()
|
||||
FLASHMLA_VLLM_V1 = enum.auto()
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user